In [None]:
# --- COMPREHENSIVE TEST SET EVALUATION ---
# This script loads ALL your trained models and evaluates them on the same test set
# with consistent metrics for fair comparison

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from skimage.metrics import structural_similarity as ssim
from scipy import stats
import json
import joblib
import warnings
from tqdm import tqdm
from google.colab import drive

warnings.filterwarnings('ignore')
drive.mount('/content/drive')

# --- CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
MODELS_DIR = PROJECT_PATH / 'publication_experiments'
OUTPUT_DIR = PROJECT_PATH / 'final_evaluation_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- DATASET CLASS ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files:
            raise FileNotFoundError(f"No predictor files in {self.split_dir}")

    def __len__(self):
        return len(self.predictor_files)

    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        # Normalize
        predictor_data = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (
            self.stats['predictor_std'][:, None, None] + 1e-8)
        target_data = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_data), torch.from_numpy(target_data).unsqueeze(0)

# --- MODEL ARCHITECTURES (copied from your working code) ---
class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 16, channels, 1),
            nn.Sigmoid()
        )
        self.spatial_att = nn.Sequential(
            nn.Conv2d(channels, 1, 7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        ch_att = self.channel_att(x)
        x = x * ch_att
        sp_att = self.spatial_att(x)
        x = x * sp_att
        return x

class MultiVariableGenerator(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]

        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(channels) - 1:
                self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch

        bottleneck_ch = channels[-1]
        self.bottleneck = self._conv_block(channels[-1], bottleneck_ch)
        if self.use_attention:
            self.attention = SimplifiedAttention(bottleneck_ch)

        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, -1, -1):
            in_ch = bottleneck_ch if i == depth-1 else channels[i+1]
            out_ch = channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))

        self.final_conv = nn.Sequential(nn.Conv2d(channels[0], 1, 1), nn.Tanh())

    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x)
            skips.append(x)
            if i < len(self.downsamplers):
                x = self.downsamplers[i](x)

        x = self.bottleneck(x)
        if self.use_attention:
            x = self.attention(x)

        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x)
            skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]:
                x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1)
            x = dec(x)

        return self.final_conv(x)

# ResNet, Attention, and Lightweight architectures (copy from your architecture code)
class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

class ResNetUNet(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4):
        super().__init__()
        self.depth = depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.initial_conv = nn.Conv2d(input_channels, channels[0], 3, padding=1)
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()

        for i, ch in enumerate(channels):
            self.encoders.append(nn.Sequential(ResNetBlock(ch), ResNetBlock(ch)))
            if i < len(channels) - 1:
                self.downsamplers.append(nn.Conv2d(ch, channels[i+1], 3, stride=2, padding=1))

        self.bottleneck = nn.Sequential(
            ResNetBlock(channels[-1]), ResNetBlock(channels[-1]), ResNetBlock(channels[-1])
        )

        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, 0, -1):
            self.upsamplers.append(nn.ConvTranspose2d(channels[i], channels[i-1], 2, stride=2))
            self.decoders.append(nn.Sequential(
                ResNetBlock(channels[i-1] * 2), ResNetBlock(channels[i-1] * 2),
                nn.Conv2d(channels[i-1] * 2, channels[i-1], 1)
            ))

        self.final_conv = nn.Sequential(
            nn.Conv2d(channels[0], channels[0]//2, 3, padding=1),
            nn.BatchNorm2d(channels[0]//2), nn.ReLU(inplace=True),
            nn.Conv2d(channels[0]//2, 1, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.initial_conv(x)
        skips = []

        for i, (encoder, downsampler) in enumerate(zip(self.encoders, self.downsamplers)):
            x = encoder(x)
            skips.append(x)
            x = downsampler(x)

        x = self.encoders[-1](x)
        skips.append(x)
        x = self.bottleneck(x)

        for i, (upsampler, decoder) in enumerate(zip(self.upsamplers, self.decoders)):
            x = upsampler(x)
            skip = skips[len(skips) - 2 - i]
            if x.shape[-2:] != skip.shape[-2:]:
                x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1)
            x = decoder(x)

        return self.final_conv(x)

class AttentionUNet(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4):
        super().__init__()
        self.depth = depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]

        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(channels) - 1:
                self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch

        bottleneck_ch = channels[-1]
        self.bottleneck_conv = self._conv_block(bottleneck_ch, bottleneck_ch)
        self.attention = SimplifiedAttention(bottleneck_ch)
        self.bottleneck_final = self._conv_block(bottleneck_ch, bottleneck_ch)

        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, -1, -1):
            in_ch = bottleneck_ch if i == depth-1 else channels[i+1]
            out_ch = channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))

        self.final_conv = nn.Sequential(
            nn.Conv2d(channels[0], channels[0]//2, 3, padding=1),
            nn.BatchNorm2d(channels[0]//2), nn.ReLU(inplace=True),
            nn.Conv2d(channels[0]//2, 1, 1), nn.Tanh()
        )

    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)
        )

    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x)
            skips.append(x)
            if i < len(self.downsamplers):
                x = self.downsamplers[i](x)

        x = self.bottleneck_conv(x)
        x = self.attention(x)
        x = self.bottleneck_final(x)

        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x)
            skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]:
                x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1)
            x = dec(x)

        return self.final_conv(x)

class LightweightCNN(nn.Module):
    def __init__(self, input_channels=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 32, 7, padding=3), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, 3, padding=1), nn.Tanh()
        )

    def forward(self, x):
        original_size = x.shape[-2:]
        out = self.features(x)
        if out.shape[-2:] != original_size:
            out = F.interpolate(out, size=original_size, mode='bilinear', align_corners=False)
        return out

# --- FLEXIBLE DATASET FOR VARIABLE SUBSETS ---
class FlexibleMultiVariableDataset(Dataset):
    def __init__(self, data_dir: Path, split: str, variable_subset: list):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.variable_subset = variable_subset
        self.variable_indices = [['IVT', 'T500', 'T850', 'RH700', 'W500'].index(var) for var in variable_subset]
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')

    def __len__(self):
        return len(self.predictor_files)

    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))

        full_predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)

        predictor_data = full_predictor_data[self.variable_indices]

        predictor_mean = self.stats['predictor_mean'][self.variable_indices, None, None]
        predictor_std = self.stats['predictor_std'][self.variable_indices, None, None]

        predictor_data = (predictor_data - predictor_mean) / (predictor_std + 1e-8)
        target_data = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)

        return torch.from_numpy(predictor_data), torch.from_numpy(target_data).unsqueeze(0)

# --- FLEXIBLE GENERATOR FOR ABLATION MODELS ---
class FlexibleGenerator(nn.Module):
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]

        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(channels) - 1:
                self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch

        bottleneck_ch = channels[-1]
        self.bottleneck = self._conv_block(channels[-1], bottleneck_ch)
        if self.use_attention:
            self.attention = SimplifiedAttention(bottleneck_ch)

        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, -1, -1):
            in_ch = bottleneck_ch if i == depth-1 else channels[i+1]
            out_ch = channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))

        self.final_conv = nn.Sequential(nn.Conv2d(channels[0], 1, 1), nn.Tanh())

    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x)
            skips.append(x)
            if i < len(self.downsamplers):
                x = self.downsamplers[i](x)

        x = self.bottleneck(x)
        if self.use_attention:
            x = self.attention(x)

        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x)
            skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]:
                x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1)
            x = dec(x)

        return self.final_conv(x)

# --- EVALUATION METRICS ---
def denormalize(data, stats):
    return data * (stats.get('target_std', 1.0) + 1e-8) + stats.get('target_mean', 0.0)

def calculate_ssim(pred, target):
    data_range = target.max() - target.min()
    if data_range == 0:
        return 1.0 if np.all(pred == target) else 0.0
    return ssim(target, pred, data_range=data_range, win_size=7)

def calculate_csi(pred, target, threshold=220.0):
    pred_event = pred <= threshold
    target_event = target <= threshold
    hits = np.sum(pred_event & target_event)
    misses = np.sum(~pred_event & target_event)
    false_alarms = np.sum(pred_event & ~target_event)
    return hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0

def calculate_fss(pred, target, threshold=220.0, window_size=11):
    from scipy.ndimage import uniform_filter
    pred_binary = (pred <= threshold).astype(float)
    target_binary = (target <= threshold).astype(float)
    pred_fractions = uniform_filter(pred_binary, size=window_size)
    target_fractions = uniform_filter(target_binary, size=window_size)
    mse_fractions = np.mean((pred_fractions - target_fractions) ** 2)
    mse_fractions_ref = np.mean(pred_fractions ** 2) + np.mean(target_fractions ** 2)
    return 1 - (mse_fractions / mse_fractions_ref) if mse_fractions_ref > 0 else 1.0

def calculate_all_metrics(pred, target, stats):
    # Denormalize for physical metrics
    pred_dn = denormalize(pred, stats)
    target_dn = denormalize(target, stats)

    return {
        'rmse': np.sqrt(mean_squared_error(target_dn.flatten(), pred_dn.flatten())),
        'mae': mean_absolute_error(target_dn.flatten(), pred_dn.flatten()),
        'r2': r2_score(target_dn.flatten(), pred_dn.flatten()),
        'correlation': np.corrcoef(target_dn.flatten(), pred_dn.flatten())[0, 1],
        'ssim': calculate_ssim(pred_dn, target_dn),
        'csi': calculate_csi(pred_dn, target_dn),
        'fss': calculate_fss(pred_dn, target_dn)
    }

# --- MODEL EVALUATION CLASS ---
class ModelEvaluator:
    def __init__(self, models_dir, data_dir, output_dir):
        self.models_dir = Path(models_dir)
        self.data_dir = Path(data_dir)
        self.output_dir = Path(output_dir)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load test dataset
        self.test_dataset = MultiVariableARDataset(data_dir, 'test')
        self.test_loader = DataLoader(self.test_dataset, batch_size=1, shuffle=False)

        # Load normalization stats
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')

        print(f"📊 Test dataset: {len(self.test_dataset)} samples")
        print(f"🔧 Using device: {self.device}")

    def create_model_by_config(self, model_config):
        """Create model based on configuration."""
        model_type = model_config.get('type', 'original')
        input_channels = model_config.get('input_channels', 5)

        if model_type == 'original' or model_type == 'multivariable':
            return MultiVariableGenerator(input_channels=input_channels)
        elif model_type == 'resnet':
            return ResNetUNet(input_channels=input_channels)
        elif model_type == 'attention':
            return AttentionUNet(input_channels=input_channels)
        elif model_type == 'lightweight':
            return LightweightCNN(input_channels=input_channels)
        elif model_type == 'flexible':
            return FlexibleGenerator(input_channels=input_channels)
        else:
            raise ValueError(f"Unknown model type: {model_type}")

    def find_all_models(self):
        """Find all trained model files."""
        models = []

        # Architecture comparison models
        arch_dir = self.models_dir / 'architecture_comparison'
        if arch_dir.exists():
            for model_file in arch_dir.glob('*_final.pt'):
                model_name = model_file.stem.replace('_final', '')
                models.append({
                    'name': f'arch_{model_name}',
                    'path': model_file,
                    'category': 'architecture',
                    'description': f'Architecture: {model_name}',
                    'type': model_name.replace('_unet', '').replace('_cnn', ''),
                    'input_channels': 5
                })

        # Ablation study models
        ablation_dir = self.models_dir / 'ablation_study'
        if ablation_dir.exists():
            for model_file in ablation_dir.glob('*_final.pt'):
                model_name = model_file.stem.replace('_final', '')

                # Determine input channels from model name
                if 'single_' in model_name:
                    input_channels = 1
                elif 'pair_' in model_name:
                    input_channels = 2
                elif 'triplet_' in model_name:
                    input_channels = 3
                elif 'remove_' in model_name:
                    input_channels = 4
                else:
                    input_channels = 5

                models.append({
                    'name': model_name,
                    'path': model_file,
                    'category': 'ablation',
                    'description': f'Variables: {model_name}',
                    'type': 'flexible',
                    'input_channels': input_channels
                })

        # No GAN model
        no_gan_dir = self.models_dir / 'no_gan_baseline'
        if no_gan_dir.exists():
            for i in range(5):
                no_gan_file = no_gan_dir / f'no_gan_baseline_epoch_{(i+1)*5}.pt'
                if no_gan_file.exists():
                  models.append({
                      'name': f'no_gan_baseline_epoch_{(i+1)*5}',
                      'path': no_gan_file,
                      'category': 'baseline',
                      'description': f'No GAN Baseline {(i+1)*5}',
                      'type': 'multivariable',
                      'input_channels': 5
                  })

        # Fair baseline
        baseline_dir = self.models_dir / 'fair_baseline_scratch'
        if baseline_dir.exists():
            baseline_file = baseline_dir / 'fair_baseline_final_model.pt'
            if baseline_file.exists():
                models.append({
                    'name': 'fair_baseline',
                    'path': baseline_file,
                    'category': 'baseline',
                    'description': 'Fair Baseline (From Scratch)',
                    'type': 'multivariable',
                    'input_channels': 5
                })

        print(f"🔍 Found {len(models)} trained models:")
        for model in models:
            print(f"   - {model['name']}: {model['description']}")

        # transfer learning models
        transfer_dir = Path('/content/drive/MyDrive/AR_Downscaling/model_output_final_multi_variable')
        if transfer_dir.exists():
            for i in range(5):
                transfer_file = transfer_dir / f'final_multi_var_gan_epoch_{(i+1)*5}.pt'
                if transfer_file.exists():
                    models.append({
                        'name': 'transfer_learning_gan',
                        'path': transfer_file,
                        'category': 'transfer_learning',
                        'description': f'Transfer Learning GAN {(i+1)*5}',
                        'type': 'multivariable',
                        'input_channels': 5
                    })

        return models

    def load_model(self, model_config):
        """Load a trained model."""
        try:
            # Load checkpoint
            checkpoint = torch.load(model_config['path'], map_location=self.device, weights_only=False)

            # Create model
            model = self.create_model_by_config(model_config)

            # Load state dict
            if 'generator_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['generator_state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)

            model.to(self.device)
            model.eval()

            return model, True

        except Exception as e:
            print(f"⚠️ Failed to load {model_config['name']}: {e}")
            return None, False

    def evaluate_model(self, model, model_config):
        """Evaluate a single model on test set."""
        print(f"🎯 Evaluating {model_config['name']}...")

        # Determine if we need flexible dataset
        if model_config['input_channels'] < 5:
            # Need to determine which variables this model uses
            # For now, use all variables and let the model handle it
            # In practice, you'd need to store variable info with the model
            test_loader = self.test_loader  # Use default 5-variable loader
        else:
            test_loader = self.test_loader

        all_metrics = []

        with torch.no_grad():
            for predictor, target in tqdm(test_loader, desc=f"Evaluating {model_config['name']}"):
                predictor = predictor.to(self.device)
                target = target.cpu().numpy().squeeze()

                # Handle variable subset for ablation models
                if model_config['input_channels'] < 5:
                    # For ablation models, we need the right subset
                    # This is a simplified approach - in practice you'd store this info
                    predictor = predictor[:, :model_config['input_channels']]

                # Generate prediction
                try:
                    prediction = model(predictor).cpu().numpy().squeeze()

                    # Calculate metrics
                    metrics = calculate_all_metrics(prediction, target, self.stats)
                    all_metrics.append(metrics)

                except Exception as e:
                    print(f"Error in prediction for {model_config['name']}: {e}")
                    continue

        if not all_metrics:
            return None

        # Calculate summary statistics
        summary_metrics = {}
        for metric_name in all_metrics[0].keys():
            values = [m[metric_name] for m in all_metrics if not np.isnan(m[metric_name])]
            if values:
                summary_metrics[metric_name] = {
                    'mean': np.mean(values),
                    'std': np.std(values),
                    'median': np.median(values),
                    'min': np.min(values),
                    'max': np.max(values),
                    'count': len(values)
                }

        return summary_metrics

    def evaluate_all_models(self):
        """Evaluate all found models."""
        models = self.find_all_models()
        results = []

        for model_config in models:
            model, success = self.load_model(model_config)

            if not success:
                results.append({
                    'name': model_config['name'],
                    'category': model_config['category'],
                    'description': model_config['description'],
                    'status': 'failed_to_load',
                    'input_channels': model_config['input_channels']
                })
                continue

            metrics = self.evaluate_model(model, model_config)

            if metrics is None:
                results.append({
                    'name': model_config['name'],
                    'category': model_config['category'],
                    'description': model_config['description'],
                    'status': 'failed_evaluation',
                    'input_channels': model_config['input_channels']
                })
                continue

            # Add model info to results
            result = {
                'name': model_config['name'],
                'category': model_config['category'],
                'description': model_config['description'],
                'status': 'success',
                'input_channels': model_config['input_channels'],
                **{f"{metric_name}_{stat}": value for metric_name, stats in metrics.items()
                   for stat, value in stats.items()}
            }
            results.append(result)

            print(f"✅ {model_config['name']} completed")
            print(f"   RMSE: {metrics['rmse']['mean']:.3f} ± {metrics['rmse']['std']:.3f}")
            print(f"   CSI: {metrics['csi']['mean']:.3f} ± {metrics['csi']['std']:.3f}")
            print(f"   FSS: {metrics['fss']['mean']:.3f} ± {metrics['fss']['std']:.3f}")

        return results

    def generate_comparison_report(self, results):
        """Generate comprehensive comparison report."""
        # Convert to DataFrame
        df = pd.DataFrame(results)
        successful_df = df[df['status'] == 'success'].copy()

        if len(successful_df) == 0:
            print("❌ No successful evaluations to compare")
            return

        print(f"\n📊 Successfully evaluated {len(successful_df)} models")

        # Key metrics for comparison
        key_metrics = ['rmse_mean', 'mae_mean', 'r2_mean', 'ssim_mean', 'csi_mean', 'fss_mean']

        # Sort by CSI (Critical Success Index) - higher is better
        if 'csi_mean' in successful_df.columns:
            successful_df_sorted = successful_df.sort_values('csi_mean', ascending=False)
        else:
            successful_df_sorted = successful_df.sort_values('rmse_mean', ascending=True)

        # Generate ranking table
        print("\n🏆 Model Performance Ranking (by CSI):")
        print("=" * 80)

        ranking_data = []
        for i, (_, row) in enumerate(successful_df_sorted.iterrows()):
            ranking_data.append({
                'Rank': i + 1,
                'Model': row['name'],
                'Category': row['category'],
                'Description': row['description'],
                'Channels': row['input_channels'],
                'RMSE': f"{row.get('rmse_mean', 0):.3f} ± {row.get('rmse_std', 0):.3f}",
                'CSI': f"{row.get('csi_mean', 0):.3f} ± {row.get('csi_std', 0):.3f}",
                'FSS': f"{row.get('fss_mean', 0):.3f} ± {row.get('fss_std', 0):.3f}",
                'SSIM': f"{row.get('ssim_mean', 0):.3f} ± {row.get('ssim_std', 0):.3f}"
            })

        ranking_df = pd.DataFrame(ranking_data)
        print(ranking_df.to_string(index=False))

        # Save detailed results
        successful_df.to_csv(self.output_dir / 'detailed_test_results.csv', index=False)
        ranking_df.to_csv(self.output_dir / 'model_ranking.csv', index=False)

        # Generate analysis by category
        self.analyze_by_category(successful_df)

        # Generate visualizations
        self.create_visualizations(successful_df)

        print(f"\n💾 Results saved to {self.output_dir}")

    def analyze_by_category(self, df):
        """Analyze results by model category."""
        print(f"\n📈 Analysis by Category:")
        print("-" * 40)

        categories = df['category'].unique()

        for category in categories:
            cat_df = df[df['category'] == category]
            print(f"\n{category.upper()} Models ({len(cat_df)} models):")

            if 'csi_mean' in cat_df.columns:
                best_model = cat_df.loc[cat_df['csi_mean'].idxmax()]
                print(f"  Best: {best_model['name']} (CSI: {best_model['csi_mean']:.3f})")

                avg_csi = cat_df['csi_mean'].mean()
                print(f"  Average CSI: {avg_csi:.3f}")

            if 'rmse_mean' in cat_df.columns:
                avg_rmse = cat_df['rmse_mean'].mean()
                print(f"  Average RMSE: {avg_rmse:.3f}")

    def create_visualizations(self, df):
        """Create comparison visualizations."""
        # Set up the plotting style
        plt.style.use('default')
        sns.set_palette("husl")

        # Create a comprehensive comparison plot
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Model Performance Comparison on Test Set', fontsize=16, fontweight='bold')

        # Plot 1: CSI comparison
        if 'csi_mean' in df.columns:
            df_sorted = df.sort_values('csi_mean', ascending=True)
            axes[0, 0].barh(range(len(df_sorted)), df_sorted['csi_mean'],
                           xerr=df_sorted['csi_std'], capsize=3)
            axes[0, 0].set_yticks(range(len(df_sorted)))
            axes[0, 0].set_yticklabels(df_sorted['name'], fontsize=8)
            axes[0, 0].set_xlabel('Critical Success Index (CSI)')
            axes[0, 0].set_title('Critical Success Index (Higher = Better)')
            axes[0, 0].grid(True, alpha=0.3)

        # Plot 2: RMSE comparison
        if 'rmse_mean' in df.columns:
            df_sorted = df.sort_values('rmse_mean', ascending=True)
            axes[0, 1].barh(range(len(df_sorted)), df_sorted['rmse_mean'],
                           xerr=df_sorted['rmse_std'], capsize=3)
            axes[0, 1].set_yticks(range(len(df_sorted)))
            axes[0, 1].set_yticklabels(df_sorted['name'], fontsize=8)
            axes[0, 1].set_xlabel('Root Mean Square Error (RMSE)')
            axes[0, 1].set_title('RMSE (Lower = Better)')
            axes[0, 1].grid(True, alpha=0.3)

        # Plot 3: FSS comparison
        if 'fss_mean' in df.columns:
            df_sorted = df.sort_values('fss_mean', ascending=True)
            axes[1, 0].barh(range(len(df_sorted)), df_sorted['fss_mean'],
                           xerr=df_sorted['fss_std'], capsize=3)
            axes[1, 0].set_yticks(range(len(df_sorted)))
            axes[1, 0].set_yticklabels(df_sorted['name'], fontsize=8)
            axes[1, 0].set_xlabel('Fractions Skill Score (FSS)')
            axes[1, 0].set_title('FSS (Higher = Better)')
            axes[1, 0].grid(True, alpha=0.3)

        # Plot 4: Performance vs Complexity
        if 'csi_mean' in df.columns and 'input_channels' in df.columns:
            scatter = axes[1, 1].scatter(df['input_channels'], df['csi_mean'],
                                       s=100, alpha=0.7, c=df.index)
            axes[1, 1].set_xlabel('Number of Input Channels')
            axes[1, 1].set_ylabel('Critical Success Index (CSI)')
            axes[1, 1].set_title('Performance vs Model Complexity')
            axes[1, 1].grid(True, alpha=0.3)

            # Add labels for points
            for i, row in df.iterrows():
                axes[1, 1].annotate(row['name'][:10],
                                   (row['input_channels'], row['csi_mean']),
                                   xytext=(5, 5), textcoords='offset points',
                                   fontsize=6, alpha=0.8)

        plt.tight_layout()
        plt.savefig(self.output_dir / 'model_comparison_plots.png', dpi=300, bbox_inches='tight')
        plt.close()

        # Create category-wise comparison
        self.create_category_plots(df)

    def create_category_plots(self, df):
        """Create category-wise comparison plots."""
        categories = df['category'].unique()

        if len(categories) > 1:
            fig, ax = plt.subplots(figsize=(12, 8))

            category_data = []
            for category in categories:
                cat_df = df[df['category'] == category]
                if 'csi_mean' in cat_df.columns:
                    category_data.append(cat_df['csi_mean'].values)

            if category_data:
                ax.boxplot(category_data, labels=categories)
                ax.set_ylabel('Critical Success Index (CSI)')
                ax.set_title('Performance Distribution by Model Category')
                ax.grid(True, alpha=0.3)

                plt.xticks(rotation=45, ha='right')
                plt.tight_layout()
                plt.savefig(self.output_dir / 'category_comparison.png', dpi=300, bbox_inches='tight')
                plt.close()

# --- MAIN EXECUTION ---
def main():
    print("🎯 Starting Comprehensive Test Set Evaluation")
    print("=" * 60)

    evaluator = ModelEvaluator(
        models_dir=MODELS_DIR,
        data_dir=DATA_DIR,
        output_dir=OUTPUT_DIR
    )

    # Run evaluation on all models
    results = evaluator.evaluate_all_models()

    # Generate comprehensive report
    evaluator.generate_comparison_report(results)

    print("\n🎉 Comprehensive evaluation completed!")
    print(f"📊 Results and visualizations saved to: {OUTPUT_DIR}")

    # Generate final summary
    successful_results = [r for r in results if r['status'] == 'success']
    if successful_results:
        best_overall = max(successful_results, key=lambda x: x.get('csi_mean', 0))
        print(f"\n🏆 Best Overall Model: {best_overall['name']}")
        print(f"   Description: {best_overall['description']}")
        print(f"   CSI: {best_overall.get('csi_mean', 0):.3f}")
        print(f"   RMSE: {best_overall.get('rmse_mean', 0):.3f}")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
🎯 Starting Comprehensive Test Set Evaluation
📊 Test dataset: 150 samples
🔧 Using device: cuda
🔍 Found 29 trained models:
   - arch_original_unet: Architecture: original_unet
   - arch_resnet_unet: Architecture: resnet_unet
   - arch_attention_unet: Architecture: attention_unet
   - arch_lightweight_cnn: Architecture: lightweight_cnn
   - single_IVT: Variables: single_IVT
   - single_T500: Variables: single_T500
   - single_T850: Variables: single_T850
   - single_RH700: Variables: single_RH700
   - single_W500: Variables: single_W500
   - pair_IVT_T500: Variables: pair_IVT_T500
   - pair_IVT_RH700: Variables: pair_IVT_RH700
   - pair_T500_T850: Variables: pair_T500_T850
   - pair_RH700_W500: Variables: pair_RH700_W500
   - pair_IVT_W500: Variables: pair_IVT_W500
   - pair_T500_W500: Variables: pair_T500_W500
   - triplet_IVT_T500_RH700: Variables: triplet_IVT

Evaluating arch_original_unet: 100%|██████████| 150/150 [00:21<00:00,  7.10it/s]


✅ arch_original_unet completed
   RMSE: 8.501 ± 3.308
   CSI: 0.142 ± 0.171
   FSS: 0.398 ± 0.351
🎯 Evaluating arch_resnet_unet...


Evaluating arch_resnet_unet: 100%|██████████| 150/150 [00:06<00:00, 23.78it/s]


✅ arch_resnet_unet completed
   RMSE: 8.241 ± 3.349
   CSI: 0.088 ± 0.138
   FSS: 0.327 ± 0.370
🎯 Evaluating arch_attention_unet...


Evaluating arch_attention_unet: 100%|██████████| 150/150 [00:03<00:00, 37.74it/s]


✅ arch_attention_unet completed
   RMSE: 8.220 ± 3.217
   CSI: 0.126 ± 0.172
   FSS: 0.384 ± 0.368
🎯 Evaluating arch_lightweight_cnn...


Evaluating arch_lightweight_cnn: 100%|██████████| 150/150 [00:03<00:00, 42.61it/s]


✅ arch_lightweight_cnn completed
   RMSE: 8.664 ± 3.514
   CSI: 0.103 ± 0.137
   FSS: 0.332 ± 0.345
🎯 Evaluating single_IVT...


Evaluating single_IVT: 100%|██████████| 150/150 [00:03<00:00, 42.62it/s]


✅ single_IVT completed
   RMSE: 8.939 ± 3.413
   CSI: 0.030 ± 0.060
   FSS: 0.238 ± 0.367
🎯 Evaluating single_T500...


Evaluating single_T500: 100%|██████████| 150/150 [00:03<00:00, 43.29it/s]


✅ single_T500 completed
   RMSE: 10.083 ± 4.124
   CSI: 0.014 ± 0.033
   FSS: 0.032 ± 0.068
🎯 Evaluating single_T850...


Evaluating single_T850: 100%|██████████| 150/150 [00:03<00:00, 43.09it/s]


✅ single_T850 completed
   RMSE: 11.142 ± 4.476
   CSI: 0.000 ± 0.000
   FSS: 0.173 ± 0.379
🎯 Evaluating single_RH700...


Evaluating single_RH700: 100%|██████████| 150/150 [00:03<00:00, 43.39it/s]


✅ single_RH700 completed
   RMSE: 10.594 ± 4.395
   CSI: 0.000 ± 0.001
   FSS: 0.174 ± 0.378
🎯 Evaluating single_W500...


Evaluating single_W500: 100%|██████████| 150/150 [00:03<00:00, 41.51it/s]


✅ single_W500 completed
   RMSE: 12.636 ± 6.018
   CSI: 0.000 ± 0.000
   FSS: 0.173 ± 0.379
🎯 Evaluating pair_IVT_T500...


Evaluating pair_IVT_T500: 100%|██████████| 150/150 [00:03<00:00, 41.90it/s]


✅ pair_IVT_T500 completed
   RMSE: 8.660 ± 3.220
   CSI: 0.047 ± 0.074
   FSS: 0.269 ± 0.360
🎯 Evaluating pair_IVT_RH700...


Evaluating pair_IVT_RH700: 100%|██████████| 150/150 [00:03<00:00, 37.53it/s]


✅ pair_IVT_RH700 completed
   RMSE: 8.698 ± 3.042
   CSI: 0.063 ± 0.098
   FSS: 0.293 ± 0.361
🎯 Evaluating pair_T500_T850...


Evaluating pair_T500_T850: 100%|██████████| 150/150 [00:03<00:00, 43.23it/s]


✅ pair_T500_T850 completed
   RMSE: 10.229 ± 4.165
   CSI: 0.002 ± 0.009
   FSS: 0.177 ± 0.377
🎯 Evaluating pair_RH700_W500...


Evaluating pair_RH700_W500: 100%|██████████| 150/150 [00:03<00:00, 41.95it/s]


✅ pair_RH700_W500 completed
   RMSE: 10.444 ± 3.945
   CSI: 0.001 ± 0.010
   FSS: 0.176 ± 0.378
🎯 Evaluating pair_IVT_W500...


Evaluating pair_IVT_W500: 100%|██████████| 150/150 [00:03<00:00, 42.68it/s]


✅ pair_IVT_W500 completed
   RMSE: 10.373 ± 3.956
   CSI: 0.000 ± 0.001
   FSS: 0.160 ± 0.367
🎯 Evaluating pair_T500_W500...


Evaluating pair_T500_W500: 100%|██████████| 150/150 [00:03<00:00, 39.97it/s]


✅ pair_T500_W500 completed
   RMSE: 11.496 ± 4.263
   CSI: 0.000 ± 0.003
   FSS: 0.154 ± 0.360
🎯 Evaluating triplet_IVT_T500_RH700...


Evaluating triplet_IVT_T500_RH700: 100%|██████████| 150/150 [00:03<00:00, 37.91it/s]


✅ triplet_IVT_T500_RH700 completed
   RMSE: 8.815 ± 3.234
   CSI: 0.036 ± 0.066
   FSS: 0.246 ± 0.365
🎯 Evaluating triplet_T500_T850_W500...


Evaluating triplet_T500_T850_W500: 100%|██████████| 150/150 [00:03<00:00, 41.76it/s]


✅ triplet_T500_T850_W500 completed
   RMSE: 11.056 ± 4.699
   CSI: 0.000 ± 0.000
   FSS: 0.153 ± 0.360
🎯 Evaluating triplet_IVT_RH700_W500...


Evaluating triplet_IVT_RH700_W500: 100%|██████████| 150/150 [00:03<00:00, 42.28it/s]


✅ triplet_IVT_RH700_W500 completed
   RMSE: 8.718 ± 3.430
   CSI: 0.001 ± 0.009
   FSS: 0.176 ± 0.378
🎯 Evaluating triplet_IVT_T500_W500...


Evaluating triplet_IVT_T500_W500: 100%|██████████| 150/150 [00:03<00:00, 43.21it/s]


✅ triplet_IVT_T500_W500 completed
   RMSE: 9.945 ± 4.292
   CSI: 0.000 ± 0.000
   FSS: 0.153 ± 0.360
🎯 Evaluating remove_IVT...


Evaluating remove_IVT: 100%|██████████| 150/150 [00:03<00:00, 41.48it/s]


✅ remove_IVT completed
   RMSE: 9.496 ± 3.853
   CSI: 0.006 ± 0.031
   FSS: 0.152 ± 0.347
🎯 Evaluating remove_T500...


Evaluating remove_T500: 100%|██████████| 150/150 [00:04<00:00, 35.68it/s]


✅ remove_T500 completed
   RMSE: 9.636 ± 3.904
   CSI: 0.001 ± 0.005
   FSS: 0.168 ± 0.372
🎯 Evaluating remove_T850...


Evaluating remove_T850: 100%|██████████| 150/150 [00:03<00:00, 40.42it/s]


✅ remove_T850 completed
   RMSE: 9.182 ± 3.601
   CSI: 0.001 ± 0.010
   FSS: 0.176 ± 0.378
🎯 Evaluating remove_RH700...


Evaluating remove_RH700: 100%|██████████| 150/150 [00:03<00:00, 42.33it/s]


✅ remove_RH700 completed
   RMSE: 10.599 ± 4.240
   CSI: 0.001 ± 0.012
   FSS: 0.109 ± 0.309
🎯 Evaluating remove_W500...


Evaluating remove_W500: 100%|██████████| 150/150 [00:03<00:00, 40.76it/s]


✅ remove_W500 completed
   RMSE: 8.761 ± 3.504
   CSI: 0.013 ± 0.030
   FSS: 0.201 ± 0.371
🎯 Evaluating all_variables...


Evaluating all_variables: 100%|██████████| 150/150 [00:03<00:00, 41.26it/s]


✅ all_variables completed
   RMSE: 8.373 ± 3.337
   CSI: 0.108 ± 0.152
   FSS: 0.362 ± 0.366
🎯 Evaluating no_gan_baseline_epoch_10...


Evaluating no_gan_baseline_epoch_10: 100%|██████████| 150/150 [00:03<00:00, 37.82it/s]


✅ no_gan_baseline_epoch_10 completed
   RMSE: 8.386 ± 3.220
   CSI: 0.046 ± 0.095
   FSS: 0.257 ± 0.374
🎯 Evaluating no_gan_baseline_epoch_20...


Evaluating no_gan_baseline_epoch_20: 100%|██████████| 150/150 [00:03<00:00, 41.98it/s]


✅ no_gan_baseline_epoch_20 completed
   RMSE: 8.294 ± 3.198
   CSI: 0.079 ± 0.135
   FSS: 0.311 ± 0.374
🎯 Evaluating no_gan_baseline_epoch_25...


Evaluating no_gan_baseline_epoch_25: 100%|██████████| 150/150 [00:04<00:00, 35.91it/s]


✅ no_gan_baseline_epoch_25 completed
   RMSE: 8.288 ± 3.157
   CSI: 0.091 ± 0.157
   FSS: 0.325 ± 0.381
🎯 Evaluating fair_baseline...


Evaluating fair_baseline: 100%|██████████| 150/150 [00:03<00:00, 43.11it/s]


✅ fair_baseline completed
   RMSE: 8.417 ± 3.745
   CSI: 0.117 ± 0.152
   FSS: 0.373 ± 0.358
🎯 Evaluating transfer_learning_gan...


Evaluating transfer_learning_gan: 100%|██████████| 150/150 [00:03<00:00, 38.59it/s]


✅ transfer_learning_gan completed
   RMSE: 8.411 ± 3.189
   CSI: 0.148 ± 0.174
   FSS: 0.416 ± 0.350
🎯 Evaluating transfer_learning_gan...


Evaluating transfer_learning_gan: 100%|██████████| 150/150 [00:03<00:00, 42.10it/s]


✅ transfer_learning_gan completed
   RMSE: 8.246 ± 3.127
   CSI: 0.115 ± 0.157
   FSS: 0.372 ± 0.365
🎯 Evaluating transfer_learning_gan...


Evaluating transfer_learning_gan: 100%|██████████| 150/150 [00:03<00:00, 43.17it/s]


✅ transfer_learning_gan completed
   RMSE: 8.234 ± 2.978
   CSI: 0.110 ± 0.151
   FSS: 0.357 ± 0.361
🎯 Evaluating transfer_learning_gan...


Evaluating transfer_learning_gan: 100%|██████████| 150/150 [00:03<00:00, 42.75it/s]


✅ transfer_learning_gan completed
   RMSE: 8.250 ± 3.135
   CSI: 0.108 ± 0.170
   FSS: 0.351 ± 0.380
🎯 Evaluating transfer_learning_gan...


Evaluating transfer_learning_gan: 100%|██████████| 150/150 [00:03<00:00, 38.68it/s]


✅ transfer_learning_gan completed
   RMSE: 8.348 ± 3.011
   CSI: 0.147 ± 0.176
   FSS: 0.412 ± 0.353

📊 Successfully evaluated 34 models

🏆 Model Performance Ranking (by CSI):
 Rank                    Model          Category                       Description  Channels           RMSE           CSI           FSS          SSIM
    1    transfer_learning_gan transfer_learning           Transfer Learning GAN 5         5  8.411 ± 3.189 0.148 ± 0.174 0.416 ± 0.350 0.630 ± 0.111
    2    transfer_learning_gan transfer_learning          Transfer Learning GAN 25         5  8.348 ± 3.011 0.147 ± 0.176 0.412 ± 0.353 0.621 ± 0.105
    3       arch_original_unet      architecture       Architecture: original_unet         5  8.501 ± 3.308 0.142 ± 0.171 0.398 ± 0.351 0.637 ± 0.120
    4      arch_attention_unet      architecture      Architecture: attention_unet         5  8.220 ± 3.217 0.126 ± 0.172 0.384 ± 0.368 0.651 ± 0.115
    5            fair_baseline          baseline      Fair Baseline (From 

In [None]:
# --- PUBLICATION-QUALITY VISUALIZATION SCRIPT (V14 - The Definitive Figure) ---
# This definitive version incorporates professional meteorological visualization techniques
# to finally do justice to the GAN's performance. It uses:
#   1. A satellite-style colormap ('gist_gray_r') to make storms pop.
#   2. Non-linear color scaling (PowerNorm) to reveal details in cold cloud tops.
#   3. A magnified inset ("microscope view") for undeniable proof of sharpness.

import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
from tqdm import tqdm
import warnings
import joblib
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from scipy.ndimage import uniform_filter, zoom, sobel
from google.colab import drive

warnings.filterwarnings('ignore')
drive.mount('/content/drive', force_remount=True)

# --- CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
MODELS_DIR = PROJECT_PATH / 'publication_experiments'
OUTPUT_DIR = PROJECT_PATH / 'final_publication_figures'
OUTPUT_DIR.mkdir(exist_ok=True)

DISASTER_DAYS = ['20230807_1200', '20230807_1800']
TARGET_SHAPE = (256, 256)
stats = joblib.load(DATA_DIR / 'normalization_stats_multi_variable.joblib')

# --- HELPER & MODEL FUNCTIONS (No changes needed) ---
# [All previous helper functions like load_and_crop_data, model architectures, etc., go here]
# [For brevity, they are omitted, but they are the same as the V12/V13 script]
def load_and_crop_data(file_path, target_shape=(256, 256)):
    if not file_path.exists(): return None
    data = np.load(file_path).astype(np.float32)
    if data.ndim == 3: h, w = data.shape[1], data.shape[2]
    else: h, w = data.shape
    th, tw = target_shape
    if h != th or w != tw:
        start_h = max(0, (h - th) // 2); start_w = max(0, (w - tw) // 2)
        if data.ndim == 3: cropped_data = data[:, start_h:start_h + th, start_w:start_w + tw]
        else: cropped_data = data[start_h:start_h + th, start_w:start_w + tw]
        if cropped_data.shape[-2:] != target_shape:
            if cropped_data.ndim == 3:
                padded_data = np.zeros((data.shape[0], *target_shape), dtype=np.float32)
                padded_data[:, :cropped_data.shape[1], :cropped_data.shape[2]] = cropped_data
            else:
                padded_data = np.zeros(target_shape, dtype=np.float32)
                padded_data[:cropped_data.shape[0], :cropped_data.shape[1]] = cropped_data
            return padded_data
        return cropped_data
    return data
class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x
class MultiVariableGenerator(nn.Module):
    def __init__(self, input_channels=5, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = channels[-1]
        self.bottleneck = self._conv_block(channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth-1, -1, -1):
            in_ch = bottleneck_ch if i == depth-1 else channels[i+1]
            out_ch = channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(channels[0], 1, 1), nn.Tanh())
    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.LeakyReLU(0.2, inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)
def denormalize(data, is_target=True):
    if is_target: return data * (stats['target_std'] + 1e-8) + stats['target_mean']
    else:
        mean = stats['predictor_mean'][:, None, None]; std = stats['predictor_std'][:, None, None]
        return data * (std + 1e-8) + mean
def load_neural_model(model_path, model_class, device):
    if not model_path.exists(): return None
    print(f"   Loading {model_path.name}...")
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model = model_class(input_channels=5)
    state_dict_key = 'generator_state_dict' if 'generator_state_dict' in checkpoint else 'model_state_dict'
    model.load_state_dict(checkpoint[state_dict_key])
    model.to(device); model.eval(); return model
def predict_neural_model(model, predictor_data, device):
    with torch.no_grad():
        predictor_norm = (predictor_data - stats['predictor_mean'][:, None, None]) / (stats['predictor_std'][:, None, None] + 1e-8)
        predictor_tensor = torch.from_numpy(predictor_norm).unsqueeze(0).to(device)
        prediction_tensor = model(predictor_tensor)
        prediction_norm = prediction_tensor.cpu().numpy().squeeze()
        return denormalize(prediction_norm, is_target=True)
def calculate_csi(pred, target, threshold=220.0):
    pred_event = pred <= threshold; target_event = target <= threshold
    hits = np.sum(pred_event & target_event)
    misses = np.sum(~pred_event & target_event)
    false_alarms = np.sum(pred_event & ~target_event)
    return hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0
def extract_multivar_features_for_rf(predictor_norm, window_size=5):
    num_channels, h, w = predictor_norm.shape
    num_features = num_channels * 3
    X_pixel_major = predictor_norm.transpose(1, 2, 0).reshape(h * w, num_channels)
    features = np.zeros((h * w, num_features), dtype=np.float32)
    features[:, 0:num_channels] = X_pixel_major
    for c in range(num_channels):
        local_mean = uniform_filter(predictor_norm[c], size=window_size)
        local_sq_mean = uniform_filter(predictor_norm[c]**2, size=window_size)
        local_var = local_sq_mean - local_mean**2
        local_std = np.sqrt(np.maximum(local_var, 0))
        features[:, num_channels + c] = local_mean.flatten()
        features[:, (2 * num_channels) + c] = local_std.flatten()
    return features

# --- THE DEFINITIVE VISUALIZATION FUNCTION ---

def create_definitive_figure(timestamp, ground_truth, predictions, metrics):
    """
    Creates the final, definitive figure with a professional satellite colormap,
    non-linear scaling, and a magnified inset to prove the GAN's sharpness.
    """
    model_order = ['Strong_RF_Baseline', 'Transfer_Learning_GAN']

    fig, axes = plt.subplots(1, 3, figsize=(21, 7), facecolor='white')
    plt.style.use('default')
    fig.suptitle(f'Model Comparison for Event: {timestamp}', fontsize=24, fontweight='bold')

    # --- KEY VISUALIZATION IMPROVEMENTS ---
    # 1. Use a professional, reversed grayscale colormap
    TBB_CMAP = 'gist_gray_r'
    # 2. Use a non-linear color scale (PowerNorm) to emphasize cold cloud tops
    TBB_NORM = colors.PowerNorm(gamma=0.5, vmin=190, vmax=310)

    # --- Data Preparation ---
    rf_pred = predictions['Strong_RF_Baseline']
    gan_pred = predictions['Transfer_Learning_GAN']

    # --- Define the zoom area for the inset (e.g., a 64x64 box) ---
    # This can be set manually or algorithmically
    ZOOM_X, ZOOM_Y, ZOOM_SIZE = 120, 80, 64

    # --- Column 1: Ground Truth ---
    im_gt = axes[0].imshow(ground_truth, cmap=TBB_CMAP, norm=TBB_NORM)
    axes[0].set_title('A) Ground Truth (TBB)', fontsize=18, fontweight='bold')
    axes[0].add_patch(Rectangle((ZOOM_X, ZOOM_Y), ZOOM_SIZE, ZOOM_SIZE, fill=False, edgecolor='cyan', lw=2))
    axins_gt = inset_axes(axes[0], width="40%", height="40%", loc='lower right')
    axins_gt.imshow(ground_truth[ZOOM_Y:ZOOM_Y+ZOOM_SIZE, ZOOM_X:ZOOM_X+ZOOM_SIZE],
                    cmap=TBB_CMAP, norm=TBB_NORM, origin="lower")
    axins_gt.set_xticks([]); axins_gt.set_yticks([])
    mark_inset(axes[0], axins_gt, loc1=1, loc2=3, fc="none", ec="cyan", lw=1)

    # --- Column 2: Strong RF Baseline ---
    axes[1].imshow(rf_pred, cmap=TBB_CMAP, norm=TBB_NORM)
    axes[1].set_title(f'B) RF Baseline (CSI: {metrics["Strong_RF_Baseline"]["csi"]:.3f})', fontsize=18, fontweight='bold')
    axes[1].add_patch(Rectangle((ZOOM_X, ZOOM_Y), ZOOM_SIZE, ZOOM_SIZE, fill=False, edgecolor='red', lw=2))
    axins_rf = inset_axes(axes[1], width="40%", height="40%", loc='lower right')
    axins_rf.imshow(rf_pred[ZOOM_Y:ZOOM_Y+ZOOM_SIZE, ZOOM_X:ZOOM_X+ZOOM_SIZE],
                    cmap=TBB_CMAP, norm=TBB_NORM, origin="lower")
    axins_rf.set_xticks([]); axins_rf.set_yticks([])
    mark_inset(axes[1], axins_rf, loc1=1, loc2=3, fc="none", ec="red", lw=1)

    # --- Column 3: Transfer Learning GAN ---
    axes[2].imshow(gan_pred, cmap=TBB_CMAP, norm=TBB_NORM)
    axes[2].set_title(f'C) GAN Prediction (CSI: {metrics["Transfer_Learning_GAN"]["csi"]:.3f})', fontsize=18, fontweight='bold')
    axes[2].add_patch(Rectangle((ZOOM_X, ZOOM_Y), ZOOM_SIZE, ZOOM_SIZE, fill=False, edgecolor='lime', lw=2))
    axins_gan = inset_axes(axes[2], width="40%", height="40%", loc='lower right')
    axins_gan.imshow(gan_pred[ZOOM_Y:ZOOM_Y+ZOOM_SIZE, ZOOM_X:ZOOM_X+ZOOM_SIZE],
                     cmap=TBB_CMAP, norm=TBB_NORM, origin="lower")
    axins_gan.set_xticks([]); axins_gan.set_yticks([])
    mark_inset(axes[2], axins_gan, loc1=1, loc2=3, fc="none", ec="lime", lw=1)

    # --- Final Touches ---
    for ax in axes.flatten():
        ax.set_xticks([]); ax.set_yticks([])

    fig.tight_layout(rect=[0, 0, 1, 0.93])
    filename = f'definitive_figure_{timestamp}.png'
    plt.savefig(OUTPUT_DIR / filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"   📊 Saved new Definitive Figure: {filename}")

# --- MAIN ANALYSIS ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models_to_load = {
        'Transfer_Learning_GAN': (MODELS_DIR.parent / 'model_output_final_multi_variable' / 'final_multi_var_gan_epoch_25.pt', MultiVariableGenerator),
        'Strong_RF_Baseline': (MODELS_DIR / 'strong_baseline_rf' / 'strong_baseline_rf.joblib', None)
    }
    loaded_models = {}
    for name, (path, model_class) in models_to_load.items():
        if name == 'Strong_RF_Baseline':
            if path.exists():
                loaded_models[name] = joblib.load(path)
                print(f"   ✅ Loaded {name}")
        else:
            model = load_neural_model(path, model_class, device)
            if model: loaded_models[name] = model

    for timestamp in DISASTER_DAYS:
        print(f"\n🌊 Analyzing timestamp: {timestamp}")
        test_dir = DATA_DIR / 'test'
        predictor_data = load_and_crop_data(test_dir / f'{timestamp}_predictor.npy')
        ground_truth_norm = load_and_crop_data(test_dir / f'{timestamp}_target.npy')
        if predictor_data is None or ground_truth_norm is None:
            print(f"   ⚠️ Data not found for {timestamp}, skipping."); continue

        ground_truth = denormalize(ground_truth_norm, is_target=True)
        predictions, metrics = {}, {}

        for name in ['Strong_RF_Baseline', 'Transfer_Learning_GAN']:
            if name not in loaded_models: continue
            model = loaded_models[name]

            if name == 'Strong_RF_Baseline':
                predictor_norm = (predictor_data - stats['predictor_mean'][:, None, None]) / (stats['predictor_std'][:, None, None] + 1e-8)
                features = extract_multivar_features_for_rf(predictor_norm)
                pred_norm_flat = model.predict(features)
                pred_norm = pred_norm_flat.reshape(TARGET_SHAPE)
                prediction = denormalize(pred_norm, is_target=True)
            else:
                prediction = predict_neural_model(model, predictor_data, device)

            predictions[name] = prediction
            metrics[name] = {'csi': calculate_csi(prediction, ground_truth)}
            print(f"   ✅ Generated prediction for {name}")

        if predictions:
            create_definitive_figure(timestamp, ground_truth, predictions, metrics)

if __name__ == "__main__":
    main()

Mounted at /content/drive
   Loading final_multi_var_gan_epoch_25.pt...
   ✅ Loaded Strong_RF_Baseline

🌊 Analyzing timestamp: 20230807_1200
   ✅ Generated prediction for Strong_RF_Baseline
   ✅ Generated prediction for Transfer_Learning_GAN


[Parallel(n_jobs=12)]: Using backend ThreadingBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Done 100 out of 100 | elapsed:    0.0s finished


   📊 Saved new Definitive Figure: definitive_figure_20230807_1200.png

🌊 Analyzing timestamp: 20230807_1800
   ✅ Generated prediction for Strong_RF_Baseline
   ✅ Generated prediction for Transfer_Learning_GAN


[Parallel(n_jobs=12)]: Using backend ThreadingBackend with 12 concurrent workers.
[Parallel(n_jobs=12)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=12)]: Done 100 out of 100 | elapsed:    0.0s finished


   📊 Saved new Definitive Figure: definitive_figure_20230807_1800.png
