In [None]:
#
# This script systematically trains different model architectures as regression models
# to conduct a scientifically rigorous architectural generalization study.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import json
import gc

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'architecture_study_models'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Training Hyperparameters (MUST remain constant for all models) ---
EPOCHS = 25
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EARLY_STOPPING_PATIENCE = 7
GRADIENT_CLIP = 1.0
INPUT_CHANNELS = 5 # All models will use the full 5-variable input

# --- 2. DATASET & LOSS FUNCTION ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'train'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0)

class AdvancedLoss(nn.Module):
    def __init__(self, mse_weight=0.6, ssim_weight=0.2, gradient_weight=0.2):
        super().__init__()
        self.mse_weight, self.ssim_weight, self.gradient_weight = mse_weight, ssim_weight, gradient_weight
        self.mse_loss, self.l1_loss = nn.MSELoss(), nn.L1Loss()
    def ssim_loss(self, pred, target, window_size=11):
        mu1 = F.avg_pool2d(pred, window_size, 1, padding=window_size//2)
        mu2 = F.avg_pool2d(target, window_size, 1, padding=window_size//2)
        mu1_sq, mu2_sq, mu1_mu2 = mu1**2, mu2**2, mu1 * mu2
        sigma1_sq = F.avg_pool2d(pred**2, window_size, 1, padding=window_size//2) - mu1_sq
        sigma2_sq = F.avg_pool2d(target**2, window_size, 1, padding=window_size//2) - mu2_sq
        sigma12 = F.avg_pool2d(pred*target, window_size, 1, padding=window_size//2) - mu1_mu2
        c1, c2 = 0.01**2, 0.03**2
        ssim_map = ((2*mu1_mu2 + c1) * (2*sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
        return 1 - ssim_map.mean()
    def gradient_loss(self, pred, target):
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
        pred_grad_x, pred_grad_y = F.conv2d(pred, sobel_x, padding=1), F.conv2d(pred, sobel_y, padding=1)
        target_grad_x, target_grad_y = F.conv2d(target, sobel_x, padding=1), F.conv2d(target, sobel_y, padding=1)
        return self.l1_loss(pred_grad_x, target_grad_x) + self.l1_loss(pred_grad_y, target_grad_y)
    def forward(self, pred, target):
        mse = self.mse_loss(pred, target)
        ssim = self.ssim_loss(pred, target)
        grad = self.gradient_loss(pred, target)
        return (self.mse_weight * mse + self.ssim_weight * ssim + self.gradient_weight * grad)

# --- 3. MODEL ARCHITECTURES ---
# (Note: Final activation is NOT Tanh for regression models)
class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class OriginalUNet(nn.Module): # This is your MultiVarAdvancedDownscaler
    def __init__(self, input_channels=5, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        identity = x; out = self.relu(self.bn1(self.conv1(x))); out = self.bn2(self.conv2(out)); out += identity
        return self.relu(out)

class ResNetUNet(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4):
        super().__init__()
        self.depth = depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.initial_conv = nn.Conv2d(input_channels, channels[0], 3, padding=1)
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        for i, ch in enumerate(channels):
            self.encoders.append(nn.Sequential(ResNetBlock(ch), ResNetBlock(ch)))
            if i < len(channels) - 1: self.downsamplers.append(nn.Conv2d(ch, channels[i+1], 3, stride=2, padding=1))
        self.bottleneck = nn.Sequential(ResNetBlock(channels[-1]), ResNetBlock(channels[-1]), ResNetBlock(channels[-1]))
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, 0, -1):
            self.upsamplers.append(nn.ConvTranspose2d(channels[i], channels[i-1], 2, stride=2))
            self.decoders.append(nn.Sequential(ResNetBlock(channels[i-1] * 2), ResNetBlock(channels[i-1] * 2), nn.Conv2d(channels[i-1] * 2, channels[i-1], 1)))
        self.final_conv = nn.Sequential(nn.Conv2d(channels[0], channels[0]//2, 3, padding=1), nn.BatchNorm2d(channels[0]//2), nn.ReLU(inplace=True), nn.Conv2d(channels[0]//2, 1, 1))
    def forward(self, x):
        x = self.initial_conv(x); skips = []
        for i, (encoder, downsampler) in enumerate(zip(self.encoders, self.downsamplers)):
            x = encoder(x); skips.append(x); x = downsampler(x)
        x = self.encoders[-1](x); skips.append(x); x = self.bottleneck(x)
        for i, (upsampler, decoder) in enumerate(zip(self.upsamplers, self.decoders)):
            x = upsampler(x); skip = skips[len(skips) - 2 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = decoder(x)
        return self.final_conv(x)

class LightweightCNN(nn.Module):
    def __init__(self, input_channels=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 32, 7, padding=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 3, padding=1) # No Tanh
        )
    def forward(self, x):
        original_size = x.shape[-2:]
        out = self.features(x)
        if out.shape[-2:] != original_size:
            out = F.interpolate(out, size=original_size, mode='bilinear', align_corners=False)
        return out

# --- 4. ARCHITECTURE CONFIGURATION ---
ARCHITECTURE_CONFIGS = {
    'original_unet': {'class': OriginalUNet, 'params': {'input_channels': INPUT_CHANNELS}},
    'resnet_unet': {'class': ResNetUNet, 'params': {'input_channels': INPUT_CHANNELS}},
    'attention_unet': {'class': OriginalUNet, 'params': {'input_channels': INPUT_CHANNELS, 'use_attention': True}}, # Your original is an attention unet
    'lightweight_cnn': {'class': LightweightCNN, 'params': {'input_channels': INPUT_CHANNELS}},
}

# --- 5. CONSISTENT TRAINING FUNCTION ---
def train_one_architecture(config):
    arch_name = config['name']

    print("\n" + "="*80)
    print(f"🏛️ STARTING ARCHITECTURE RUN: {arch_name}")
    print("="*80)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = MultiVariableARDataset(DATA_DIR, 'train')
    val_dataset = MultiVariableARDataset(DATA_DIR, 'val')
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = config['class'](**config['params']).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    loss_fn = AdvancedLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [{arch_name}]")
        for predictor, target in pbar:
            predictor, target = predictor.to(device), target.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                prediction = model(predictor)
                loss = loss_fn(prediction, target)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            scaler.step(optimizer)
            scaler.update()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for predictor, target in val_loader:
                predictor, target = predictor.to(device), target.to(device)
                prediction = model(predictor)
                val_loss += loss_fn(prediction, target).item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} | Val Loss: {avg_val_loss:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            model_path = OUTPUT_DIR / f"arch_{arch_name}.pt"
            metadata_path = OUTPUT_DIR / f"arch_{arch_name}.json"
            torch.save(model.state_dict(), model_path)
            with open(metadata_path, 'w') as f:
                # Save info needed to reload the model later
                json.dump({
                    'name': arch_name,
                    'class': model.__class__.__name__,
                    'params': config['params']
                }, f, indent=2)
            print(f"   ✅ Best model saved to {model_path}")
        else:
            patience_counter += 1

        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"   🛑 Early stopping triggered.")
            break

    print(f"🏛️ FINISHED ARCHITECTURE RUN: {arch_name} | Best Val Loss: {best_val_loss:.6f}")
    gc.collect()
    torch.cuda.empty_cache()

# --- 6. MAIN EXECUTION SCRIPT ---
def main():
    for name, config in ARCHITECTURE_CONFIGS.items():
        config['name'] = name
        train_one_architecture(config)

    print("\n\n🎉🎉🎉 All architecture models trained successfully! 🎉🎉🎉")

if __name__ == "__main__":
    main()



🏛️ STARTING ARCHITECTURE RUN: original_unet


Epoch 1/25 [original_unet]: 100%|██████████| 150/150 [03:52<00:00,  1.55s/it, loss=0.6161]


Epoch 1 | Val Loss: 0.640447
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_original_unet.pt


Epoch 2/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=1.0256]


Epoch 2 | Val Loss: 14.177297


Epoch 3/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.79it/s, loss=2.3693]


Epoch 3 | Val Loss: 2.716541


Epoch 4/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.79it/s, loss=1.8890]


Epoch 4 | Val Loss: 2.593938


Epoch 5/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.79it/s, loss=2.0437]


Epoch 5 | Val Loss: 5.336793


Epoch 6/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=3.2288]


Epoch 6 | Val Loss: 6.151578


Epoch 7/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.81it/s, loss=3.8546]


Epoch 7 | Val Loss: 6.281056


Epoch 8/25 [original_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=3.4636]


Epoch 8 | Val Loss: 8.409728
   🛑 Early stopping triggered.
🏛️ FINISHED ARCHITECTURE RUN: original_unet | Best Val Loss: 0.640447

🏛️ STARTING ARCHITECTURE RUN: resnet_unet


Epoch 1/25 [resnet_unet]: 100%|██████████| 150/150 [01:35<00:00,  1.58it/s, loss=0.7202]


Epoch 1 | Val Loss: 0.815562
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_resnet_unet.pt


Epoch 2/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.0095]


Epoch 2 | Val Loss: 1.556403


Epoch 3/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.9695]


Epoch 3 | Val Loss: 2.878543


Epoch 4/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.2798]


Epoch 4 | Val Loss: 2.071224


Epoch 5/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.1761]


Epoch 5 | Val Loss: 1.065803


Epoch 6/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=0.7664]


Epoch 6 | Val Loss: 0.894059


Epoch 7/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.3900]


Epoch 7 | Val Loss: 1.285973


Epoch 8/25 [resnet_unet]: 100%|██████████| 150/150 [01:15<00:00,  1.98it/s, loss=1.0887]


Epoch 8 | Val Loss: 1.120279
   🛑 Early stopping triggered.
🏛️ FINISHED ARCHITECTURE RUN: resnet_unet | Best Val Loss: 0.815562

🏛️ STARTING ARCHITECTURE RUN: attention_unet


Epoch 1/25 [attention_unet]: 100%|██████████| 150/150 [00:48<00:00,  3.08it/s, loss=0.5930]


Epoch 1 | Val Loss: 57.362974
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_attention_unet.pt


Epoch 2/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=0.9696]


Epoch 2 | Val Loss: 1.586533
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_attention_unet.pt


Epoch 3/25 [attention_unet]: 100%|██████████| 150/150 [00:32<00:00,  4.66it/s, loss=1.5181]


Epoch 3 | Val Loss: 4.686310


Epoch 4/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=1.7635]


Epoch 4 | Val Loss: 4.739056


Epoch 5/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.79it/s, loss=2.0737]


Epoch 5 | Val Loss: 3.381161


Epoch 6/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=2.2166]


Epoch 6 | Val Loss: 7.329491


Epoch 7/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.81it/s, loss=2.3247]


Epoch 7 | Val Loss: 1.960474


Epoch 8/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=2.0185]


Epoch 8 | Val Loss: 1.812630


Epoch 9/25 [attention_unet]: 100%|██████████| 150/150 [00:31<00:00,  4.80it/s, loss=2.4235]


Epoch 9 | Val Loss: 1.949027
   🛑 Early stopping triggered.
🏛️ FINISHED ARCHITECTURE RUN: attention_unet | Best Val Loss: 1.586533

🏛️ STARTING ARCHITECTURE RUN: lightweight_cnn


Epoch 1/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:23<00:00,  6.32it/s, loss=0.6553]


Epoch 1 | Val Loss: 0.645011
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_lightweight_cnn.pt


Epoch 2/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.06it/s, loss=0.7888]


Epoch 2 | Val Loss: 0.629705
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_lightweight_cnn.pt


Epoch 3/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 19.97it/s, loss=0.5621]


Epoch 3 | Val Loss: 0.634293


Epoch 4/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.52it/s, loss=0.4842]


Epoch 4 | Val Loss: 0.631113


Epoch 5/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 19.84it/s, loss=0.6128]


Epoch 5 | Val Loss: 0.614137
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_lightweight_cnn.pt


Epoch 6/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.47it/s, loss=0.6296]


Epoch 6 | Val Loss: 0.618693


Epoch 7/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 19.73it/s, loss=0.5259]


Epoch 7 | Val Loss: 0.629660


Epoch 8/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.00it/s, loss=0.9393]


Epoch 8 | Val Loss: 0.616834


Epoch 9/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.19it/s, loss=1.0355]


Epoch 9 | Val Loss: 0.629355


Epoch 10/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 21.00it/s, loss=0.7389]


Epoch 10 | Val Loss: 0.612613
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_lightweight_cnn.pt


Epoch 11/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 19.97it/s, loss=0.6777]


Epoch 11 | Val Loss: 0.606358
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/architecture_study_models/arch_lightweight_cnn.pt


Epoch 12/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.23it/s, loss=0.8178]


Epoch 12 | Val Loss: 0.609687


Epoch 13/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.01it/s, loss=0.4720]


Epoch 13 | Val Loss: 0.630632


Epoch 14/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.18it/s, loss=0.6124]


Epoch 14 | Val Loss: 0.607495


Epoch 15/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.11it/s, loss=0.6216]


Epoch 15 | Val Loss: 0.606666


Epoch 16/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 19.94it/s, loss=0.6403]


Epoch 16 | Val Loss: 0.610036


Epoch 17/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.73it/s, loss=0.6624]


Epoch 17 | Val Loss: 0.627995


Epoch 18/25 [lightweight_cnn]: 100%|██████████| 150/150 [00:07<00:00, 20.09it/s, loss=0.4933]


Epoch 18 | Val Loss: 0.608713
   🛑 Early stopping triggered.
🏛️ FINISHED ARCHITECTURE RUN: lightweight_cnn | Best Val Loss: 0.606358


🎉🎉🎉 All architecture models trained successfully! 🎉🎉🎉


In [None]:
#
# This script performs a comprehensive, quantitative evaluation comparing the
# results of the architectural generalization study.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import pandas as pd
from scipy.stats import wasserstein_distance
from scipy.ndimage import sobel
from torchvision.transforms import Resize
import json
import math

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
MODEL_DIR = PROJECT_PATH / 'architecture_study_models' # Directory where arch models were saved
OUTPUT_DIR = PROJECT_PATH / 'final_evaluation_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Evaluation Configuration ---
CSI_THRESHOLDS_K = [230.0, 220.0, 210.0]
IMG_SIZE = 256 # All these models were trained on 256x256

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class OriginalUNet(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        identity = x; out = self.relu(self.bn1(self.conv1(x))); out = self.bn2(self.conv2(out)); out += identity
        return self.relu(out)

class ResNetUNet(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4):
        super().__init__()
        self.depth = depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.initial_conv = nn.Conv2d(input_channels, channels[0], 3, padding=1)
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        for i, ch in enumerate(channels):
            self.encoders.append(nn.Sequential(ResNetBlock(ch), ResNetBlock(ch)))
            if i < len(channels) - 1: self.downsamplers.append(nn.Conv2d(ch, channels[i+1], 3, stride=2, padding=1))
        self.bottleneck = nn.Sequential(ResNetBlock(channels[-1]), ResNetBlock(channels[-1]), ResNetBlock(channels[-1]))
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, 0, -1):
            self.upsamplers.append(nn.ConvTranspose2d(channels[i], channels[i-1], 2, stride=2))
            self.decoders.append(nn.Sequential(ResNetBlock(channels[i-1] * 2), ResNetBlock(channels[i-1] * 2), nn.Conv2d(channels[i-1] * 2, channels[i-1], 1)))
        self.final_conv = nn.Sequential(nn.Conv2d(channels[0], channels[0]//2, 3, padding=1), nn.BatchNorm2d(channels[0]//2), nn.ReLU(inplace=True), nn.Conv2d(channels[0]//2, 1, 1))
    def forward(self, x):
        x = self.initial_conv(x); skips = []
        for i, (encoder, downsampler) in enumerate(zip(self.encoders, self.downsamplers)):
            x = encoder(x); skips.append(x); x = downsampler(x)
        x = self.encoders[-1](x); skips.append(x); x = self.bottleneck(x)
        for i, (upsampler, decoder) in enumerate(zip(self.upsamplers, self.decoders)):
            x = upsampler(x); skip = skips[len(skips) - 2 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = decoder(x)
        return self.final_conv(x)

class LightweightCNN(nn.Module):
    def __init__(self, input_channels=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 32, 7, padding=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 3, padding=1)
        )
    def forward(self, x):
        original_size = x.shape[-2:]
        out = self.features(x)
        if out.shape[-2:] != original_size:
            out = F.interpolate(out, size=original_size, mode='bilinear', align_corners=False)
        return out

# --- 3. METRIC & MODEL LOADING FUNCTIONS ---
def calculate_detailed_csi(pred_k, true_k, threshold_k):
    pred_mask = pred_k <= threshold_k; true_mask = true_k <= threshold_k
    hits = (pred_mask & true_mask).sum()
    misses = (~pred_mask & true_mask).sum()
    false_alarms = (pred_mask & ~true_mask).sum()
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0
    event_frequency = true_mask.sum() / true_mask.size
    return csi, event_frequency

def calculate_gradient_magnitude(image):
    grad_x = sobel(image, axis=0); grad_y = sobel(image, axis=1)
    return np.sqrt(grad_x**2 + grad_y**2).mean()

def calculate_all_metrics(pred_norm, true_norm, stats):
    pred_k = pred_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    true_k = true_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    metrics = {
        'rmse': np.sqrt(np.mean((pred_k - true_k)**2)),
        'mae': np.mean(np.abs(pred_k - true_k)),
        'sharpness': calculate_gradient_magnitude(pred_k),
        'distribution_dist': wasserstein_distance(pred_k.flatten(), true_k.flatten())
    }
    for thr in CSI_THRESHOLDS_K:
        csi, freq = calculate_detailed_csi(pred_k, true_k, thr)
        metrics[f'csi_{int(thr)}K'] = csi
        metrics[f'freq_{int(thr)}K'] = freq
    return metrics

def get_model_class_from_name(class_name):
    """Helper to get class object from its string name."""
    if class_name == 'OriginalUNet': return OriginalUNet
    if class_name == 'ResNetUNet': return ResNetUNet
    if class_name == 'LightweightCNN': return LightweightCNN
    raise ValueError(f"Unknown model class name: {class_name}")

# --- 4. MAIN EVALUATION SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔧 Starting evaluation on device: {device}")

    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    stats = test_dataset.stats

    # --- Find and Load Models Automatically ---
    metadata_files = sorted(list(MODEL_DIR.glob('arch_*.json')))
    if not metadata_files:
        print(f"❌ No model metadata files found in {MODEL_DIR}. Cannot evaluate.")
        return

    loaded_models = {}
    for meta_file in metadata_files:
        with open(meta_file, 'r') as f:
            metadata = json.load(f)

        name = metadata['name']
        model_class_name = metadata['class']
        params = metadata['params']
        model_path = meta_file.with_suffix('.pt')

        if not model_path.exists():
            print(f"⚠️ Model file not found for '{name}'. Skipping.")
            continue

        print(f"Loading model: '{name}' (Class: {model_class_name})")
        ModelClass = get_model_class_from_name(model_class_name)
        model = ModelClass(**params).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        loaded_models[name] = model

    # --- Evaluate Each Model ---
    all_results = []
    for name, model in loaded_models.items():
        print(f"\n--- Evaluating {name} ---")
        model_metrics = []
        with torch.no_grad():
            for predictor, target in tqdm(test_loader, desc=f"Predicting with {name}"):
                predictor = predictor.to(device)
                pred_norm = model(predictor).cpu().numpy().squeeze()
                target_norm = target.numpy().squeeze()
                metrics = calculate_all_metrics(pred_norm, target_norm, stats)
                model_metrics.append(metrics)

        df = pd.DataFrame(model_metrics)
        summary = {"Model": name}
        for col in df.columns:
            summary[f"{col}_mean"] = df[col].mean()
            summary[f"{col}_std"] = df[col].std()
        all_results.append(summary)

    if not all_results:
        print("❌ No models were evaluated. Exiting.")
        return

    results_df = pd.DataFrame(all_results).set_index("Model")

    # --- Reporting ---
    print("\n\n" + "="*80)
    print("🏛️ FINAL ARCHITECTURE EVALUATION REPORT 🏛️")
    print("="*80)
    print("\n--- I. Overall Performance Metrics ---")
    std_metrics_df = results_df[['rmse_mean', 'mae_mean', 'sharpness_mean', 'distribution_dist_mean']]
    print(std_metrics_df.round(4))

    print("\n--- II. Critical Success Index (CSI) by Threshold ---")

    # --- FIX: Explicitly select and order the '_mean' columns for the report ---
    csi_cols_ordered = []
    pretty_cols = []
    for thr in CSI_THRESHOLDS_K:
        csi_cols_ordered.append(f'csi_{int(thr)}K_mean')
        csi_cols_ordered.append(f'freq_{int(thr)}K_mean')
        pretty_cols.append((f"T <= {int(thr)}K", "CSI"))
        pretty_cols.append((f"T <= {int(thr)}K", "Event Freq."))

    csi_df = results_df[csi_cols_ordered]
    csi_df.columns = pd.MultiIndex.from_tuples(pretty_cols)
    csi_df = csi_df.sort_index(axis=1)

    print(csi_df.round(4))
    print("\n" + "="*80)

    results_df.to_csv(OUTPUT_DIR / 'architecture_study_evaluation.csv')
    print(f"\n💾 Detailed results saved to: {OUTPUT_DIR / 'architecture_study_evaluation.csv'}")

if __name__ == "__main__":
    main()


🔧 Starting evaluation on device: cuda
Loaded 'test' dataset with 150 samples.
Loading model: 'attention_unet' (Class: OriginalUNet)
Loading model: 'lightweight_cnn' (Class: LightweightCNN)
Loading model: 'original_unet' (Class: OriginalUNet)
Loading model: 'resnet_unet' (Class: ResNetUNet)

--- Evaluating attention_unet ---


Predicting with attention_unet: 100%|██████████| 150/150 [00:08<00:00, 16.69it/s]



--- Evaluating lightweight_cnn ---


Predicting with lightweight_cnn: 100%|██████████| 150/150 [00:07<00:00, 19.64it/s]



--- Evaluating original_unet ---


Predicting with original_unet: 100%|██████████| 150/150 [00:08<00:00, 16.91it/s]



--- Evaluating resnet_unet ---


Predicting with resnet_unet: 100%|██████████| 150/150 [00:10<00:00, 13.97it/s]



🏛️ FINAL ARCHITECTURE EVALUATION REPORT 🏛️

--- I. Overall Performance Metrics ---
                 rmse_mean  mae_mean  sharpness_mean  distribution_dist_mean
Model                                                                       
attention_unet     15.0333   13.0326          2.4606                 11.6406
lightweight_cnn     8.5555    7.2118          0.9943                  5.8837
original_unet       8.7593    7.4865          1.1162                  6.3254
resnet_unet        11.1532    9.6379          3.7013                  7.8701

--- II. Critical Success Index (CSI) by Threshold ---
                T <= 210K             T <= 220K             T <= 230K  \
                      CSI Event Freq.       CSI Event Freq.       CSI   
Model                                                                   
attention_unet     0.0630      0.0626    0.1979      0.1962    0.4555   
lightweight_cnn    0.0037      0.0626    0.1354      0.1962    0.4035   
original_unet      0.0107      0.


