In [None]:
#
# This script performs a comprehensive, quantitative evaluation comparing the
# results of the ablation study to determine variable importance.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import pandas as pd
from scipy.stats import wasserstein_distance
from scipy.ndimage import sobel
import json
import math

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
MODEL_DIR = PROJECT_PATH / 'ablation_study_models' # Directory where ablation models were saved
OUTPUT_DIR = PROJECT_PATH / 'final_evaluation_results'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Evaluation Configuration ---
CSI_THRESHOLDS_K = [230.0, 220.0, 210.0]
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']

# --- 2. DATASET & MODEL ARCHITECTURES ---
class MultiVariableARDataset(Dataset):
    """Loads the full 5-channel dataset for evaluation."""
    def __init__(self, data_dir: Path, split: str = 'test'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module):
    """The Regression Baseline model, adapted to accept any number of input channels."""
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 3. METRIC FUNCTIONS ---
def calculate_detailed_csi(pred_k, true_k, threshold_k):
    pred_mask = pred_k <= threshold_k; true_mask = true_k <= threshold_k
    hits = (pred_mask & true_mask).sum()
    misses = (~pred_mask & true_mask).sum()
    false_alarms = (pred_mask & ~true_mask).sum()
    csi = hits / (hits + misses + false_alarms) if (hits + misses + false_alarms) > 0 else 0.0
    event_frequency = true_mask.sum() / true_mask.size
    return csi, event_frequency

def calculate_gradient_magnitude(image):
    grad_x = sobel(image, axis=0); grad_y = sobel(image, axis=1)
    return np.sqrt(grad_x**2 + grad_y**2).mean()

def calculate_all_metrics(pred_norm, true_norm, stats):
    pred_k = pred_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    true_k = true_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
    metrics = {
        'rmse': np.sqrt(np.mean((pred_k - true_k)**2)),
        'mae': np.mean(np.abs(pred_k - true_k)),
        'sharpness': calculate_gradient_magnitude(pred_k),
        'distribution_dist': wasserstein_distance(pred_k.flatten(), true_k.flatten())
    }
    for thr in CSI_THRESHOLDS_K:
        csi, freq = calculate_detailed_csi(pred_k, true_k, thr)
        metrics[f'csi_{int(thr)}K'] = csi
        metrics[f'freq_{int(thr)}K'] = freq
    return metrics

# --- 4. MAIN EVALUATION SCRIPT ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🔧 Starting evaluation on device: {device}")

    test_dataset = MultiVariableARDataset(DATA_DIR, 'test')
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    stats = test_dataset.stats

    # --- Find and Load Models Automatically ---
    metadata_files = sorted(list(MODEL_DIR.glob('ablation_*.json')))
    if not metadata_files:
        print(f"❌ No model metadata files found in {MODEL_DIR}. Cannot evaluate.")
        return

    all_results = []
    for meta_file in metadata_files:
        with open(meta_file, 'r') as f:
            metadata = json.load(f)

        name = metadata['name']
        variables = metadata['variables']
        variable_indices = [ALL_VARIABLES.index(v) for v in variables]
        input_channels = len(variable_indices)
        model_path = meta_file.with_suffix('.pt')

        if not model_path.exists():
            print(f"⚠️ Model file not found for '{name}'. Skipping.")
            continue

        print(f"\n--- Evaluating {name} ({input_channels} variables) ---")
        model = FlexibleAblationModel(input_channels=input_channels).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        model_metrics = []
        with torch.no_grad():
            for predictor, target in tqdm(test_loader, desc=f"Predicting with {name}"):
                # Select the correct variable subset for the model
                predictor_subset = predictor[:, variable_indices, :, :].to(device)

                pred_norm = model(predictor_subset).cpu().numpy().squeeze()
                target_norm = target.numpy().squeeze()
                metrics = calculate_all_metrics(pred_norm, target_norm, stats)
                model_metrics.append(metrics)

        df = pd.DataFrame(model_metrics)
        summary = {"Model": name, "Input Channels": input_channels, "Variables": ", ".join(variables)}
        for col in df.columns:
            summary[f"{col}_mean"] = df[col].mean()
            summary[f"{col}_std"] = df[col].std()
        all_results.append(summary)

    if not all_results:
        print("❌ No models were evaluated. Exiting.")
        return

    results_df = pd.DataFrame(all_results).set_index("Model")

    # --- Reporting ---
    print("\n\n" + "="*80)
    print("🔬 FINAL ABLATION STUDY EVALUATION REPORT 🔬")
    print("="*80)

    # Sort by the most critical CSI score for ranking
    results_df = results_df.sort_values(by='csi_220K_mean', ascending=False)

    print("\n--- Performance Ranking (Sorted by CSI @ 220K) ---")
    report_cols = ['Input Channels', 'csi_230K_mean', 'csi_220K_mean', 'csi_210K_mean', 'rmse_mean', 'sharpness_mean']
    print(results_df[report_cols].round(4))

    print("\n--- Variable Importance Analysis (Based on 'Leave-One-Out' models) ---")
    control_model_csi = results_df.loc['all_variables']['csi_220K_mean']
    importance = {}
    for var in ALL_VARIABLES:
        model_name = f'remove_{var}'
        if model_name in results_df.index:
            model_csi = results_df.loc[model_name]['csi_220K_mean']
            performance_drop = (control_model_csi - model_csi) / control_model_csi
            importance[var] = performance_drop * 100

    print("Performance Drop (%) when variable is REMOVED (Higher = More Important):")
    for var, drop in sorted(importance.items(), key=lambda item: item[1], reverse=True):
        print(f"  - {var}: {drop:.2f}%")

    print("\n" + "="*80)

    results_df.to_csv(OUTPUT_DIR / 'ablation_study_evaluation.csv')
    print(f"\n💾 Detailed results saved to: {OUTPUT_DIR / 'ablation_study_evaluation.csv'}")

if __name__ == "__main__":
    main()


🔧 Starting evaluation on device: cuda
Loaded 'test' dataset with 150 samples.

--- Evaluating all_variables (5 variables) ---


Predicting with all_variables: 100%|██████████| 150/150 [01:57<00:00,  1.28it/s]



--- Evaluating pair_IVT_RH700 (2 variables) ---


Predicting with pair_IVT_RH700: 100%|██████████| 150/150 [00:16<00:00,  9.10it/s]



--- Evaluating pair_IVT_T500 (2 variables) ---


Predicting with pair_IVT_T500: 100%|██████████| 150/150 [00:15<00:00,  9.73it/s]



--- Evaluating pair_IVT_W500 (2 variables) ---


Predicting with pair_IVT_W500: 100%|██████████| 150/150 [00:17<00:00,  8.78it/s]



--- Evaluating pair_RH700_W500 (2 variables) ---


Predicting with pair_RH700_W500: 100%|██████████| 150/150 [00:16<00:00,  8.86it/s]



--- Evaluating pair_T500_T850 (2 variables) ---


Predicting with pair_T500_T850: 100%|██████████| 150/150 [00:16<00:00,  8.96it/s]



--- Evaluating pair_T500_W500 (2 variables) ---


Predicting with pair_T500_W500: 100%|██████████| 150/150 [00:16<00:00,  9.37it/s]



--- Evaluating remove_IVT (4 variables) ---


Predicting with remove_IVT: 100%|██████████| 150/150 [00:16<00:00,  8.91it/s]



--- Evaluating remove_RH700 (4 variables) ---


Predicting with remove_RH700: 100%|██████████| 150/150 [00:15<00:00,  9.93it/s]



--- Evaluating remove_T500 (4 variables) ---


Predicting with remove_T500: 100%|██████████| 150/150 [00:14<00:00, 10.44it/s]



--- Evaluating remove_T850 (4 variables) ---


Predicting with remove_T850: 100%|██████████| 150/150 [00:16<00:00,  9.17it/s]



--- Evaluating remove_W500 (4 variables) ---


Predicting with remove_W500: 100%|██████████| 150/150 [00:16<00:00,  9.32it/s]



--- Evaluating single_IVT (1 variables) ---


Predicting with single_IVT: 100%|██████████| 150/150 [00:15<00:00,  9.75it/s]



--- Evaluating single_RH700 (1 variables) ---


Predicting with single_RH700: 100%|██████████| 150/150 [00:14<00:00, 10.25it/s]



--- Evaluating single_T500 (1 variables) ---


Predicting with single_T500: 100%|██████████| 150/150 [00:16<00:00,  9.26it/s]



--- Evaluating single_T850 (1 variables) ---


Predicting with single_T850: 100%|██████████| 150/150 [00:14<00:00, 10.26it/s]



--- Evaluating single_W500 (1 variables) ---


Predicting with single_W500: 100%|██████████| 150/150 [00:15<00:00,  9.60it/s]



--- Evaluating triplet_IVT_RH700_W500 (3 variables) ---


Predicting with triplet_IVT_RH700_W500: 100%|██████████| 150/150 [00:15<00:00,  9.61it/s]



--- Evaluating triplet_IVT_T500_RH700 (3 variables) ---


Predicting with triplet_IVT_T500_RH700: 100%|██████████| 150/150 [00:15<00:00,  9.76it/s]



--- Evaluating triplet_IVT_T500_W500 (3 variables) ---


Predicting with triplet_IVT_T500_W500: 100%|██████████| 150/150 [00:16<00:00,  8.96it/s]



--- Evaluating triplet_T500_T850_W500 (3 variables) ---


Predicting with triplet_T500_T850_W500: 100%|██████████| 150/150 [00:15<00:00,  9.85it/s]



🔬 FINAL ABLATION STUDY EVALUATION REPORT 🔬

--- Performance Ranking (Sorted by CSI @ 220K) ---
                        Input Channels  csi_230K_mean  csi_220K_mean  \
Model                                                                  
pair_T500_W500                       2         0.4115         0.1756   
single_T500                          1         0.4247         0.1696   
remove_IVT                           4         0.4293         0.1607   
remove_T500                          4         0.4456         0.1511   
triplet_T500_T850_W500               3         0.4189         0.1231   
remove_T850                          4         0.4152         0.1181   
triplet_IVT_T500_W500                3         0.4201         0.0843   
triplet_IVT_RH700_W500               3         0.4429         0.0671   
single_W500                          1         0.4008         0.0578   
pair_IVT_RH700                       2         0.4351         0.0499   
remove_W500                          4 


