In [None]:
#
# This script trains the Stage 2 "Diagnostic Network" for the decoupled
# "Forecast and Diagnose" pipeline.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import json
import gc

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
ABLATION_MODEL_DIR = PROJECT_PATH / 'ablation_study_models'
OUTPUT_DIR = PROJECT_PATH / 'final_diagnostic_model'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Path to the Stage 1 "Forecast Engine" ---
STAGE_1_MODEL_PATH = ABLATION_MODEL_DIR / 'ablation_remove_IVT.pt'

# --- Training Hyperparameters ---
EPOCHS = 30
BATCH_SIZE = 8
LEARNING_RATE = 1e-4

# --- Loss Function Weights ---
LAMBDA_BCE = 1.0
LAMBDA_DICE = 3.0 # Heavier weight on spatial overlap

# --- Model & Data Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLES_TO_USE = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLES_TO_USE]
INPUT_CHANNELS = len(VARIABLE_INDICES)
DIAGNOSTIC_THRESHOLD_K = 220.0

# --- 2. DATASET FOR DIAGNOSTIC NETWORK ---
class DiagnosticDataset(Dataset):
    """
    A dataset where the input is the prediction from the Stage 1 model
    and the target is a binary mask of the severe storm cores.
    """
    def __init__(self, data_dir: Path, split: str, stage1_predictions_dir: Path):
        self.split_dir = data_dir / split
        self.stage1_dir = stage1_predictions_dir / split
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')

        self.stage1_files = sorted(list(self.stage1_dir.glob('*.npy')))
        self.target_files = [self.split_dir / f"{p.stem.replace('pred_', '')}_target.npy" for p in self.stage1_files]

        print(f"Loaded {len(self.stage1_files)} Stage 1 predictions for '{split}' split.")

    def __len__(self):
        return len(self.stage1_files)

    def __getitem__(self, idx):
        # Input is the Stage 1 prediction (already normalized)
        input_data = np.load(self.stage1_files[idx]).astype(np.float32)

        # Target is the real ground truth, which we convert to a binary mask
        target_data = np.load(self.target_files[idx]).astype(np.float32)
        target_mask = (target_data <= DIAGNOSTIC_THRESHOLD_K).astype(np.float32)

        return torch.from_numpy(input_data).unsqueeze(0), torch.from_numpy(target_mask).unsqueeze(0)

# --- 3. MODEL ARCHITECTURES ---
class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module): # Stage 1 Model
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class DiagnosticUNet(nn.Module): # Stage 2 Model
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, depth=3):
        super().__init__()
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Conv2d(self.channels[0], output_channels, 1) # No Tanh, outputs logits
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 4. PRE-PROCESSING & CSI METRIC ---
def generate_stage1_predictions(stage1_model, data_dir, output_dir, variable_indices, device):
    print("--- Generating Stage 1 Predictions for Diagnostic Network Training ---")
    for split in ['train', 'val']:
        print(f"Processing '{split}' split...")
        split_dir = data_dir / split
        output_split_dir = output_dir / split
        output_split_dir.mkdir(parents=True, exist_ok=True)

        predictor_files = sorted(list(split_dir.glob('*_predictor.npy')))
        stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')

        with torch.no_grad():
            for pred_file in tqdm(predictor_files, desc=f"Generating Stage 1 preds for {split}"):
                full_predictor = np.load(pred_file).astype(np.float32)
                predictor_subset = full_predictor[variable_indices, :, :]
                mean_subset = stats['predictor_mean'][variable_indices, None, None]
                std_subset = stats['predictor_std'][variable_indices, None, None]
                predictor_norm = (predictor_subset - mean_subset) / (std_subset + 1e-8)

                prediction_norm = stage1_model(torch.from_numpy(predictor_norm).unsqueeze(0).to(device)).cpu().numpy().squeeze()
                prediction_denorm = prediction_norm * (stats['target_std'] + 1e-8) + stats['target_mean']

                save_path = output_split_dir / f"pred_{pred_file.stem.replace('_predictor','')}.npy"
                np.save(save_path, prediction_denorm)
    print("--- Stage 1 Prediction Generation Complete ---")

def calculate_csi_from_mask(pred_mask, true_mask):
    hits = (pred_mask & true_mask).sum().item()
    misses = (~pred_mask & true_mask).sum().item()
    false_alarms = (pred_mask & ~true_mask).sum().item()
    return hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    def forward(self, y_pred, y_true):
        y_pred_sig = torch.sigmoid(y_pred)
        y_pred_flat = y_pred_sig.contiguous().view(-1)
        y_true_flat = y_true.contiguous().view(-1)
        intersection = (y_pred_flat * y_true_flat).sum()
        return 1 - (2. * intersection + self.smooth) / (y_pred_flat.sum() + y_true_flat.sum() + self.smooth)

# --- 5. MAIN TRAINING SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🎯 Training Stage 2 Diagnostic Network on device: {device}")

    stage1_model = FlexibleAblationModel(input_channels=INPUT_CHANNELS).to(device)
    stage1_model.load_state_dict(torch.load(STAGE_1_MODEL_PATH, map_location=device))
    stage1_model.eval()

    stage1_preds_dir = OUTPUT_DIR / 'stage1_predictions_for_diagnostic'
    generate_stage1_predictions(stage1_model, DATA_DIR, stage1_preds_dir, VARIABLE_INDICES, device)

    train_dataset = DiagnosticDataset(DATA_DIR, 'train', stage1_preds_dir)
    val_dataset = DiagnosticDataset(DATA_DIR, 'val', stage1_preds_dir)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = DiagnosticUNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    bce_loss = nn.BCEWithLogitsLoss()
    dice_loss = DiceLoss()

    best_csi = 0.0
    print(f"\nStarting Stage 2 Diagnostic training for {EPOCHS} epochs...")
    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Diagnostic Train]")
        for stage1_pred, true_mask in pbar:
            stage1_pred, true_mask = stage1_pred.to(device), true_mask.to(device)

            optimizer.zero_grad()
            pred_logits = model(stage1_pred)

            loss_bce = bce_loss(pred_logits, true_mask)
            loss_dice = dice_loss(pred_logits, true_mask)

            loss = (loss_bce * LAMBDA_BCE) + (loss_dice * LAMBDA_DICE)
            loss.backward()
            optimizer.step()

            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        model.eval()
        total_csi = 0.0
        with torch.no_grad():
            for stage1_pred_val, true_mask_val in val_loader:
                stage1_pred_val, true_mask_val = stage1_pred_val.to(device), true_mask_val.to(device)
                pred_logits_val = model(stage1_pred_val)
                pred_mask_val = (torch.sigmoid(pred_logits_val) > 0.5)
                total_csi += calculate_csi_from_mask(pred_mask_val, true_mask_val.bool())
        avg_csi = total_csi / len(val_loader)
        print(f"Epoch {epoch+1}/{EPOCHS} | Validation CSI @ {DIAGNOSTIC_THRESHOLD_K}K: {avg_csi:.4f}")

        if avg_csi > best_csi:
            best_csi = avg_csi
            save_path = OUTPUT_DIR / 'best_diagnostic_model.pt'
            torch.save(model.state_dict(), save_path)
            print(f"✅ New best CSI! Model saved to {save_path}")
        gc.collect(); torch.cuda.empty_cache()

    print(f"\n🎉 Diagnostic Network training complete! Best CSI achieved: {best_csi:.4f}")

if __name__ == "__main__":
    main()


🎯 Training Stage 2 Diagnostic Network on device: cuda
--- Generating Stage 1 Predictions for Diagnostic Network Training ---
Processing 'train' split...


Generating Stage 1 preds for train: 100%|██████████| 1200/1200 [04:58<00:00,  4.02it/s]


Processing 'val' split...


Generating Stage 1 preds for val: 100%|██████████| 150/150 [00:20<00:00,  7.27it/s]


--- Stage 1 Prediction Generation Complete ---
Loaded 1200 Stage 1 predictions for 'train' split.
Loaded 150 Stage 1 predictions for 'val' split.

Starting Stage 2 Diagnostic training for 30 epochs...


Epoch 1/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:52<00:00,  2.87it/s, loss=2.2938]


Epoch 1/30 | Validation CSI @ 220.0K: 0.1578
✅ New best CSI! Model saved to /content/drive/My Drive/AR_Downscaling/final_diagnostic_model/best_diagnostic_model.pt


Epoch 2/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.48it/s, loss=1.8196]


Epoch 2/30 | Validation CSI @ 220.0K: 0.1571


Epoch 3/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.47it/s, loss=2.6381]


Epoch 3/30 | Validation CSI @ 220.0K: 0.1557


Epoch 4/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.52it/s, loss=2.1981]


Epoch 4/30 | Validation CSI @ 220.0K: 0.1541


Epoch 5/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.52it/s, loss=2.5044]


Epoch 5/30 | Validation CSI @ 220.0K: 0.1584
✅ New best CSI! Model saved to /content/drive/My Drive/AR_Downscaling/final_diagnostic_model/best_diagnostic_model.pt


Epoch 6/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.52it/s, loss=2.1244]


Epoch 6/30 | Validation CSI @ 220.0K: 0.1553


Epoch 7/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=1.9735]


Epoch 7/30 | Validation CSI @ 220.0K: 0.1512


Epoch 8/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.49it/s, loss=2.3915]


Epoch 8/30 | Validation CSI @ 220.0K: 0.1509


Epoch 9/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.2511]


Epoch 9/30 | Validation CSI @ 220.0K: 0.1512


Epoch 10/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.7644]


Epoch 10/30 | Validation CSI @ 220.0K: 0.1370


Epoch 11/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=3.0110]


Epoch 11/30 | Validation CSI @ 220.0K: 0.1528


Epoch 12/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.1704]


Epoch 12/30 | Validation CSI @ 220.0K: 0.1455


Epoch 13/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=1.5507]


Epoch 13/30 | Validation CSI @ 220.0K: 0.1557


Epoch 14/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=3.2203]


Epoch 14/30 | Validation CSI @ 220.0K: 0.1653
✅ New best CSI! Model saved to /content/drive/My Drive/AR_Downscaling/final_diagnostic_model/best_diagnostic_model.pt


Epoch 15/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=2.3269]


Epoch 15/30 | Validation CSI @ 220.0K: 0.1464


Epoch 16/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.1238]


Epoch 16/30 | Validation CSI @ 220.0K: 0.1515


Epoch 17/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=1.7916]


Epoch 17/30 | Validation CSI @ 220.0K: 0.1562


Epoch 18/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.49it/s, loss=2.0061]


Epoch 18/30 | Validation CSI @ 220.0K: 0.1533


Epoch 19/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.49it/s, loss=2.1148]


Epoch 19/30 | Validation CSI @ 220.0K: 0.1495


Epoch 20/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.49it/s, loss=1.8967]


Epoch 20/30 | Validation CSI @ 220.0K: 0.1597


Epoch 21/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=1.9060]


Epoch 21/30 | Validation CSI @ 220.0K: 0.1464


Epoch 22/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=2.3870]


Epoch 22/30 | Validation CSI @ 220.0K: 0.1487


Epoch 23/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.51it/s, loss=2.2977]


Epoch 23/30 | Validation CSI @ 220.0K: 0.1535


Epoch 24/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.7834]


Epoch 24/30 | Validation CSI @ 220.0K: 0.1482


Epoch 25/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=2.7248]


Epoch 25/30 | Validation CSI @ 220.0K: 0.1496


Epoch 26/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.48it/s, loss=2.4848]


Epoch 26/30 | Validation CSI @ 220.0K: 0.1494


Epoch 27/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.0123]


Epoch 27/30 | Validation CSI @ 220.0K: 0.1583


Epoch 28/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:20<00:00,  7.50it/s, loss=1.9178]


Epoch 28/30 | Validation CSI @ 220.0K: 0.1492


Epoch 29/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=2.6102]


Epoch 29/30 | Validation CSI @ 220.0K: 0.1400


Epoch 30/30 [Diagnostic Train]: 100%|██████████| 150/150 [00:19<00:00,  7.50it/s, loss=1.7906]


Epoch 30/30 | Validation CSI @ 220.0K: 0.1556

🎉 Diagnostic Network training complete! Best CSI achieved: 0.1653


In [None]:
#
# This script performs the final, comprehensive evaluation and visualization
# of the complete two-stage "Forecast and Diagnose" pipeline.
# VERSION 2: Corrected ValueError during test set iteration.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import pandas as pd
from scipy.stats import wasserstein_distance
from scipy.ndimage import sobel
import json
import matplotlib.pyplot as plt
import matplotlib.colors as colors

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
ABLATION_MODEL_DIR = PROJECT_PATH / 'ablation_study_models'
DIAGNOSTIC_MODEL_DIR = PROJECT_PATH / 'final_diagnostic_model'
OUTPUT_DIR = PROJECT_PATH / 'final_evaluation_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Models to Evaluate ---
STAGE_1_MODEL_PATH = ABLATION_MODEL_DIR / 'ablation_remove_IVT.pt'
STAGE_2_MODEL_PATH = DIAGNOSTIC_MODEL_DIR / 'best_diagnostic_model.pt'

# --- Evaluation Configuration ---
CSI_THRESHOLDS_K = [230.0, 220.0, 210.0]
DIAGNOSTIC_THRESHOLD_K = 220.0
NUM_CASES_TO_PLOT = 3
COLORMAP = 'CMRmap_r'

# --- Model & Data Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLES_TO_USE = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLES_TO_USE]
INPUT_CHANNELS = len(VARIABLE_INDICES)

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0), pred_path.stem.replace('_predictor','')

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module): # Stage 1 Model
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class DiagnosticUNet(nn.Module): # Stage 2 Model
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, depth=3):
        super().__init__()
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Conv2d(self.channels[0], output_channels, 1)
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 3. METRIC & HELPER FUNCTIONS ---
def calculate_detailed_csi(pred_k, true_k, threshold_k):
    pred_mask = pred_k <= threshold_k; true_mask = true_k <= threshold_k
    hits = (pred_mask & true_mask).sum()
    misses = (~pred_mask & true_mask).sum()
    false_alarms = (pred_mask & ~true_mask).sum()
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0
    return csi

def calculate_all_metrics(pred_norm, true_norm, stats):
    pred_k = pred_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    true_k = true_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    metrics = {'rmse': np.sqrt(np.mean((pred_k - true_k)**2))}
    for thr in CSI_THRESHOLDS_K:
        metrics[f'csi_{int(thr)}K'] = calculate_detailed_csi(pred_k, true_k, thr)
    return metrics

def find_best_test_cases(dataset, threshold_k, num_cases):
    scores = []
    for i in tqdm(range(len(dataset)), desc="Finding best cases"):
        _, target_norm, _ = dataset[i]
        target_k = target_norm.numpy().squeeze() * (dataset.stats['target_std'] + 1e-8) + dataset.stats['target_mean']
        score = (target_k <= threshold_k).sum()
        scores.append((score, i))
    scores.sort(key=lambda x: x[0], reverse=True)
    return [idx for _, idx in scores[:num_cases]]

# --- 4. VISUALIZATION FUNCTION ---
def create_diagnostic_figure(case_data, stage1_model, stage2_model, stats, device):
    predictor, target_norm, case_name = case_data

    with torch.no_grad():
        predictor_subset = predictor[VARIABLE_INDICES, :, :].unsqueeze(0).to(device)
        stage1_pred_norm = stage1_model(predictor_subset)
        stage2_pred_logits = stage2_model(stage1_pred_norm)
        stage2_pred_prob = torch.sigmoid(stage2_pred_logits)

    stage1_pred_k = stage1_pred_norm.cpu().numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
    ground_truth_k = target_norm.numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
    threat_mask = stage2_pred_prob.cpu().numpy().squeeze()

    fig, axes = plt.subplots(1, 3, figsize=(21, 7), facecolor='white')
    fig.suptitle(f'Final Pipeline Evaluation for Case: {case_name}', fontsize=20, fontweight='bold')

    norm = colors.PowerNorm(gamma=0.4, vmin=200, vmax=275)

    axes[0].imshow(ground_truth_k, cmap=COLORMAP, norm=norm)
    axes[0].set_title('A) Ground Truth TBB', fontsize=16)

    axes[1].imshow(stage1_pred_k, cmap=COLORMAP, norm=norm)
    axes[1].set_title('B) Stage 1: Quantitative Forecast', fontsize=16)

    axes[2].imshow(ground_truth_k, cmap='gray', norm=norm)
    axes[2].imshow(threat_mask, cmap='Reds', alpha=(threat_mask * 0.8), vmin=0, vmax=1)
    axes[2].set_title('C) Final Product: Threat Highlight', fontsize=16)

    for ax in axes:
        ax.set_xticks([]); ax.set_yticks([])

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = OUTPUT_DIR / f"final_pipeline_figure_{case_name}.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Saved diagnostic figure to {save_path}")

# --- 5. MAIN SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔧 Starting final evaluation on device: {device}")

    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    stats = test_dataset.stats

    stage1_model = FlexibleAblationModel(input_channels=INPUT_CHANNELS).to(device)
    stage1_model.load_state_dict(torch.load(STAGE_1_MODEL_PATH, map_location=device))
    stage1_model.eval()

    stage2_model = DiagnosticUNet().to(device)
    stage2_model.load_state_dict(torch.load(STAGE_2_MODEL_PATH, map_location=device))
    stage2_model.eval()

    # --- Comprehensive Evaluation ---
    stage1_metrics_list, final_pipeline_metrics_list = [], []
    with torch.no_grad():
        # --- KEY FIX: Unpack all three items from the dataloader ---
        for predictor, target, _ in tqdm(test_loader, desc="Evaluating Test Set"):
            predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device)

            # Stage 1 Prediction
            stage1_pred_norm = stage1_model(predictor_subset)
            stage1_metrics = calculate_all_metrics(stage1_pred_norm.cpu().numpy().squeeze(), target.numpy().squeeze(), stats)
            stage1_metrics_list.append(stage1_metrics)

            # Stage 2 Prediction
            stage2_pred_logits = stage2_model(stage1_pred_norm)
            stage2_pred_mask = (torch.sigmoid(stage2_pred_logits) > 0.5).cpu().numpy().squeeze()

            final_pred_k = np.where(stage2_pred_mask, stage1_pred_norm.cpu().numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean'], 300)
            true_k = target.numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']

            pipeline_metrics = {'rmse': np.sqrt(np.mean((final_pred_k - true_k)**2))}
            for thr in CSI_THRESHOLDS_K:
                pipeline_metrics[f'csi_{int(thr)}K'] = calculate_detailed_csi(final_pred_k, true_k, thr)
            final_pipeline_metrics_list.append(pipeline_metrics)

    # --- Reporting ---
    stage1_df = pd.DataFrame(stage1_metrics_list).mean().add_prefix('Stage1_')
    final_df = pd.DataFrame(final_pipeline_metrics_list).mean().add_prefix('Final_Pipeline_')

    report_df = pd.concat([stage1_df, final_df], axis=1)

    print("\n\n" + "="*80)
    print("🏆 FINAL PIPELINE EVALUATION REPORT 🏆")
    print("="*80)
    print(report_df.round(4))
    print("="*80)
    report_df.to_csv(OUTPUT_DIR / 'final_pipeline_evaluation.csv')
    print(f"\n💾 Detailed results saved to: {OUTPUT_DIR / 'final_pipeline_evaluation.csv'}")

    # --- Visualization ---
    best_case_indices = find_best_test_cases(test_dataset, DIAGNOSTIC_THRESHOLD_K, NUM_CASES_TO_PLOT)
    print(f"\nGenerating {len(best_case_indices)} visualizations for the most intense cases...")
    for idx in best_case_indices:
        case_data = test_dataset[idx]
        create_diagnostic_figure(case_data, stage1_model, stage2_model, stats, device)

    print("\n🎉 Final evaluation and visualization complete!")

if __name__ == "__main__":
    main()


🔧 Starting final evaluation on device: cuda
Loaded 'test' dataset with 150 samples.


Evaluating Test Set: 100%|██████████| 150/150 [04:02<00:00,  1.62s/it]




🏆 FINAL PIPELINE EVALUATION REPORT 🏆
                              0       1
Stage1_rmse              9.0806     NaN
Stage1_csi_230K          0.4293     NaN
Stage1_csi_220K          0.1607     NaN
Stage1_csi_210K          0.0587     NaN
Final_Pipeline_rmse         NaN  9.0839
Final_Pipeline_csi_230K     NaN  0.4293
Final_Pipeline_csi_220K     NaN  0.1607
Final_Pipeline_csi_210K     NaN  0.0587

💾 Detailed results saved to: /content/drive/My Drive/AR_Downscaling/final_evaluation_results/final_pipeline_evaluation.csv


Finding best cases: 100%|██████████| 150/150 [00:01<00:00, 142.82it/s]



Generating 3 visualizations for the most intense cases...
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_evaluation_results/final_pipeline_figure_20230807_1200.png
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_evaluation_results/final_pipeline_figure_20230615_1800.png
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_evaluation_results/final_pipeline_figure_20230807_1800.png

🎉 Final evaluation and visualization complete!


In [None]:
#
# This script performs the final, comprehensive evaluation and visualization
# of the complete two-stage "Forecast and Diagnose" pipeline.
# VERSION 3: Corrected TypeError in visualization function.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import pandas as pd
from scipy.ndimage import sobel
import json
import matplotlib.pyplot as plt
import matplotlib.colors as colors

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
ABLATION_MODEL_DIR = PROJECT_PATH / 'ablation_study_models'
DIAGNOSTIC_MODEL_DIR = PROJECT_PATH / 'final_diagnostic_model'
OUTPUT_DIR = PROJECT_PATH / 'final_publication_figures'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Models to Evaluate ---
STAGE_1_MODEL_PATH = ABLATION_MODEL_DIR / 'ablation_remove_IVT.pt'
STAGE_2_MODEL_PATH = DIAGNOSTIC_MODEL_DIR / 'best_diagnostic_model.pt'

# --- Evaluation Configuration ---
DIAGNOSTIC_THRESHOLD_K = 220.0
NUM_CASES_TO_PLOT = 3
COLORMAP = 'CMRmap_r'

# --- Model & Data Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLES_TO_USE = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLES_TO_USE]
INPUT_CHANNELS = len(VARIABLE_INDICES)

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        # Return the unnormalized target (as a tensor) for creating the ground truth mask later
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_data), pred_path.stem.replace('_predictor','')

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module): # Stage 1 Model
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class DiagnosticUNet(nn.Module): # Stage 2 Model
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, depth=3):
        super().__init__()
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Conv2d(self.channels[0], output_channels, 1)
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 3. HELPER FUNCTIONS ---
def find_best_test_cases(dataset, threshold_k, num_cases):
    scores = []
    for i in tqdm(range(len(dataset)), desc="Finding best cases"):
        _, target_k_tensor, _ = dataset[i]
        score = (target_k_tensor.numpy() <= threshold_k).sum()
        scores.append((score, i))
    scores.sort(key=lambda x: x[0], reverse=True)
    return [idx for _, idx in scores[:num_cases]]

def calculate_csi_from_mask(pred_mask, true_mask):
    # This function now expects NumPy arrays
    hits = (pred_mask & true_mask).sum()
    misses = (~pred_mask & true_mask).sum()
    false_alarms = (pred_mask & ~true_mask).sum()
    return hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0

# --- 4. VISUALIZATION FUNCTION ---
def create_diagnostic_figure(case_data, stage1_model, stage2_model, stats, device):
    predictor_norm, ground_truth_k_tensor, case_name = case_data
    ground_truth_k = ground_truth_k_tensor.numpy() # Convert to numpy for plotting/masking

    with torch.no_grad():
        predictor_subset = predictor_norm[VARIABLE_INDICES, :, :].unsqueeze(0).to(device)
        stage1_pred_norm = stage1_model(predictor_subset)
        stage2_pred_logits = stage2_model(stage1_pred_norm)
        stage2_pred_prob = torch.sigmoid(stage2_pred_logits)

    stage1_pred_k = stage1_pred_norm.cpu().numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
    threat_mask = stage2_pred_prob.cpu().numpy().squeeze()

    fig, axes = plt.subplots(1, 3, figsize=(21, 7), facecolor='white')
    fig.suptitle(f'Final Pipeline Evaluation for Case: {case_name}', fontsize=20, fontweight='bold')

    norm = colors.PowerNorm(gamma=0.4, vmin=200, vmax=275)

    # Panel A: Ground Truth
    axes[0].imshow(ground_truth_k, cmap=COLORMAP, norm=norm)
    true_mask = (ground_truth_k <= DIAGNOSTIC_THRESHOLD_K)
    axes[0].set_title('A) Ground Truth TBB', fontsize=16)

    # Panel B: Stage 1 Forecast
    axes[1].imshow(stage1_pred_k, cmap=COLORMAP, norm=norm)
    stage1_mask = (stage1_pred_k <= DIAGNOSTIC_THRESHOLD_K)
    # --- KEY FIX: Ensure both masks are numpy arrays before passing to CSI function ---
    stage1_csi = calculate_csi_from_mask(stage1_mask, true_mask)
    axes[1].set_title(f'B) Stage 1 Forecast (CSI: {stage1_csi:.3f})', fontsize=16)

    # Panel C: Final "Threat Highlight" Product
    axes[2].imshow(ground_truth_k, cmap='gray', norm=norm)
    axes[2].imshow(threat_mask, cmap='Reds', alpha=(threat_mask * 0.7), vmin=0.2, vmax=1)
    final_mask = (threat_mask > 0.5)
    final_csi = calculate_csi_from_mask(final_mask, true_mask)
    axes[2].set_title(f'C) Final Product: Threat Highlight (CSI: {final_csi:.3f})', fontsize=16)

    for ax in axes:
        ax.set_xticks([]); ax.set_yticks([])

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = OUTPUT_DIR / f"final_pipeline_figure_{case_name}.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Saved diagnostic figure to {save_path}")

# --- 5. MAIN SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Load Models ---
    stage1_model = FlexibleAblationModel(input_channels=INPUT_CHANNELS).to(device)
    stage1_model.load_state_dict(torch.load(STAGE_1_MODEL_PATH, map_location=device))
    stage1_model.eval()
    print(f"✅ Loaded Stage 1 Model: {STAGE_1_MODEL_PATH.name}")

    stage2_model = DiagnosticUNet().to(device)
    stage2_model.load_state_dict(torch.load(STAGE_2_MODEL_PATH, map_location=device))
    stage2_model.eval()
    print(f"✅ Loaded Stage 2 Model: {STAGE_2_MODEL_PATH.name}")

    # --- Find and Visualize Best Cases ---
    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    best_case_indices = find_best_test_cases(test_dataset, DIAGNOSTIC_THRESHOLD_K, NUM_CASES_TO_PLOT)

    print(f"\nGenerating {len(best_case_indices)} visualizations for the most intense cases...")
    for idx in best_case_indices:
        case_data = test_dataset[idx]
        create_diagnostic_figure(case_data, stage1_model, stage2_model, test_dataset.stats, device)

    print("\n🎉 Final visualization complete!")

if __name__ == "__main__":
    main()


✅ Loaded Stage 1 Model: ablation_remove_IVT.pt
✅ Loaded Stage 2 Model: best_diagnostic_model.pt
Loaded 'test' dataset with 150 samples.


Finding best cases: 100%|██████████| 150/150 [00:01<00:00, 105.90it/s]



Generating 3 visualizations for the most intense cases...
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/final_pipeline_figure_20230807_1200.png
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/final_pipeline_figure_20230615_1800.png
✅ Saved diagnostic figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/final_pipeline_figure_20230807_1800.png

🎉 Final visualization complete!


In [None]:
#
# This script generates the final, publication-quality figures that demonstrate
# the operational output of the "Forecast and Diagnose" pipeline.
# VERSION 3: Simplified to a clean, two-panel operational figure.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.colors as colors

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
ABLATION_MODEL_DIR = PROJECT_PATH / 'ablation_study_models'
DIAGNOSTIC_MODEL_DIR = PROJECT_PATH / 'final_diagnostic_model'
OUTPUT_DIR = PROJECT_PATH / 'final_publication_figures'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Models to Visualize ---
STAGE_1_MODEL_PATH = ABLATION_MODEL_DIR / 'ablation_remove_IVT.pt'
STAGE_2_MODEL_PATH = DIAGNOSTIC_MODEL_DIR / 'best_diagnostic_model.pt'

# --- Visualization Configuration ---
NUM_CASES_TO_PLOT = 3
DIAGNOSTIC_THRESHOLD_K = 220.0
COLORMAP = 'CMRmap_r'

# --- Model & Data Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLES_TO_USE = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLES_TO_USE]
INPUT_CHANNELS = len(VARIABLE_INDICES)

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_data), pred_path.stem.replace('_predictor','')

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module): # Stage 1 Model
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class DiagnosticUNet(nn.Module): # Stage 2 Model
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, depth=3):
        super().__init__()
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Conv2d(self.channels[0], output_channels, 1)
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 3. HELPER FUNCTIONS ---
def find_best_test_cases(dataset, threshold_k, num_cases):
    scores = []
    for i in tqdm(range(len(dataset)), desc="Finding best cases"):
        _, target_k_tensor, _ = dataset[i]
        score = (target_k_tensor.numpy() <= threshold_k).sum()
        scores.append((score, i))
    scores.sort(key=lambda x: x[0], reverse=True)
    return [idx for _, idx in scores[:num_cases]]

# --- 4. VISUALIZATION FUNCTION ---
def create_operational_figure(case_data, stage1_model, stage2_model, stats, device):
    predictor_norm, _, case_name = case_data

    with torch.no_grad():
        predictor_subset = predictor_norm[VARIABLE_INDICES, :, :].unsqueeze(0).to(device)
        stage1_pred_norm = stage1_model(predictor_subset)
        stage2_pred_logits = stage2_model(stage1_pred_norm)
        stage2_pred_prob = torch.sigmoid(stage2_pred_logits)

    stage1_pred_k = stage1_pred_norm.cpu().numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
    threat_mask = stage2_pred_prob.cpu().numpy().squeeze()

    # --- KEY CHANGE: Create a two-panel figure ---
    fig, axes = plt.subplots(1, 2, figsize=(14, 7), facecolor='white')
    fig.suptitle(f'Operational Forecast for Case: {case_name}', fontsize=20, fontweight='bold')

    norm = colors.PowerNorm(gamma=0.4, vmin=200, vmax=275)

    # Panel A: Stage 1 Forecast
    axes[0].imshow(stage1_pred_k, cmap=COLORMAP, norm=norm)
    axes[0].set_title('A) Stage 1: Quantitative TBB Forecast', fontsize=16)

    # Panel B: Final Operational Product
    axes[1].imshow(stage1_pred_k, cmap=COLORMAP, norm=norm)
    # Use a threshold on the probability map for a cleaner overlay
    axes[1].imshow(threat_mask, cmap='Reds', alpha=(threat_mask > 0.5) * 0.7, vmin=0, vmax=1)
    axes[1].set_title('B) Final Product: Threat Highlight', fontsize=16)

    for ax in axes:
        ax.set_xticks([]); ax.set_yticks([])

    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    save_path = OUTPUT_DIR / f"operational_figure_{case_name}.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Saved operational figure to {save_path}")

# --- 5. MAIN SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Load Models ---
    stage1_model = FlexibleAblationModel(input_channels=INPUT_CHANNELS).to(device)
    stage1_model.load_state_dict(torch.load(STAGE_1_MODEL_PATH, map_location=device))
    stage1_model.eval()
    print(f"✅ Loaded Stage 1 Model: {STAGE_1_MODEL_PATH.name}")

    stage2_model = DiagnosticUNet().to(device)
    stage2_model.load_state_dict(torch.load(STAGE_2_MODEL_PATH, map_location=device))
    stage2_model.eval()
    print(f"✅ Loaded Stage 2 Model: {STAGE_2_MODEL_PATH.name}")

    # --- Find and Visualize Best Cases ---
    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    best_case_indices = find_best_test_cases(test_dataset, DIAGNOSTIC_THRESHOLD_K, NUM_CASES_TO_PLOT)

    print(f"\nGenerating {len(best_case_indices)} operational visualizations...")
    for idx in best_case_indices:
        case_data = test_dataset[idx]
        create_operational_figure(case_data, stage1_model, stage2_model, test_dataset.stats, device)

    print("\n🎉 Final visualization complete!")

if __name__ == "__main__":
    main()


✅ Loaded Stage 1 Model: ablation_remove_IVT.pt
✅ Loaded Stage 2 Model: best_diagnostic_model.pt
Loaded 'test' dataset with 150 samples.


Finding best cases: 100%|██████████| 150/150 [00:01<00:00, 101.19it/s]



Generating 3 operational visualizations...
✅ Saved operational figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/operational_figure_20230807_1200.png
✅ Saved operational figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/operational_figure_20230615_1800.png
✅ Saved operational figure to /content/drive/My Drive/AR_Downscaling/final_publication_figures/operational_figure_20230807_1800.png

🎉 Final visualization complete!


In [None]:
#
# This script performs the final, comprehensive evaluation of the complete
# two-stage "Forecast and Diagnose" pipeline against the Stage 1 baseline.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import pandas as pd

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
ABLATION_MODEL_DIR = PROJECT_PATH / 'ablation_study_models'
DIAGNOSTIC_MODEL_DIR = PROJECT_PATH / 'final_diagnostic_model'
OUTPUT_DIR = PROJECT_PATH / 'final_evaluation_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Models to Evaluate ---
STAGE_1_MODEL_PATH = ABLATION_MODEL_DIR / 'ablation_remove_IVT.pt'
STAGE_2_MODEL_PATH = DIAGNOSTIC_MODEL_DIR / 'best_diagnostic_model.pt'

# --- Evaluation Configuration ---
CSI_THRESHOLDS_K = [230.0, 220.0, 210.0]

# --- Model & Data Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLES_TO_USE = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLES_TO_USE]
INPUT_CHANNELS = len(VARIABLE_INDICES)

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        # Return the unnormalized target for creating the ground truth mask
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_data)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module): # Stage 1 Model
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

class DiagnosticUNet(nn.Module): # Stage 2 Model
    def __init__(self, input_channels=1, output_channels=1, base_channels=32, depth=3):
        super().__init__()
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Conv2d(self.channels[0], output_channels, 1)
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 3. METRIC FUNCTIONS ---
def calculate_csi(pred_mask, true_mask):
    hits = (pred_mask & true_mask).sum()
    misses = (~pred_mask & true_mask).sum()
    false_alarms = (pred_mask & ~true_mask).sum()
    return hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0

# --- 4. MAIN EVALUATION SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔧 Starting final evaluation on device: {device}")

    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    stats = test_dataset.stats

    stage1_model = FlexibleAblationModel(input_channels=INPUT_CHANNELS).to(device)
    stage1_model.load_state_dict(torch.load(STAGE_1_MODEL_PATH, map_location=device))
    stage1_model.eval()

    stage2_model = DiagnosticUNet().to(device)
    stage2_model.load_state_dict(torch.load(STAGE_2_MODEL_PATH, map_location=device))
    stage2_model.eval()

    stage1_metrics_list = []
    final_pipeline_metrics_list = []

    with torch.no_grad():
        for predictor_norm, target_k_tensor in tqdm(test_loader, desc="Evaluating Test Set"):
            predictor_subset = predictor_norm[:, VARIABLE_INDICES, :, :].to(device)
            target_k = target_k_tensor.numpy().squeeze()

            # --- Stage 1 Evaluation ---
            stage1_pred_norm = stage1_model(predictor_subset)
            stage1_pred_k = stage1_pred_norm.cpu().numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']

            s1_metrics = {}
            for thr in CSI_THRESHOLDS_K:
                s1_pred_mask = stage1_pred_k <= thr
                s1_true_mask = target_k <= thr
                s1_metrics[f'csi_{int(thr)}K'] = calculate_csi(s1_pred_mask, s1_true_mask)
            stage1_metrics_list.append(s1_metrics)

            # --- Stage 2 / Final Pipeline Evaluation ---
            stage2_pred_logits = stage2_model(stage1_pred_norm)
            stage2_pred_mask = (torch.sigmoid(stage2_pred_logits) > 0.5).cpu().numpy().squeeze()

            final_metrics = {}
            for thr in CSI_THRESHOLDS_K:
                # The final product's "prediction" is the mask from Stage 2
                final_true_mask = target_k <= thr
                # We assume the diagnostic model was trained for 220K, so its mask is the prediction for all thresholds
                final_metrics[f'csi_{int(thr)}K'] = calculate_csi(stage2_pred_mask, final_true_mask)
            final_pipeline_metrics_list.append(final_metrics)

    # --- Reporting ---
    stage1_df = pd.DataFrame(stage1_metrics_list).mean().add_prefix('Stage1_')
    final_df = pd.DataFrame(final_pipeline_metrics_list).mean().add_prefix('Final_Pipeline_')

    report_df = pd.concat([stage1_df, final_df], axis=1)

    print("\n\n" + "="*80)
    print("🏆 FINAL PIPELINE EVALUATION REPORT 🏆")
    print("="*80)
    print(report_df.round(4))
    print("="*80)

    # Calculate and print the percentage improvement
    print("\n--- CSI Improvement from Stage 2 Diagnosis ---")
    for thr in CSI_THRESHOLDS_K:
        s1_score = report_df.loc[f'Stage1_csi_{int(thr)}K'][0]
        final_score = report_df.loc[f'Final_Pipeline_csi_{int(thr)}K'][1]
        if s1_score > 0:
            improvement = ((final_score - s1_score) / s1_score) * 100
            print(f"  - CSI @ {int(thr)}K: {s1_score:.4f} -> {final_score:.4f} ({improvement:+.2f}%)")
        else:
            print(f"  - CSI @ {int(thr)}K: {s1_score:.4f} -> {final_score:.4f} (Improvement cannot be calculated)")

    print("\n" + "="*80)

    report_df.to_csv(OUTPUT_DIR / 'final_pipeline_evaluation.csv')
    print(f"\n💾 Detailed results saved to: {OUTPUT_DIR / 'final_pipeline_evaluation.csv'}")

if __name__ == "__main__":
    main()


🔧 Starting final evaluation on device: cuda
Loaded 'test' dataset with 150 samples.


Evaluating Test Set: 100%|██████████| 150/150 [00:03<00:00, 39.95it/s]



🏆 FINAL PIPELINE EVALUATION REPORT 🏆
                              0       1
Stage1_csi_230K          0.4293     NaN
Stage1_csi_220K          0.1607     NaN
Stage1_csi_210K          0.0587     NaN
Final_Pipeline_csi_230K     NaN  0.4922
Final_Pipeline_csi_220K     NaN  0.1962
Final_Pipeline_csi_210K     NaN  0.0626

--- CSI Improvement from Stage 2 Diagnosis ---
  - CSI @ 230K: 0.4293 -> 0.4922 (+14.65%)
  - CSI @ 220K: 0.1607 -> 0.1962 (+22.12%)
  - CSI @ 210K: 0.0587 -> 0.0626 (+6.76%)


💾 Detailed results saved to: /content/drive/My Drive/AR_Downscaling/final_evaluation_results/final_pipeline_evaluation.csv



