In [None]:
"""
FAIR CLIMATE DOWNSCALING: PINN vs Auto-AI
Correct PDE with Simplified Coefficients
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import copy
import random
from torch.utils.data import Dataset, DataLoader

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

PyTorch version: 2.9.0+cpu
CUDA available: False


In [None]:
# =============================================================================
# HYPERPARAMETERS - Modify these to experiment
# =============================================================================

# Random seed
SEED = 42

# Data generation
N_FIELDS = 2000          # Total synthetic climate fields
NX_COARSE = 16          # Coarse resolution
NX_FINE = 64            # Fine resolution
NT_STEPS = 10           # Time steps per field

# Training
EPOCHS = 100            # Training epochs
LR = 1e-3               # Learning rate
BATCH_TRAIN = 256       # Training batch size
BATCH_VAL = 512         # Validation batch size

# Physics parameters
ALPHA_TRUE_BASE = 0.01   # Base diffusivity
ALPHA_TRUE_VAR = 0.006   # Spatial variation amplitude
PHYSICS_NOISE = 0.2      # Unmodeled forcing (1-2% of signal)
OBS_NOISE = 0.05         # Observation noise

# PINN uses SIMPLIFIED (constant) alpha - approximation, not wrong!
ALPHA_PINN_SIMPLE = ALPHA_TRUE_BASE  # Constant approximation

# Initial lambda for both methods
LAMBDA_INIT = 10.0

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set seeds for reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

print(f"Device: {DEVICE}")
print(f"Ground truth uses: α(x) = {ALPHA_TRUE_BASE} + {ALPHA_TRUE_VAR}·sin(10πx)")
print(f"PINN uses simplified: α = {ALPHA_PINN_SIMPLE} (constant)")

Device: cpu
Ground truth uses: α(x) = 0.01 + 0.006·sin(10πx)
PINN uses simplified: α = 0.01 (constant)


In [None]:
# =============================================================================
# GROUND TRUTH DATA GENERATION
# =============================================================================

def generate_climate_fields_realistic():
    """
    Generate ground truth following:
    ∂u/∂t = α(x)·∂²u/∂x² + ε(x,t)

    where:
    - α(x) varies spatially (realistic)
    - ε represents unmodeled forcing/subgrid processes
    """
    x_fine = np.linspace(0.0, 1.0, NX_FINE)
    dx = x_fine[1] - x_fine[0]

    # Spatially-varying diffusivity (GROUND TRUTH)
    alpha_true = ALPHA_TRUE_BASE + ALPHA_TRUE_VAR * np.sin(10 * np.pi * x_fine)
    alpha_max = alpha_true.max()
    alpha_min = alpha_true.min()

    # FTCS stability: r = α·dt/dx² ≤ 0.5
    r = 0.25
    dt = r * dx**2 / alpha_max

    print(f"\n{'='*70}")
    print(f"DATA GENERATION")
    print(f"{'='*70}")
    print(f"Ground truth α(x) range: [{alpha_min:.4f}, {alpha_max:.4f}]")
    print(f"Spatial variation: {(alpha_max/alpha_min - 1)*100:.1f}%")
    print(f"PINN simplified α: {ALPHA_PINN_SIMPLE:.4f} (constant)")
    print(f"This is APPROXIMATION (not wrong!)")
    print(f"Stability ratio r = {r:.3f} (safe < 0.5)")
    print(f"Time step dt = {dt:.6f}")
    print(f"Spatial step dx = {dx:.6f}")
    print(f"{'='*70}\n")

    fine_all = np.zeros((N_FIELDS, NT_STEPS, NX_FINE), dtype=np.float32)
    coarse_all = np.zeros((N_FIELDS, NT_STEPS, NX_COARSE), dtype=np.float32)
    block = NX_FINE // NX_COARSE

    print("Generating fields...")
    for i in range(N_FIELDS):
        if (i+1) % 500 == 0:
            print(f"  Generated {i+1}/{N_FIELDS} fields")

        # Random initial condition (smooth + high-freq components)
        A = np.random.uniform(5.0, 10.0)
        B = np.random.uniform(0.0, 4.0)
        phase = np.random.uniform(0.0, 2*np.pi)

        u0 = (A * np.sin(2*np.pi * x_fine) +
              B * np.sin(4*np.pi * x_fine) +
              0.5 * np.cos(2*np.pi * x_fine) +
              0.3 * np.sin(16*np.pi * x_fine + phase))
        u0 += OBS_NOISE * np.random.randn(NX_FINE)

        fine_all[i, 0] = u0
        coarse_all[i, 0] = u0.reshape(NX_COARSE, block).mean(axis=1)

        # Time evolution using FTCS
        u = u0.copy()
        for t in range(1, NT_STEPS):
            u_new = np.zeros_like(u)

            # PDE: ∂u/∂t = α(x)·∂²u/∂x²
            for j in range(NX_FINE):
                jm = (j - 1) % NX_FINE
                jp = (j + 1) % NX_FINE
                u_xx = (u[jp] - 2*u[j] + u[jm]) / dx**2
                u_new[j] = u[j] + alpha_true[j] * dt * u_xx

            # Add unmodeled forcing (spatially correlated)
            white_noise = np.random.randn(NX_FINE)
            kernel = np.array([0.25, 0.5, 0.25])
            correlated_noise = np.convolve(white_noise, kernel, mode='same')
            u_new += PHYSICS_NOISE * correlated_noise

            u = u_new.copy()
            fine_all[i, t] = u

            # Coarse observations with measurement noise
            coarse_true = u.reshape(NX_COARSE, block).mean(axis=1)
            coarse_all[i, t] = coarse_true + OBS_NOISE * np.random.randn(NX_COARSE)

    print(f"  Generated {N_FIELDS}/{N_FIELDS} fields ✓")
    print(f"Fine data shape: {fine_all.shape}")
    print(f"Coarse data shape: {coarse_all.shape}")

    return x_fine, dx, dt, alpha_true, coarse_all, fine_all

In [None]:
# =============================================================================
# DATASET PREPARATION
# =============================================================================

def build_dataset(x_fine, coarse, fine):
    """
    Build dataset for 1-step forecasting:
    Input: [x, t_norm, coarse(t,:)]
    Output: fine(t+1, x)
    """
    X_list, Y_list = [], []
    N_fields, Nt, Nx_coarse = coarse.shape
    Nx_fine = fine.shape[2]

    for f in range(N_fields):
        for t in range(Nt - 1):
            t_norm = t / (Nt - 1)
            c_t = coarse[f, t]
            f_next = fine[f, t+1]

            for j in range(Nx_fine):
                X_list.append(np.concatenate([[x_fine[j], t_norm], c_t]))
                Y_list.append(f_next[j])

    X = np.array(X_list, dtype=np.float32)
    Y = np.array(Y_list, dtype=np.float32)[:, None]

    return X, Y


class ClimateDataset(Dataset):
    """PyTorch Dataset for climate downscaling"""
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X)
        self.Y = torch.from_numpy(Y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.Y[i]

In [None]:
# =============================================================================
# NEURAL NETWORK MODEL
# =============================================================================

class DownscaleNet(nn.Module):
    """
    Multi-layer perceptron for downscaling
    Input: [x, t, coarse_field]
    Output: fine_scale_value
    """
    def __init__(self, in_dim, hidden=64, depth=3):
        super().__init__()
        layers = []
        d = in_dim
        for _ in range(depth):
            layers += [nn.Linear(d, hidden), nn.Tanh()]
            d = hidden
        layers.append(nn.Linear(d, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


# Test model
print("Testing model architecture...")
in_dim = 2 + NX_COARSE  # [x, t, coarse_field]
test_model = DownscaleNet(in_dim).to(DEVICE)
test_input = torch.randn(10, in_dim).to(DEVICE)
test_output = test_model(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")
print(f"Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")
del test_model, test_input, test_output
print("Model test passed ✓")

Testing model architecture...
Input shape: torch.Size([10, 18])
Output shape: torch.Size([10, 1])
Model parameters: 9,601
Model test passed ✓


In [None]:
# =============================================================================
# CORRECT PHYSICS-INFORMED LOSS
# =============================================================================

def compute_losses_CORRECT(model, xb, yb, lambda_phys, dx, dt, for_training):
    """
    CORRECT PINN implementation:

    Enforces: ∂u/∂t = α_simple × ∂²u/∂x²

    where α_simple is constant (simplified approximation of α(x))

    This is CORRECT PDE structure with SIMPLIFIED coefficient!
    """
    xb = xb.to(DEVICE)
    yb = yb.to(DEVICE)

    # Extract coordinates and enable gradients
    xcoord = xb[:, [0]].clone()
    tcoord = xb[:, [1]].clone()
    xcoord.requires_grad_(True)
    tcoord.requires_grad_(True)

    # Forward pass
    inp = torch.cat([xcoord, tcoord, xb[:, 2:]], dim=1)
    pred = model(inp)

    # Data loss (MSE)
    mse = F.mse_loss(pred, yb)

    # Physics loss: Enforce heat equation with simplified alpha
    # ∂u/∂t = α_simple × ∂²u/∂x²

    # Compute ∂u/∂t
    u_t = torch.autograd.grad(
        pred, tcoord,
        grad_outputs=torch.ones_like(pred),
        create_graph=True,
        retain_graph=True
    )[0]

    # Compute ∂u/∂x
    u_x = torch.autograd.grad(
        pred, xcoord,
        grad_outputs=torch.ones_like(pred),
        create_graph=True,
        retain_graph=True
    )[0]

    # Compute ∂²u/∂x²
    u_xx = torch.autograd.grad(
        u_x, xcoord,
        grad_outputs=torch.ones_like(u_x),
        create_graph=for_training
    )[0]

    # Physics residual: R = ∂u/∂t - α_simple × ∂²u/∂x²
    # Should be ≈ 0 for solutions of heat equation
    alpha_simple = torch.tensor(ALPHA_PINN_SIMPLE, device=DEVICE)
    residual = u_t - alpha_simple * u_xx

    # Physics loss (mean squared residual)
    phys = torch.mean(residual**2)

    # Total loss
    total = mse + lambda_phys * phys

    return pred, mse, phys, total


print("Physics loss function defined ✓")
print(f"PINN enforces: ∂u/∂t = {ALPHA_PINN_SIMPLE}·∂²u/∂x²")
print("This is CORRECT structure with SIMPLIFIED coefficient!")

Physics loss function defined ✓
PINN enforces: ∂u/∂t = 0.01·∂²u/∂x²
This is CORRECT structure with SIMPLIFIED coefficient!


In [None]:
# =============================================================================
# TRAINING UTILITIES
# =============================================================================

def run_epoch(model, loader, optimizer, lam, dx, dt, train=True):
    """Run one epoch of training or evaluation"""
    model.train() if train else model.eval()

    tot_loss = 0.0
    mse_sum = 0.0
    phys_sum = 0.0
    N = 0

    for xb, yb in loader:
        if train:
            optimizer.zero_grad()

        _, mse, phys, total = compute_losses_CORRECT(
            model, xb, yb, lam, dx, dt, train
        )

        if train:
            total.backward()
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        bs = xb.size(0)
        tot_loss += total.item() * bs
        mse_sum += mse.item() * bs
        phys_sum += phys.item() * bs
        N += bs

    return tot_loss/N, mse_sum/N, phys_sum/N


def auto_ai_update_lambda(lam, mse_val, phys_val):
    """
    Auto-AI Validator: Find best λ on validation set

    Tests 3 candidates: λ/2, λ, 2λ
    Picks the one that minimizes: MSE_val + λ·Phys_val
    """
    candidates = [lam, lam/2, lam*2]
    candidates = [max(1e-4, min(c, 100.0)) for c in candidates]
    J = [mse_val + c * phys_val for c in candidates]
    best_idx = int(np.argmin(J))
    return candidates[best_idx]


print("Training utilities defined ✓")

Training utilities defined ✓


In [None]:
# =============================================================================
# GENERATE GROUND TRUTH DATA
# =============================================================================

print("\n" + "="*70)
print("GENERATING GROUND TRUTH CLIMATE DATA")
print("="*70)

x_fine, dx, dt, alpha_true, C, fine = generate_climate_fields_realistic()

print("\n" + "="*70)
print("Data generation complete ✓")
print("="*70)


GENERATING GROUND TRUTH CLIMATE DATA

DATA GENERATION
Ground truth α(x) range: [0.0040, 0.0160]
Spatial variation: 299.8%
PINN simplified α: 0.0100 (constant)
This is APPROXIMATION (not wrong!)
Stability ratio r = 0.250 (safe < 0.5)
Time step dt = 0.003937
Spatial step dx = 0.015873

Generating fields...
  Generated 500/2000 fields
  Generated 1000/2000 fields
  Generated 1500/2000 fields
  Generated 2000/2000 fields
  Generated 2000/2000 fields ✓
Fine data shape: (2000, 10, 64)
Coarse data shape: (2000, 10, 16)

Data generation complete ✓


In [None]:
# Need to add after data generation:

print("\n" + "="*70)
print("PREPARING DATASETS")
print("="*70)

# Split indices
idx_train = np.arange(0, 1600)      # 80%
idx_val = np.arange(1600, 1800)     # 10%
idx_test = np.arange(1800, 2000)    # 10%

print(f"Train fields: {len(idx_train)}")
print(f"Val fields: {len(idx_val)}")
print(f"Test fields: {len(idx_test)}")

# Build datasets
print("\nBuilding datasets...")
Xtr, Ytr = build_dataset(x_fine, C[idx_train], fine[idx_train])
Xva, Yva = build_dataset(x_fine, C[idx_val], fine[idx_val])
Xte, Yte = build_dataset(x_fine, C[idx_test], fine[idx_test])

print(f"Train samples: {Xtr.shape[0]:,}")
print(f"Val samples: {Xva.shape[0]:,}")
print(f"Test samples: {Xte.shape[0]:,}")

# Create data loaders
train_loader = DataLoader(
    ClimateDataset(Xtr, Ytr),
    batch_size=BATCH_TRAIN,
    shuffle=True
)
val_loader = DataLoader(
    ClimateDataset(Xva, Yva),
    batch_size=BATCH_VAL
)
test_loader = DataLoader(
    ClimateDataset(Xte, Yte),
    batch_size=BATCH_VAL
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
print("="*70)


PREPARING DATASETS
Train fields: 1600
Val fields: 200
Test fields: 200

Building datasets...
Train samples: 921,600
Val samples: 115,200
Test samples: 115,200

Train batches: 3600
Val batches: 225
Test batches: 225


In [None]:
# model initialization

print("\n" + "="*70)
print("INITIALIZING MODELS")
print("="*70)

in_dim = Xtr.shape[1]
print(f"Input dimension: {in_dim}")

base_model = DownscaleNet(in_dim).to(DEVICE)

PINN = copy.deepcopy(base_model)
AUTO = copy.deepcopy(base_model)

print(f"PINN parameters: {sum(p.numel() for p in PINN.parameters()):,}")
print(f"AUTO parameters: {sum(p.numel() for p in AUTO.parameters()):,}")

opt_pinn = torch.optim.Adam(PINN.parameters(), lr=LR)
opt_auto = torch.optim.Adam(AUTO.parameters(), lr=LR)

lam_pinn = LAMBDA_INIT
lam_auto = LAMBDA_INIT
lam_history = []

print(f"Initial λ (both): {LAMBDA_INIT}")
print("="*70)


INITIALIZING MODELS
Input dimension: 18
PINN parameters: 9,601
AUTO parameters: 9,601
Initial λ (both): 10.0


In [None]:
# PINN Training
print("\n" + "="*70)
print("TRAINING PINN")
print(f"Fixed λ = {lam_pinn:.2f}")
print("="*70 + "\n")

for ep in range(1, EPOCHS+1):
    tr_tot, tr_mse, tr_phys = run_epoch(
        PINN, train_loader, opt_pinn, lam_pinn, dx, dt, train=True
    )
    va_tot, va_mse, va_phys = run_epoch(
        PINN, val_loader, None, lam_pinn, dx, dt, train=False
    )

    if ep % 10 == 0 or ep == 1 or ep == EPOCHS:
        print(f"[PINN] Epoch {ep:3d}/{EPOCHS} | "
              f"Train: MSE={tr_mse:.4e} Phys={tr_phys:.4e} | "
              f"Val: MSE={va_mse:.4e} Phys={va_phys:.4e}")

print("\n" + "="*70)
print("PINN training complete ✓")
print("="*70)

# Auto-AI Training
print("\n" + "="*70)
print("TRAINING AUTO-AI")
print(f"Adaptive λ (starts at {lam_auto:.2f})")
print("="*70 + "\n")

for ep in range(1, EPOCHS+1):
    tr_tot, tr_mse, tr_phys = run_epoch(
        AUTO, train_loader, opt_auto, lam_auto, dx, dt, train=True
    )
    va_tot, va_mse, va_phys = run_epoch(
        AUTO, val_loader, None, lam_auto, dx, dt, train=False
    )

    # UPDATE λ
    lam_auto = auto_ai_update_lambda(lam_auto, va_mse, va_phys)
    lam_history.append(lam_auto)

    if ep % 10 == 0 or ep == 1 or ep == EPOCHS:
        print(f"[AUTO] Epoch {ep:3d}/{EPOCHS} | "
              f"Train: MSE={tr_mse:.4e} Phys={tr_phys:.4e} | "
              f"Val: MSE={va_mse:.4e} Phys={va_phys:.4e} | "
              f"λ={lam_auto:.4e}")

print("\n" + "="*70)
print("Auto-AI training complete ✓")
print(f"Final λ = {lam_auto:.4e}")
print("="*70)


TRAINING PINN
Fixed λ = 10.00

[PINN] Epoch   1/100 | Train: MSE=5.6508e+00 Phys=1.6477e-02 | Val: MSE=2.7855e+00 Phys=2.0367e-02
[PINN] Epoch  10/100 | Train: MSE=5.4205e-01 Phys=9.5573e-03 | Val: MSE=4.6484e-01 Phys=7.6368e-03
[PINN] Epoch  20/100 | Train: MSE=5.1302e-01 Phys=7.2592e-03 | Val: MSE=4.3703e-01 Phys=7.7247e-03
[PINN] Epoch  30/100 | Train: MSE=4.6591e-01 Phys=6.3122e-03 | Val: MSE=4.0113e-01 Phys=9.3638e-03
[PINN] Epoch  40/100 | Train: MSE=4.4387e-01 Phys=5.6276e-03 | Val: MSE=4.9700e-01 Phys=9.2291e-03
[PINN] Epoch  50/100 | Train: MSE=4.3106e-01 Phys=5.2231e-03 | Val: MSE=3.5679e-01 Phys=4.3712e-03
[PINN] Epoch  60/100 | Train: MSE=4.2699e-01 Phys=5.0824e-03 | Val: MSE=3.8101e-01 Phys=7.7784e-03
[PINN] Epoch  70/100 | Train: MSE=4.1440e-01 Phys=5.0493e-03 | Val: MSE=3.9043e-01 Phys=2.2991e-03
[PINN] Epoch  80/100 | Train: MSE=4.0525e-01 Phys=4.4326e-03 | Val: MSE=4.1392e-01 Phys=2.8546e-03
[PINN] Epoch  90/100 | Train: MSE=4.0049e-01 Phys=4.3964e-03 | Val: MSE=3.621

In [None]:
# Test
print("\n" + "="*70)
print("FINAL TEST EVALUATION")
print("="*70)

_, mse_pinn, phys_pinn = run_epoch(
    PINN, test_loader, None, lam_pinn, dx, dt, train=False
)
tot_pinn = mse_pinn + lam_pinn * phys_pinn

_, mse_auto, phys_auto = run_epoch(
    AUTO, test_loader, None, lam_auto, dx, dt, train=False
)
tot_auto = mse_auto + lam_auto * phys_auto

print(f"\n{'Method':<10} {'Test MSE':<12} {'Phys Loss':<12} {'Total Loss':<12} {'λ':<10}")
print("-" * 70)
print(f"{'PINN':<10} {mse_pinn:<12.4e} {phys_pinn:<12.4e} {tot_pinn:<12.4e} {lam_pinn:<10.2f}")
print(f"{'Auto-AI':<10} {mse_auto:<12.4e} {phys_auto:<12.4e} {tot_auto:<12.4e} {lam_auto:<10.4e}")

improvement = (mse_pinn - mse_auto) / mse_pinn * 100
print(f"\nAuto-AI MSE improvement: {improvement:+.2f}%")
print("="*70)

# Visualizations
# (Add plotting code from previous cells)


FINAL TEST EVALUATION


NameError: name 'run_epoch' is not defined

In [None]:
# =============================================================================
# VISUALIZATION: TEST METRICS
# =============================================================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

labels = ['PINN', 'Auto-AI']
x_pos = np.arange(len(labels))

# MSE comparison
axes[0].bar(x_pos, [mse_pinn, mse_auto],
           color=['#E74C3C', '#3498DB'], alpha=0.8, edgecolor='black')
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels(labels, fontsize=12)
axes[0].set_ylabel('Test MSE', fontsize=12)
axes[0].set_title('Test MSE (Lower is Better)', fontsize=13, fontweight='bold')
axes[0].grid(axis='y', alpha=0.3, linestyle='--')

# Physics loss comparison
axes[1].bar(x_pos, [phys_pinn, phys_auto],
           color=['#E74C3C', '#3498DB'], alpha=0.8, edgecolor='black')
axes[1].set_xticks(x_pos)
axes[1].set_xticklabels(labels, fontsize=12)
axes[1].set_ylabel('Physics Loss', fontsize=12)
axes[1].set_title('Physics Residual', fontsize=13, fontweight='bold')
axes[1].grid(axis='y', alpha=0.3, linestyle='--')

# Lambda evolution
axes[2].plot(range(1, len(lam_history)+1), lam_history,
            'b-', linewidth=2, label=f'Auto-AI λ(t)')
axes[2].axhline(y=lam_pinn, color='r', linestyle='--',
               linewidth=2, label=f'PINN λ={lam_pinn:.1f}')
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('λ (log scale)', fontsize=12)
axes[2].set_title('Lambda Evolution', fontsize=13, fontweight='bold')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3, linestyle='--')
axes[2].legend(fontsize=11)

plt.tight_layout()
plt.show()

print("Key Observations:")
print(f"• PINN maintains λ={lam_pinn:.1f} throughout training")
print(f"• Auto-AI adapts λ from {LAMBDA_INIT:.1f} → {lam_auto:.4f}")
print(f"• Final λ ratio: {lam_auto/lam_pinn:.2e}x of initial")

In [None]:
# =============================================================================
# LAMBDA STABILITY DIAGNOSTIC
# =============================================================================

def test_lambda_stability(epochs_test=50, verbose=True):
    """
    Test lambda evolution to verify it stabilizes rather than
    oscillating or hitting bounds

    Returns diagnostic metrics
    """
    print("\n" + "="*70)
    print("LAMBDA STABILITY DIAGNOSTIC TEST")
    print("="*70)

    # Use small subset for quick testing
    idx_mini_train = np.arange(0, 200)
    idx_mini_val = np.arange(200, 250)

    Xtr_mini, Ytr_mini = build_dataset(x_fine, C[idx_mini_train], fine[idx_mini_train])
    Xva_mini, Yva_mini = build_dataset(x_fine, C[idx_mini_val], fine[idx_mini_val])

    mini_train_loader = DataLoader(ClimateDataset(Xtr_mini, Ytr_mini),
                                   batch_size=128, shuffle=True)
    mini_val_loader = DataLoader(ClimateDataset(Xva_mini, Yva_mini),
                                 batch_size=256)

    # Initialize model
    test_model = DownscaleNet(Xtr_mini.shape[1]).to(DEVICE)
    test_opt = torch.optim.Adam(test_model.parameters(), lr=LR)

    lam = LAMBDA_INIT
    lam_hist = []
    mse_hist = []
    phys_hist = []
    j_hist = []

    print(f"\nRunning {epochs_test} epochs with adaptive λ...")
    print(f"Initial λ = {lam:.4f}\n")

    for ep in range(1, epochs_test + 1):
        # Train
        _, tr_mse, tr_phys = run_epoch(
            test_model, mini_train_loader, test_opt, lam, dx, dt, train=True
        )

        # Validate
        _, va_mse, va_phys = run_epoch(
            test_model, mini_val_loader, None, lam, dx, dt, train=False
        )

        # Update λ
        lam_old = lam
        lam = auto_ai_update_lambda(lam, va_mse, va_phys)

        # Track history
        lam_hist.append(lam)
        mse_hist.append(va_mse)
        phys_hist.append(va_phys)
        j_hist.append(va_mse + lam * va_phys)

        if verbose and (ep % 5 == 0 or ep == 1):
            change = "CHANGED" if lam != lam_old else "STABLE"
            print(f"Epoch {ep:3d}: λ={lam:.6f} | MSE={va_mse:.4e} | "
                  f"Phys={va_phys:.4e} | {change}")

    # Analyze stability
    print("\n" + "="*70)
    print("STABILITY ANALYSIS")
    print("="*70)

    lam_array = np.array(lam_hist)

    # Check if λ hit bounds
    hit_lower = (lam_array[-1] <= 1.1e-4)
    hit_upper = (lam_array[-1] >= 99.0)

    # Check for oscillation (changes direction frequently)
    if len(lam_hist) > 10:
        changes = np.diff(lam_array[-20:])
        direction_changes = np.sum(changes[:-1] * changes[1:] < 0)
        is_oscillating = direction_changes > 5
    else:
        is_oscillating = False

    # Check for stabilization (last 10 epochs have small changes)
    if len(lam_hist) >= 10:
        recent_std = np.std(lam_array[-10:])
        recent_mean = np.mean(lam_array[-10:])
        relative_variation = recent_std / (recent_mean + 1e-10)
        is_stable = relative_variation < 0.01  # Less than 1% variation
    else:
        is_stable = False
        relative_variation = np.inf

    print(f"\nFinal λ: {lam:.6e}")
    print(f"λ range: [{lam_array.min():.6e}, {lam_array.max():.6e}]")
    print(f"Recent variation (last 10): {relative_variation:.4%}")

    print(f"\n✓ Hit lower bound (1e-4): {'YES ❌' if hit_lower else 'NO ✓'}")
    print(f"✓ Hit upper bound (100): {'YES ❌' if hit_upper else 'NO ✓'}")
    print(f"✓ Oscillating: {'YES ❌' if is_oscillating else 'NO ✓'}")
    print(f"✓ Stabilized: {'YES ✓' if is_stable else 'NO ❌'}")

    # Visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # λ evolution
    axes[0, 0].plot(lam_hist, 'b-', linewidth=2)
    axes[0, 0].axhline(y=1e-4, color='r', linestyle='--', alpha=0.5, label='Lower bound')
    axes[0, 0].axhline(y=100, color='r', linestyle='--', alpha=0.5, label='Upper bound')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('λ')
    axes[0, 0].set_title('Lambda Evolution')
    axes[0, 0].set_yscale('log')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()

    # MSE evolution
    axes[0, 1].plot(mse_hist, 'g-', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Validation MSE')
    axes[0, 1].set_title('MSE Evolution')
    axes[0, 1].set_yscale('log')
    axes[0, 1].grid(True, alpha=0.3)

    # Physics loss evolution
    axes[1, 0].plot(phys_hist, 'orange', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Validation Physics Loss')
    axes[1, 0].set_title('Physics Loss Evolution')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)

    # Total J evolution
    axes[1, 1].plot(j_hist, 'purple', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('J = MSE + λ·Phys')
    axes[1, 1].set_title('Validation Loss (J) Evolution')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Return diagnostics
    return {
        'final_lambda': lam,
        'lambda_history': lam_hist,
        'hit_bounds': hit_lower or hit_upper,
        'oscillating': is_oscillating,
        'stabilized': is_stable,
        'variation': relative_variation
    }

# Run the diagnostic test
diagnostics = test_lambda_stability(epochs_test=50, verbose=True)

In [None]:
# =============================================================================
# HORIZON ERROR ANALYSIS
# =============================================================================

def evaluate_horizon_error(model, x_fine, C, fine, field_id, horizons, Nt, dx, dt):
    """
    Compute forecast error at different horizons using autoregressive prediction

    Args:
        model: Trained neural network
        x_fine: Fine grid coordinates
        C: Coarse observations array
        fine: Ground truth fine fields
        field_id: Which field to evaluate
        horizons: List of horizons to test [1, 2, 3, 5, 8]
        Nt: Number of time steps
        dx: Spatial resolution
        dt: Time step

    Returns:
        List of MSE values for each horizon
    """
    errors = []
    model.eval()

    Nx_coarse = C.shape[2]
    Nx_fine = len(x_fine)
    block = Nx_fine // Nx_coarse

    for h in horizons:
        h_actual = min(h, Nt - 1)

        # Start from t=0 with true coarse observation
        coarse_current = C[field_id, 0].copy()

        # Autoregressive prediction for h steps
        for step in range(h_actual):
            t_norm = step / (Nt - 1)

            # Predict fine field for all spatial points
            fine_pred = np.zeros(Nx_fine, dtype=np.float32)

            for j in range(Nx_fine):
                # Create input: [x, t, coarse_field]
                inp = np.concatenate([[x_fine[j], t_norm], coarse_current])
                inp_tensor = torch.from_numpy(inp).float().unsqueeze(0).to(DEVICE)

                # Predict
                with torch.no_grad():
                    fine_pred[j] = model(inp_tensor).cpu().item()

            # Update coarse for next step (if not last step)
            if step < h_actual - 1:
                # Downsample predicted fine to coarse
                coarse_current = fine_pred.reshape(Nx_coarse, block).mean(axis=1)

        # Compute MSE against ground truth at time h
        gt = fine[field_id, h_actual]
        mse = np.mean((fine_pred - gt) ** 2)
        errors.append(mse)

        print(f"  Horizon {h}: MSE = {mse:.6e}")

    return errors


def plot_horizon_bar_chart(horizons, err_pinn, err_auto, save_path='horizon_1.png'):
    """
    Create bar chart comparing PINN vs Auto-AI at different horizons

    Args:
        horizons: List of horizon values [1, 2, 3, 5, 8]
        err_pinn: List of PINN MSE values
        err_auto: List of Auto-AI MSE values
        save_path: Where to save the figure
    """
    plt.figure(figsize=(10, 6))

    x = np.arange(len(horizons))
    width = 0.35

    # Create bars
    bars1 = plt.bar(x - width/2, err_pinn, width, label='PINN',
                    color='#d62728', alpha=0.8, edgecolor='black', linewidth=1.2)
    bars2 = plt.bar(x + width/2, err_auto, width, label='Auto-AI',
                    color='#1f77b4', alpha=0.8, edgecolor='black', linewidth=1.2)

    # Customize plot
    plt.xlabel('Forecast Horizon (steps ahead)', fontsize=13, fontweight='bold')
    plt.ylabel('Mean Squared Error (MSE)', fontsize=13, fontweight='bold')
    plt.title('Forecast Error vs Horizon: PINN vs Auto-AI',
              fontsize=14, fontweight='bold', pad=15)
    plt.xticks(x, horizons, fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(fontsize=12, loc='upper left', framealpha=0.9)
    plt.grid(axis='y', alpha=0.3, linestyle='--', linewidth=0.8)

    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2e}',
                    ha='center', va='bottom', fontsize=9, rotation=0)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"\n✓ Horizon plot saved to: {save_path}")
    plt.show()


# =============================================================================
# RUN HORIZON ANALYSIS
# =============================================================================

print("\n" + "="*70)
print("HORIZON ERROR ANALYSIS")
print("="*70)

# Select a test field to analyze
field_id = idx_test[0]
print(f"\nAnalyzing field ID: {field_id}")

# Define horizons to test
horizons = [1, 2, 3, 5, 8]
print(f"Testing horizons: {horizons}")

# Evaluate PINN
print("\n[PINN] Computing horizon errors...")
err_pinn = evaluate_horizon_error(
    PINN, x_fine, C, fine, field_id, horizons, NT_STEPS, dx, dt
)

# Evaluate Auto-AI
print("\n[Auto-AI] Computing horizon errors...")
err_auto = evaluate_horizon_error(
    AUTO, x_fine, C, fine, field_id, horizons, NT_STEPS, dx, dt
)

# Print comparison table
print("\n" + "="*70)
print("HORIZON ERROR COMPARISON")
print("="*70)
print(f"{'Horizon':<10} {'PINN MSE':<15} {'Auto-AI MSE':<15} {'Improvement':<15}")
print("-"*70)
for h, ep, ea in zip(horizons, err_pinn, err_auto):
    improvement = (ep - ea) / ep * 100 if ep > 0 else 0
    print(f"{h:<10} {ep:<15.6e} {ea:<15.6e} {improvement:+.2f}%")
print("="*70)

# Create and save bar chart
print("\nCreating horizon bar chart...")
plot_horizon_bar_chart(horizons, err_pinn, err_auto, save_path='horizon_1.png')

print("\n✓ Horizon analysis complete!")