# Batch Optimization for All Temperatures (Parametric Curve Fitting)

**Runs the parametric curve-fitting optimization pipeline for all temperature datasets.**

## Curve-Fitting Approach

We fit the parametric curve V(L) where both V and L are computed from parameters a, b:
- L(z) = (2/π) ∫ sqrt(g(z'))/sqrt((z⁴f(z))/(z'⁴f(z')) - 1) dz'
- V(z) = 2π ∫ [1/z'² (sqrt(fg)/sqrt(1 - z'⁴f(z)/(z⁴f(z'))) - 1)] dz' - 2π/z

## Loss Function

```
Loss = DataFitError + λ·NECPenalty
```

where:
- DataFitError = Mean[(V_model - V_data)²]
- NECPenalty = Max[0, Max_z[Σ(n+1)(aₙ+bₙ)zⁿ]]²
- λ = 100 (tunable weight)

**Outputs for each temperature:**
- `results/T{temp}_params.txt` - Optimized parameters
- `results/T{temp}_fit.png` - Plot of data vs optimized model
- `results/T{temp}_convergence.png` - Convergence plots for a, b, objective
- `results/all_temperatures_summary.txt` - Summary table of all results

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import differential_evolution, Bounds
from sklearn.linear_model import LinearRegression
from dataset_HR import AdSBHDataset
from model_HR_new import AdSBHNet
from constants import dreal, dcomplex
import glob
import os
import re
from pathlib import Path

## 1. Setup Output Directory

In [2]:
# Create results directory
results_dir = Path('results')
results_dir.mkdir(exist_ok=True)
print(f"Results will be saved to: {results_dir.absolute()}")

Results will be saved to: /Users/helitakko/Dropbox/Own/deep-learning-EE/env-mac/complex_wilson/results


## 2. Find All Temperature Files

In [3]:
# Find all lattice data files (only 1607 series)
files = sorted(glob.glob('1607latticeT*.txt'))

# Extract temperature from filename
temp_files = []
for f in files:
    match = re.search(r'T(\d+)', f)
    if match:
        temp = int(match.group(1))
        temp_files.append((temp, f))

temp_files.sort()

print(f"Found {len(temp_files)} temperature datasets (1607 series only):")
for temp, f in temp_files:
    print(f"  T = {temp:3d} MeV: {f}")

Found 9 temperature datasets (1607 series only):
  T = 113 MeV: 1607latticeT113.txt
  T = 226 MeV: 1607latticeT226.txt
  T = 254 MeV: 1607latticeT254.txt
  T = 271 MeV: 1607latticeT271.txt
  T = 290 MeV: 1607latticeT290.txt
  T = 312 MeV: 1607latticeT312.txt
  T = 338 MeV: 1607latticeT338.txt
  T = 369 MeV: 1607latticeT369.txt
  T = 406 MeV: 1607latticeT406.txt


## 3. Define Optimization Pipeline (Same as Single-Temperature Notebook)

In [4]:
N_COEFFS = 4

@torch.no_grad()
def connected_branch(model, device, dt, zmin=0.02, zmax=0.9995, Nc=3000):
    zs = torch.linspace(zmin, zmax, Nc, dtype=dt, device=device)
    Lc = model.integrate_L(zs).real
    Vc = model.integrate_V(zs).real
    idx = torch.argsort(Lc)
    Lc, Vc = Lc[idx], Vc[idx]
    imax = torch.argmax(Lc)
    return Lc[:imax+1], Vc[:imax+1]

@torch.no_grad()
def interp_1d(x, y, xq, eps=1e-12):
    if x.numel() < 2:
        return torch.full_like(xq, y[0] if y.numel() > 0 else 0.0)
    xq = torch.as_tensor(xq, dtype=x.dtype, device=x.device).reshape(-1)
    pos = torch.searchsorted(x, xq, right=True)
    i0 = (pos - 1).clamp(0, x.numel() - 2)
    i1 = i0 + 1
    x0, x1, y0, y1 = x[i0], x[i1], y[i0], y[i1]
    w = (xq - x0) / (x1 - x0 + eps)
    v_lin = y0 + w * (y1 - y0)
    mL = (y[1] - y[0]) / (x[1] - x[0] + eps)
    mR = (y[-1] - y[-2]) / (x[-1] - x[-2] + eps)
    v_left = y[0] + mL * (xq - x[0])
    v_right = y[-1] + mR * (xq - x[-1])
    v = torch.where(xq < x[0], v_left, v_lin)
    v = torch.where(xq > x[-1], v_right, v)
    return v

@torch.no_grad()
def pack_params(model):
    a_np = model.a.detach().cpu().numpy().astype(np.float64)
    b_np = model.b.detach().cpu().numpy().astype(np.float64)
    N = len(a_np)
    return np.concatenate([a_np[0:1], a_np[1:], b_np[1:]])

@torch.no_grad()
def unpack_params(model, theta, device, dt):
    theta = np.asarray(theta, dtype=np.float64)
    N = model.a.numel()
    a0 = theta[0]
    a_rest = theta[1:N]
    b_rest = theta[N:2*N-1]
    b0 = -a0
    a_np = np.concatenate([[a0], a_rest])
    b_np = np.concatenate([[b0], b_rest])
    model.a.copy_(torch.tensor(a_np, dtype=dt, device=device))
    model.b.copy_(torch.tensor(b_np, dtype=dt, device=device))

def get_bounds(N):
    lower = np.concatenate([[-0.5], [-0.1]*(N-1), [0.05]*(N-1)])
    upper = np.concatenate([[-0.05], [0.6]*(N-1), [0.3]*(N-1)])
    return Bounds(lower, upper)

@torch.no_grad()
def nec_penalty(a, b, device, dt, num_samples=100):
    """
    Compute NEC (Null Energy Condition) violation penalty.
    
    NEC requires: -(3/2z) * [a'(z) + b'(z)] >= 0 for all z in [0, 1]
    This means: (a0+b0) + 2(a1+b1)z + 3(a2+b2)z^2 + 4(a3+b3)z^3 <= 0
    
    We penalize violations where P(z) > 0.
    """
    z = torch.linspace(0.01, 0.99, num_samples, dtype=dt, device=device)
    
    # Compute polynomial P(z) = sum (n+1)(an+bn)z^n
    P = (a[0] + b[0]) + 2*(a[1] + b[1])*z + 3*(a[2] + b[2])*z**2 + 4*(a[3] + b[3])*z**3
    
    # NEC violation if P > 0
    violation = torch.relu(P).max()
    
    return float(violation.item())

In [None]:
def optimize_params_all_temperatures(filename, temp, verbose=True):
    """
    Run parametric curve-fitting optimization for a single temperature.
    
    Uses the new loss formulation:
        Loss = DataFitError + λ·NECPenalty
    
    Returns:
        dict with keys: a, b, coef, shift, obj_final, Lmax_model, success
    """
    if verbose:
        print(f"\n{'='*70}")
        print(f"OPTIMIZING T = {temp} MeV")
        print(f"File: {filename}")
        print(f"{'='*70}")
    
    # Load data
    dataset = AdSBHDataset(file=filename)
    mask_L = dataset.L < 1.4
    L_all = dataset.L[mask_L]
    V_all = dataset.V[mask_L]
    sigma_all = dataset.sigma[mask_L]
    
    if verbose:
        print(f"  Dataset: {len(L_all)} points")
        print(f"  L range: [{L_all.min():.4f}, {L_all.max():.4f}]")
    
    # Setup model
    model = AdSBHNet(N=N_COEFFS, std=0.1)
    device = model.a.device
    dt = model.a.dtype
    
    L_all = L_all.to(device=device, dtype=dt)
    V_all = V_all.to(device=device)
    sigma_all = sigma_all.to(device=device)
    
    # Initialize model
    model.a.zero_()
    model.b.zero_()
    model.a[0] = -0.25
    model.b[0] = 0.25
    
    if model.a.numel() > 1:
        model.a[1:] = torch.tensor([0.26, 0.31, 0.33], dtype=dt, device=device)
        model.b[1:] = torch.tensor([0.12, 0.12, 0.13], dtype=dt, device=device)
    
    # LS initialization for coef and shift
    Lm, Vm = connected_branch(model, device, dt)
    if Lm.numel() >= 2:
        mask = (L_all >= Lm[0]) & (L_all <= Lm[-1])
        if mask.sum() >= 3:
            L_fit = L_all[mask]
            V_fit = V_all.real[mask]
            X_fit = interp_1d(Lm, Vm, L_fit)
            X_np = X_fit.cpu().numpy().reshape(-1, 1)
            Y_np = V_fit.cpu().numpy()
            reg = LinearRegression().fit(X_np, Y_np)
            coef = max(float(reg.coef_[0]), 1e-6)
            shift = float(reg.intercept_)
            model.logcoef.copy_(torch.tensor(np.log(coef), dtype=dt, device=device))
            model.shift.copy_(torch.tensor(shift, dtype=dt, device=device))
    
    if verbose:
        print(f"  LS init: coef={np.exp(float(model.logcoef.detach())):.6f}, shift={float(model.shift.detach()):.6f}")
        print(f"  Initial swallowtail at L_max = {Lm[-1].item():.4f}")
    
    theta0 = pack_params(model)
    bounds = get_bounds(N_COEFFS)
    
    # Define objective with new loss formulation
    iteration_counter = [0]
    a_history = []
    b_history = []
    obj_history = []
    
    @torch.no_grad()
    def objective(theta, return_components=False):
        """
        Loss = DataFitError + λ·NECPenalty
        
        where:
          DataFitError = Mean[(V_model - V_data)²]
          NECPenalty = Max[0, Max_z[Σ(n+1)(aₙ+bₙ)zⁿ]]²
        """
        unpack_params(model, theta, device, dt)
        iteration_counter[0] += 1
        a_history.append(model.a.detach().cpu().numpy().copy())
        b_history.append(model.b.detach().cpu().numpy().copy())
        
        try:
            Lm, Vm = connected_branch(model, device, dt)
        except:
            obj_history.append(1e10)
            return 1e10 if not return_components else (1e10, {})
        
        if Lm.numel() < 5 or not torch.isfinite(Lm).all() or not torch.isfinite(Vm).all():
            obj_history.append(1e10)
            return 1e10 if not return_components else (1e10, {})
        
        # Interpolate V_model at data points L
        V_model = interp_1d(Lm, Vm, L_all) + model.shift
        residuals = V_model - V_all.real
        
        # Data fit error: Mean[(V_model - V_data)²]
        w_stat = 1.0 / (sigma_all.real**2 + 1e-12)
        L_weight = (L_all / L_all.max()).clamp_min(0).pow(2.0)
        w_total = w_stat * L_weight
        
        data_fit_error = float((w_total * residuals**2).mean().item())
        
        # NEC penalty: Max[0, Max_z[P(z)]]²
        z_samples = torch.linspace(0.01, 0.99, 100, dtype=dt, device=device)
        P = ((model.a[0] + model.b[0]) + 
             2*(model.a[1] + model.b[1])*z_samples + 
             3*(model.a[2] + model.b[2])*z_samples**2 + 
             4*(model.a[3] + model.b[3])*z_samples**3)
        
        nec_violation = torch.relu(P).max()
        nec_penalty_term = float((nec_violation ** 2).item())
        
        # Weight parameter λ for NEC penalty
        lambda_nec = 100.0
        
        # Total loss
        total_loss = data_fit_error + lambda_nec * nec_penalty_term
        
        # Additional soft constraints
        Lmax_model = float(Lm[-1].item())
        swallowtail_penalty = 0.0
        target_L_min, target_L_max = 0.35, 0.60
        
        if Lmax_model < target_L_min:
            swallowtail_penalty = 10.0 * ((target_L_min - Lmax_model) / target_L_min) ** 2
        elif Lmax_model > target_L_max:
            swallowtail_penalty = 10.0 * ((Lmax_model - target_L_max) / target_L_max) ** 2
        
        Lmax_data = float(L_all.max().item())
        coverage_penalty = 0.0
        if Lmax_model < Lmax_data:
            gap = (Lmax_data - Lmax_model) / max(Lmax_data, 1e-6)
            coverage_penalty = 5.0 * (gap**2)
        
        reg_L2 = 1e-5 * float((model.a**2).sum().item() + (model.b**2).sum().item())
        
        # Total objective
        total = total_loss + swallowtail_penalty + coverage_penalty + reg_L2
        obj_history.append(total)
        
        if return_components:
            return total, {
                'data_fit_error': data_fit_error,
                'nec_penalty': nec_penalty_term,
                'nec_violation': float(nec_violation.item()),
                'total_loss': total_loss,
                'swallowtail_penalty': swallowtail_penalty,
                'coverage_penalty': coverage_penalty,
                'reg_L2': reg_L2,
                'Lmax_model': Lmax_model,
                'Lmax_data': Lmax_data,
                'chi2_data': data_fit_error  # For backward compatibility
            }
        return total
    
    # Run optimization
    if verbose:
        print(f"  Running differential evolution...")
        print(f"  Loss = DataFitError + λ·NECPenalty (λ=100)")
    
    result = differential_evolution(
        objective,
        bounds=[(bounds.lb[i], bounds.ub[i]) for i in range(len(bounds.lb))],
        maxiter=50,
        popsize=10,
        strategy='best1bin',
        atol=1e-6,
        tol=1e-6,
        mutation=(0.5, 1.5),
        recombination=0.7,
        seed=42,
        polish=True,
        workers=1,
        disp=False
    )
    
    unpack_params(model, result.x, device, dt)
    obj_final, comp_final = objective(result.x, return_components=True)
    
    if verbose:
        print(f"  Optimization complete!")
        print(f"    Success: {result.success}")
        print(f"    Final loss: {comp_final['total_loss']:.4e}")
        print(f"      DataFitError: {comp_final['data_fit_error']:.4e}")
        print(f"      NECPenalty: {comp_final['nec_penalty']:.4e}")
        print(f"      NEC violation: {comp_final['nec_violation']:.4e}")
        print(f"    Swallowtail at L_max = {comp_final['Lmax_model']:.4f}")
        print(f"    Function evaluations: {result.nfev}")
    
    # Extract final parameters
    a_final = model.a.detach().cpu().numpy()
    b_final = model.b.detach().cpu().numpy()
    coef_final = np.exp(float(model.logcoef.detach()))
    shift_final = float(model.shift.detach())
    
    # Generate plots
    if verbose:
        print(f"  Generating plots...")
    
    # Plot 1: Fit
    fig, ax = plt.subplots(figsize=(8, 6))
    with torch.no_grad():
        Lm_final, Vm_final = connected_branch(model, device, dt, Nc=5000)
        V_branch = Vm_final + model.shift
    
    ax.plot(Lm_final.cpu(), V_branch.cpu(), 'r-', linewidth=2.5,
            label='Optimized', alpha=0.8, zorder=2)
    ax.errorbar(L_all.cpu(), V_all.real.cpu(), yerr=sigma_all.real.cpu(),
                fmt='o', markersize=6, alpha=0.7, label='Data', color='blue', zorder=1)
    ax.set_xlabel(r'$T L$', fontsize=13)
    ax.set_ylabel(r'$V/T$', fontsize=13)
    ax.set_title(f'T = {temp} MeV: Data vs Optimized Model', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(results_dir / f'T{temp}_fit.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    # Plot 2: Convergence
    a_hist_np = np.array(a_history)
    b_hist_np = np.array(b_history)
    iterations = np.arange(1, len(a_history) + 1)
    
    n_a = a_hist_np.shape[1]
    n_b = b_hist_np.shape[1]
    
    warm = plt.cm.OrRd(np.linspace(0.9, 0.45, n_a))
    cold = plt.cm.Blues(np.linspace(0.9, 0.45, n_b))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Coefficients
    ax = axes[0]
    for i in range(n_a):
        ax.plot(iterations, a_hist_np[:, i], color=warm[i], label=f'a[{i}]', linewidth=2)
    for i in range(n_b):
        ax.plot(iterations, b_hist_np[:, i], color=cold[i], linestyle='--', label=f'b[{i}]', linewidth=2)
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Coefficient Value', fontsize=12)
    ax.set_title(f'T = {temp} MeV: Coefficient Convergence', fontsize=13)
    ax.legend(ncol=2, fontsize=10)
    ax.grid(alpha=0.3)
    
    # Objective
    ax = axes[1]
    ax.plot(iterations, obj_history, 'k-', linewidth=2, alpha=0.7)
    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Loss Value', fontsize=12)
    ax.set_title(f'T = {temp} MeV: Loss Convergence', fontsize=13)
    ax.set_yscale('log')
    ax.grid(alpha=0.3, which='both')
    
    plt.tight_layout()
    plt.savefig(results_dir / f'T{temp}_convergence.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    # Save parameters to text file
    with open(results_dir / f'T{temp}_params.txt', 'w') as f:
        f.write(f"Temperature: {temp} MeV\n")
        f.write(f"File: {filename}\n")
        f.write(f"\n{'='*60}\n")
        f.write(f"PARAMETRIC CURVE FITTING RESULTS\n")
        f.write(f"{'='*60}\n\n")
        
        f.write("Loss = DataFitError + λ·NECPenalty\n")
        f.write(f"  DataFitError: {comp_final['data_fit_error']:.6e}\n")
        f.write(f"  NECPenalty: {comp_final['nec_penalty']:.6e}\n")
        f.write(f"  λ: 100.0\n")
        f.write(f"  Total Loss: {comp_final['total_loss']:.6e}\n\n")
        
        f.write("Coefficients a:\n")
        for i, val in enumerate(a_final):
            f.write(f"  a[{i}] = {val:.10f}\n")
        
        f.write("\nCoefficients b:\n")
        for i, val in enumerate(b_final):
            f.write(f"  b[{i}] = {val:.10f}\n")
        
        f.write("\nScale and shift:\n")
        f.write(f"  coef = {coef_final:.10f}\n")
        f.write(f"  shift = {shift_final:.10f}\n")
        
        f.write("\nLoss components:\n")
        for k, v in comp_final.items():
            f.write(f"  {k}: {v:.6e}\n")
        
        f.write("\nConstraints:\n")
        f.write(f"  a[0] + b[0] = {(a_final[0] + b_final[0]):.6e}\n")
        f.write(f"  a[0] < 0: {a_final[0] < 0}\n")
        f.write(f"  NEC satisfied: {comp_final['nec_violation'] < 1e-6}\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("Python format:\n")
        f.write(f"a = {a_final.tolist()}\n")
        f.write(f"b = {b_final.tolist()}\n")
        f.write(f"coef = {coef_final}\n")
        f.write(f"shift = {shift_final}\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("Mathematica format:\n")
        a_str = ", ".join([f"{x:.17g}" for x in a_final])
        b_str = ", ".join([f"{x:.17g}" for x in b_final])
        f.write(f"a = {{{a_str}}};\n")
        f.write(f"b = {{{b_str}}};\n")
        f.write(f"coef = {coef_final:.17g};\n")
        f.write(f"shift = {shift_final:.17g};\n")
    
    if verbose:
        print(f"  Saved to {results_dir}/T{temp}_*")
    
    return {
        'temp': temp,
        'filename': filename,
        'a': a_final,
        'b': b_final,
        'coef': coef_final,
        'shift': shift_final,
        'obj_final': obj_final,
        'total_loss': comp_final['total_loss'],
        'data_fit_error': comp_final['data_fit_error'],
        'nec_penalty': comp_final['nec_penalty'],
        'nec_violation': comp_final['nec_violation'],
        'Lmax_model': comp_final['Lmax_model'],
        'chi2_data': comp_final['chi2_data'],
        'success': result.success,
        'nfev': result.nfev
    }

## 4. Run Batch Optimization for All Temperatures

In [6]:
# Run optimization for all temperatures
all_results = []

print(f"\n\n{'='*70}")
print(f"STARTING BATCH OPTIMIZATION FOR {len(temp_files)} TEMPERATURES")
print(f"{'='*70}\n")

for idx, (temp, filename) in enumerate(temp_files, 1):
    print(f"\n{'#'*70}")
    print(f"# TEMPERATURE {idx}/{len(temp_files)}: T = {temp} MeV")
    print(f"{'#'*70}")
    
    try:
        result = optimize_params_all_temperatures(filename, temp, verbose=True)
        all_results.append(result)
        
        # Print summary for this temperature
        print(f"\n  RESULTS FOR T = {temp} MeV:")
        print(f"    a = [{result['a'][0]:.4f}, {result['a'][1]:.4f}, {result['a'][2]:.4f}, {result['a'][3]:.4f}]")
        print(f"    b = [{result['b'][0]:.4f}, {result['b'][1]:.4f}, {result['b'][2]:.4f}, {result['b'][3]:.4f}]")
        print(f"    coef = {result['coef']:.6f}, shift = {result['shift']:.4f}")
        print(f"    Swallowtail L_max = {result['Lmax_model']:.4f}")
        print(f"    Chi2 = {result['chi2_data']:.4e}")
        print(f"    Success: {result['success']}")
        
    except Exception as e:
        print(f"\n  *** ERROR for T={temp}: {e}")
        print(f"  *** Skipping this temperature and continuing...\n")
        import traceback
        traceback.print_exc()
        continue

print(f"\n\n{'='*70}")
print(f"BATCH OPTIMIZATION COMPLETE")
print(f"Successfully optimized {len(all_results)}/{len(temp_files)} temperatures")
print(f"{'='*70}")



STARTING BATCH OPTIMIZATION FOR 9 TEMPERATURES


######################################################################
# TEMPERATURE 1/9: T = 113 MeV
######################################################################

OPTIMIZING T = 113 MeV
File: 1607latticeT113.txt
  Dataset: 28 points
  L range: [0.0513, 0.7682]

  *** ERROR for T=113: a leaf Variable that requires grad is being used in an in-place operation.
  *** Skipping this temperature and continuing...


######################################################################
# TEMPERATURE 2/9: T = 226 MeV
######################################################################

OPTIMIZING T = 226 MeV
File: 1607latticeT226.txt
  Dataset: 25 points
  L range: [0.1025, 1.3444]

  *** ERROR for T=226: a leaf Variable that requires grad is being used in an in-place operation.
  *** Skipping this temperature and continuing...


######################################################################
# TEMPERATURE 3/9: T = 254 MeV
#

Traceback (most recent call last):
  File "/var/folders/hq/yn5dxnvn1vqgp11804klfycw0000gp/T/ipykernel_12591/1964647463.py", line 14, in <module>
    result = optimize_params_all_temperatures(filename, temp, verbose=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/hq/yn5dxnvn1vqgp11804klfycw0000gp/T/ipykernel_12591/3076048715.py", line 35, in optimize_params_all_temperatures
    model.a.zero_()
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
Traceback (most recent call last):
  File "/var/folders/hq/yn5dxnvn1vqgp11804klfycw0000gp/T/ipykernel_12591/1964647463.py", line 14, in <module>
    result = optimize_params_all_temperatures(filename, temp, verbose=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/hq/yn5dxnvn1vqgp11804klfycw0000gp/T/ipykernel_12591/3076048715.py", line 35, in optimize_params_all_temperatures
    model.a.zero_()
RuntimeEr

## 5. Create Summary Table

In [None]:
# Save summary table
summary_file = results_dir / 'all_temperatures_summary.txt'

with open(summary_file, 'w') as f:
    f.write("="*100 + "\n")
    f.write("SUMMARY: ALL TEMPERATURES\n")
    f.write("="*100 + "\n\n")
    
    # Header
    f.write(f"{'T[MeV]':>7} {'a[0]':>12} {'a[1]':>12} {'a[2]':>12} {'a[3]':>12} ")
    f.write(f"{'b[1]':>12} {'b[2]':>12} {'b[3]':>12} {'coef':>10} {'shift':>10} ")
    f.write(f"{'L_max':>8} {'chi2':>10} {'Success':>7}\n")
    f.write("-"*100 + "\n")
    
    for r in all_results:
        f.write(f"{r['temp']:7d} ")
        f.write(f"{r['a'][0]:12.6f} {r['a'][1]:12.6f} {r['a'][2]:12.6f} {r['a'][3]:12.6f} ")
        f.write(f"{r['b'][1]:12.6f} {r['b'][2]:12.6f} {r['b'][3]:12.6f} ")
        f.write(f"{r['coef']:10.6f} {r['shift']:10.4f} ")
        f.write(f"{r['Lmax_model']:8.4f} {r['chi2_data']:10.2e} ")
        f.write(f"{'Yes' if r['success'] else 'No':>7}\n")
    
    f.write("\n" + "="*100 + "\n")
    f.write(f"Total: {len(all_results)} temperatures\n")
    f.write(f"Results saved to: {results_dir.absolute()}\n")
    f.write("="*100 + "\n")

print(f"\nSummary table saved to: {summary_file}")
print(f"\nAll results saved to: {results_dir.absolute()}/")

## 6. Display Summary Table

## 6. Display Summary Table

In [None]:
def plot_all_temperatures_combined(all_results):
    """
    Plot all temperatures together with color gradient from cold to warm.
    """
    n_temps = len(all_results)
    if n_temps == 0:
        print("No results to plot!")
        return
    
    # Create color map from cold (blue) to warm (red)
    colors = plt.cm.coolwarm(np.linspace(0.0, 1.0, n_temps))
    
    fig, ax = plt.subplots(figsize=(12, 8))
    
    print(f"\nPlotting {n_temps} temperatures together...")
    
    for idx, result in enumerate(all_results):
        temp = result['temp']
        filename = result['filename']
        color = colors[idx]
        
        # Load data
        dataset = AdSBHDataset(file=filename)
        mask_L = dataset.L < 1.4
        L_data = dataset.L[mask_L]
        V_data = dataset.V[mask_L]
        sigma_data = dataset.sigma[mask_L]
        
        # Setup model with optimized parameters
        model = AdSBHNet(N=N_COEFFS, std=0.1)
        device = model.a.device
        dt = model.a.dtype
        
        # Set optimized parameters
        model.a.copy_(torch.tensor(result['a'], dtype=dt, device=device))
        model.b.copy_(torch.tensor(result['b'], dtype=dt, device=device))
        model.logcoef.copy_(torch.tensor(np.log(result['coef']), dtype=dt, device=device))
        model.shift.copy_(torch.tensor(result['shift'], dtype=dt, device=device))
        
        # Compute optimized curve
        with torch.no_grad():
            Lm, Vm = connected_branch(model, device, dt, Nc=5000)
            V_branch = (Vm + model.shift).cpu()
            Lm = Lm.cpu()
        
        # Plot data points (with error bars)
        ax.errorbar(L_data.cpu(), V_data.real.cpu(), yerr=sigma_data.real.cpu(),
                    fmt='o', markersize=5, alpha=0.6, color=color, 
                    label=f'T={temp} MeV (data)', zorder=1)
        
        # Plot optimized curve
        ax.plot(Lm, V_branch, '-', linewidth=2.5, color=color, alpha=0.9, zorder=2)
    
    ax.set_xlabel(r'$T L$', fontsize=14)
    ax.set_ylabel(r'$V/T$ (real part)', fontsize=14)
    ax.set_title(f'All Temperatures: Data and Optimized Models\\n(Cold to Warm: T = {all_results[0][\"temp\"]} → {all_results[-1][\"temp\"]} MeV)', 
                 fontsize=15)
    ax.grid(alpha=0.3)
    
    # Create custom legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', 
               markersize=8, label='Data points'),
        Line2D([0], [0], color='gray', linewidth=2.5, label='Optimized curves')
    ]
    ax.legend(handles=legend_elements, fontsize=12, loc='best')
    
    # Add colorbar to show temperature gradient
    sm = plt.cm.ScalarMappable(cmap=plt.cm.coolwarm, 
                                norm=plt.Normalize(vmin=all_results[0]['temp'], 
                                                   vmax=all_results[-1]['temp']))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, label='Temperature (MeV)', pad=0.02)
    cbar.ax.tick_params(labelsize=11)
    
    plt.tight_layout()
    
    # Save plot
    output_file = results_dir / 'all_temperatures_combined.png'
    plt.savefig(output_file, dpi=200, bbox_inches='tight')
    print(f"Saved combined plot to: {output_file}")
    
    plt.show()
    
    return fig, ax

# Create the combined plot
fig, ax = plot_all_temperatures_combined(all_results)

## 7. Combined Plot: All Temperatures Together

Plot all temperature data and optimized curves in one figure, color-coded from cold (blue) to hot (red).

In [None]:
# Display summary
print("\n" + "="*100)
print("SUMMARY TABLE")
print("="*100)
print(f"{'T[MeV]':>7} {'a[0]':>12} {'a[1]':>12} {'a[2]':>12} {'a[3]':>12} ", end="")
print(f"{'b[1]':>12} {'b[2]':>12} {'b[3]':>12} {'coef':>10} {'shift':>10} ", end="")
print(f"{'L_max':>8} {'chi2':>10} {'Success':>7}")
print("-"*100)

for r in all_results:
    print(f"{r['temp']:7d} ", end="")
    print(f"{r['a'][0]:12.6f} {r['a'][1]:12.6f} {r['a'][2]:12.6f} {r['a'][3]:12.6f} ", end="")
    print(f"{r['b'][1]:12.6f} {r['b'][2]:12.6f} {r['b'][3]:12.6f} ", end="")
    print(f"{r['coef']:10.6f} {r['shift']:10.4f} ", end="")
    print(f"{r['Lmax_model']:8.4f} {r['chi2_data']:10.2e} ", end="")
    print(f"{'Yes' if r['success'] else 'No':>7}")

print("="*100)