In [None]:
#
# This script systematically trains the Regression Baseline model for all 21
# variable combinations to conduct a scientifically rigorous ablation study.
# It uses the same AdvancedLoss as the main regression model for fair comparison.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import joblib
import warnings
from tqdm import tqdm
import json
import gc

warnings.filterwarnings('ignore')

# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'ablation_study_models'
OUTPUT_DIR.mkdir(exist_ok=True)

# --- Training Hyperparameters (MUST remain constant for all models) ---
EPOCHS = 25
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
WARMUP_EPOCHS = 3
EARLY_STOPPING_PATIENCE = 7
GRADIENT_CLIP = 1.0

# --- 2. ABLATION STUDY CONFIGURATION ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']

ABLATION_CONFIGS = {
    # --- Single Variable Models (5 total) ---
    'single_IVT': ['IVT'],
    'single_T500': ['T500'],
    'single_T850': ['T850'],
    'single_RH700': ['RH700'],
    'single_W500': ['W500'],
    # --- Two Variable Models (6 total from your script) ---
    'pair_IVT_T500': ['IVT', 'T500'],
    'pair_IVT_RH700': ['IVT', 'RH700'],
    'pair_T500_T850': ['T500', 'T850'],
    'pair_RH700_W500': ['RH700', 'W500'],
    'pair_IVT_W500': ['IVT', 'W500'],
    'pair_T500_W500': ['T500', 'W500'],
    # --- Three Variable Models (4 total from your script) ---
    'triplet_IVT_T500_RH700': ['IVT', 'T500', 'RH700'],
    'triplet_T500_T850_W500': ['T500', 'T850', 'W500'],
    'triplet_IVT_RH700_W500': ['IVT', 'RH700', 'W500'],
    'triplet_IVT_T500_W500': ['IVT', 'T500', 'W500'],
    # --- Four Variable Models ("Leave-One-Out") (5 total) ---
    'remove_IVT': [v for v in ALL_VARIABLES if v != 'IVT'],
    'remove_T500': [v for v in ALL_VARIABLES if v != 'T500'],
    'remove_T850': [v for v in ALL_VARIABLES if v != 'T850'],
    'remove_RH700': [v for v in ALL_VARIABLES if v != 'RH700'],
    'remove_W500': [v for v in ALL_VARIABLES if v != 'W500'],
    # --- Five Variable Model (1 total) ---
    'all_variables': ALL_VARIABLES,
}

# --- 3. FLEXIBLE DATASET & MODEL ---
class FlexibleAblationDataset(Dataset):
    def __init__(self, data_dir: Path, split: str, variable_indices: list):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        self.variable_indices = variable_indices

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        full_predictor = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_subset = full_predictor[self.variable_indices, :, :]
        mean_subset = self.stats['predictor_mean'][self.variable_indices, None, None]
        std_subset = self.stats['predictor_std'][self.variable_indices, None, None]
        predictor_norm = (predictor_subset - mean_subset) / (std_subset + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class FlexibleAblationModel(nn.Module):
    def __init__(self, input_channels, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))

    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))

    def forward(self, x):
        skips = []
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)

# --- 4. ADVANCED LOSS FUNCTION (for consistency) ---
class AdvancedLoss(nn.Module):
    def __init__(self, mse_weight=0.6, ssim_weight=0.2, gradient_weight=0.2):
        super().__init__()
        self.mse_weight = mse_weight
        self.ssim_weight = ssim_weight
        self.gradient_weight = gradient_weight
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

    def ssim_loss(self, pred, target, window_size=11):
        mu1 = F.avg_pool2d(pred, window_size, 1, padding=window_size//2)
        mu2 = F.avg_pool2d(target, window_size, 1, padding=window_size//2)
        mu1_sq, mu2_sq, mu1_mu2 = mu1**2, mu2**2, mu1 * mu2
        sigma1_sq = F.avg_pool2d(pred**2, window_size, 1, padding=window_size//2) - mu1_sq
        sigma2_sq = F.avg_pool2d(target**2, window_size, 1, padding=window_size//2) - mu2_sq
        sigma12 = F.avg_pool2d(pred*target, window_size, 1, padding=window_size//2) - mu1_mu2
        c1, c2 = 0.01**2, 0.03**2
        ssim_map = ((2*mu1_mu2 + c1) * (2*sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
        return 1 - ssim_map.mean()

    def gradient_loss(self, pred, target):
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).to(pred.device)
        pred_grad_x = F.conv2d(pred, sobel_x, padding=1)
        pred_grad_y = F.conv2d(pred, sobel_y, padding=1)
        target_grad_x = F.conv2d(target, sobel_x, padding=1)
        target_grad_y = F.conv2d(target, sobel_y, padding=1)
        return self.l1_loss(pred_grad_x, target_grad_x) + self.l1_loss(pred_grad_y, target_grad_y)

    def forward(self, pred, target):
        mse = self.mse_loss(pred, target)
        ssim = self.ssim_loss(pred, target)
        grad = self.gradient_loss(pred, target)
        return (self.mse_weight * mse + self.ssim_weight * ssim + self.gradient_weight * grad)

# --- 5. CONSISTENT TRAINING FUNCTION ---
def train_one_model(config):
    model_name = config['name']
    variable_names = config['variables']
    variable_indices = [ALL_VARIABLES.index(v) for v in variable_names]
    num_channels = len(variable_indices)

    print("\n" + "="*80)
    print(f"🔬 STARTING ABLATION RUN: {model_name}")
    print(f"   Variables ({num_channels}): {', '.join(variable_names)}")
    print("="*80)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = FlexibleAblationDataset(DATA_DIR, 'train', variable_indices)
    val_dataset = FlexibleAblationDataset(DATA_DIR, 'val', variable_indices)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = FlexibleAblationModel(input_channels=num_channels).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    loss_fn = AdvancedLoss() # Using the consistent, advanced loss function
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
        for predictor, target in pbar:
            predictor, target = predictor.to(device), target.to(device)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                prediction = model(predictor)
                loss = loss_fn(prediction, target)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            scaler.step(optimizer)
            scaler.update()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for predictor, target in val_loader:
                predictor, target = predictor.to(device), target.to(device)
                prediction = model(predictor)
                val_loss += loss_fn(prediction, target).item()

        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch+1} | Val Loss: {avg_val_loss:.6f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            model_path = OUTPUT_DIR / f"ablation_{model_name}.pt"
            metadata_path = OUTPUT_DIR / f"ablation_{model_name}.json"
            torch.save(model.state_dict(), model_path)
            with open(metadata_path, 'w') as f:
                json.dump(config, f, indent=2)
            print(f"   ✅ Best model saved to {model_path}")
        else:
            patience_counter += 1

        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"   🛑 Early stopping triggered.")
            break

    print(f"🔬 FINISHED ABLATION RUN: {model_name} | Best Val Loss: {best_val_loss:.6f}")
    gc.collect()
    torch.cuda.empty_cache()

# --- 6. MAIN EXECUTION SCRIPT ---
def main():
    for name, variables in ABLATION_CONFIGS.items():
        config = {"name": name, "variables": variables}
        train_one_model(config)

    print("\n\n🎉🎉🎉 All ablation study models trained successfully! 🎉🎉🎉")

if __name__ == "__main__":
    main()



🔬 STARTING ABLATION RUN: single_IVT
   Variables (1): IVT


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.02it/s, loss=0.7096]


Epoch 1 | Val Loss: 0.710103
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_IVT.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.44it/s, loss=0.5540]


Epoch 2 | Val Loss: 0.697578
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_IVT.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.48it/s, loss=0.7690]


Epoch 3 | Val Loss: 0.685988
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_IVT.pt


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.42it/s, loss=0.5994]


Epoch 4 | Val Loss: 0.683127
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_IVT.pt


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.45it/s, loss=0.7380]


Epoch 5 | Val Loss: 0.798299


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.56it/s, loss=0.6967]


Epoch 6 | Val Loss: 0.709325


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.60it/s, loss=0.6450]


Epoch 7 | Val Loss: 0.714092


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.64it/s, loss=1.0187]


Epoch 8 | Val Loss: 0.718896


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.68it/s, loss=0.6853]


Epoch 9 | Val Loss: 0.673902
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_IVT.pt


Epoch 10/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.58it/s, loss=0.5529]


Epoch 10 | Val Loss: 1.153316


Epoch 11/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=0.6762]


Epoch 11 | Val Loss: 1.613254


Epoch 12/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=0.6383]


Epoch 12 | Val Loss: 0.689770


Epoch 13/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=0.8160]


Epoch 13 | Val Loss: 0.691979


Epoch 14/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=0.8393]


Epoch 14 | Val Loss: 0.690672


Epoch 15/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=0.8307]


Epoch 15 | Val Loss: 0.682033


Epoch 16/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=0.4907]


Epoch 16 | Val Loss: 0.687442
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: single_IVT | Best Val Loss: 0.673902

🔬 STARTING ABLATION RUN: single_T500
   Variables (1): T500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:48<00:00,  3.06it/s, loss=0.7448]


Epoch 1 | Val Loss: 0.821180
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_T500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.70it/s, loss=2.3314]


Epoch 2 | Val Loss: 19.245550


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.5221]


Epoch 3 | Val Loss: 2.034605


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=3.2485]


Epoch 4 | Val Loss: 9.741885


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.77it/s, loss=3.1197]


Epoch 5 | Val Loss: 12.636429


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.78it/s, loss=2.7920]


Epoch 6 | Val Loss: 11.607461


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.78it/s, loss=4.0227]


Epoch 7 | Val Loss: 20.019143


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.77it/s, loss=5.1341]


Epoch 8 | Val Loss: 14.440988
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: single_T500 | Best Val Loss: 0.821180

🔬 STARTING ABLATION RUN: single_T850
   Variables (1): T850


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.04it/s, loss=0.7433]


Epoch 1 | Val Loss: 0.782613
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_T850.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.70it/s, loss=1.1074]


Epoch 2 | Val Loss: 17.381653


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.1369]


Epoch 3 | Val Loss: 20.566318


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.1377]


Epoch 4 | Val Loss: 5.739299


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.3106]


Epoch 5 | Val Loss: 3.532061


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=3.3577]


Epoch 6 | Val Loss: 5.234361


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.1119]


Epoch 7 | Val Loss: 4.927393


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.6561]


Epoch 8 | Val Loss: 5.205822
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: single_T850 | Best Val Loss: 0.782613

🔬 STARTING ABLATION RUN: single_RH700
   Variables (1): RH700


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.03it/s, loss=0.5124]


Epoch 1 | Val Loss: 0.702754
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_RH700.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=1.5070]


Epoch 2 | Val Loss: 6.304401


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=1.9986]


Epoch 3 | Val Loss: 248.524984


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.0409]


Epoch 4 | Val Loss: 3.999715


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=1.9011]


Epoch 5 | Val Loss: 5.067943


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.4132]


Epoch 6 | Val Loss: 3.375730


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.2272]


Epoch 7 | Val Loss: 3.440747


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=1.7030]


Epoch 8 | Val Loss: 2.660444
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: single_RH700 | Best Val Loss: 0.702754

🔬 STARTING ABLATION RUN: single_W500
   Variables (1): W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.00it/s, loss=0.7153]


Epoch 1 | Val Loss: 0.784482
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.54it/s, loss=0.8083]


Epoch 2 | Val Loss: 0.762139
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_W500.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.45it/s, loss=1.0019]


Epoch 3 | Val Loss: 0.725087
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_W500.pt


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:34<00:00,  4.40it/s, loss=0.6977]


Epoch 4 | Val Loss: 0.737175


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.54it/s, loss=0.6285]


Epoch 5 | Val Loss: 0.788077


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.55it/s, loss=0.6645]


Epoch 6 | Val Loss: 0.743097


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.55it/s, loss=0.9149]


Epoch 7 | Val Loss: 0.653913
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_W500.pt


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.42it/s, loss=0.6052]


Epoch 8 | Val Loss: 0.623570
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_single_W500.pt


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.50it/s, loss=0.7404]


Epoch 9 | Val Loss: 0.673904


Epoch 10/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.62it/s, loss=0.5455]


Epoch 10 | Val Loss: 0.688375


Epoch 11/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.62it/s, loss=0.8890]


Epoch 11 | Val Loss: 0.634550


Epoch 12/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.62it/s, loss=0.6401]


Epoch 12 | Val Loss: 0.637561


Epoch 13/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.62it/s, loss=0.6945]


Epoch 13 | Val Loss: 0.703280


Epoch 14/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.62it/s, loss=0.8158]


Epoch 14 | Val Loss: 0.637165


Epoch 15/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.61it/s, loss=0.6922]


Epoch 15 | Val Loss: 0.714478
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: single_W500 | Best Val Loss: 0.623570

🔬 STARTING ABLATION RUN: pair_IVT_T500
   Variables (2): IVT, T500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:50<00:00,  2.94it/s, loss=0.6197]


Epoch 1 | Val Loss: 0.660718
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_IVT_T500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.67it/s, loss=0.6421]


Epoch 2 | Val Loss: 0.671360


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=0.6065]


Epoch 3 | Val Loss: 0.674736


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.0932]


Epoch 4 | Val Loss: 0.686084


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.1762]


Epoch 5 | Val Loss: 1.356601


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.0823]


Epoch 6 | Val Loss: 1.718338


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.3448]


Epoch 7 | Val Loss: 1.587123


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.1159]


Epoch 8 | Val Loss: 2.185878
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_IVT_T500 | Best Val Loss: 0.660718

🔬 STARTING ABLATION RUN: pair_IVT_RH700
   Variables (2): IVT, RH700


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.02it/s, loss=0.8475]


Epoch 1 | Val Loss: 0.680633
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_IVT_RH700.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.69it/s, loss=0.9457]


Epoch 2 | Val Loss: 2.509288


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.5341]


Epoch 3 | Val Loss: 1.812392


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.8711]


Epoch 4 | Val Loss: 6.802514


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=3.9805]


Epoch 5 | Val Loss: 4.370924


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.2489]


Epoch 6 | Val Loss: 4.295389


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.3014]


Epoch 7 | Val Loss: 3.510108


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.0500]


Epoch 8 | Val Loss: 2.655897
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_IVT_RH700 | Best Val Loss: 0.680633

🔬 STARTING ABLATION RUN: pair_T500_T850
   Variables (2): T500, T850


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.00it/s, loss=0.5867]


Epoch 1 | Val Loss: 0.758740
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_T500_T850.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.69it/s, loss=0.8661]


Epoch 2 | Val Loss: 1.704326


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=1.0091]


Epoch 3 | Val Loss: 0.801034


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.4756]


Epoch 4 | Val Loss: 3.366870


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.1512]


Epoch 5 | Val Loss: 4.133761


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.1304]


Epoch 6 | Val Loss: 3.517705


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.8634]


Epoch 7 | Val Loss: 4.247675


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.8421]


Epoch 8 | Val Loss: 25.037043
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_T500_T850 | Best Val Loss: 0.758740

🔬 STARTING ABLATION RUN: pair_RH700_W500
   Variables (2): RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.01it/s, loss=0.9109]


Epoch 1 | Val Loss: 0.717165
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_RH700_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.64it/s, loss=0.7566]


Epoch 2 | Val Loss: 0.671729
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_RH700_W500.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.60it/s, loss=0.5290]


Epoch 3 | Val Loss: 0.816654


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=1.6699]


Epoch 4 | Val Loss: 6.882489


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.3647]


Epoch 5 | Val Loss: 9.382259


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.6850]


Epoch 6 | Val Loss: 13.336525


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=4.5270]


Epoch 7 | Val Loss: 6.077782


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=3.7399]


Epoch 8 | Val Loss: 5.345804


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.4660]


Epoch 9 | Val Loss: 4.760293
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_RH700_W500 | Best Val Loss: 0.671729

🔬 STARTING ABLATION RUN: pair_IVT_W500
   Variables (2): IVT, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:50<00:00,  3.00it/s, loss=0.9657]


Epoch 1 | Val Loss: 0.683011
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_IVT_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.57it/s, loss=1.1099]


Epoch 2 | Val Loss: 0.686459


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.60it/s, loss=0.5567]


Epoch 3 | Val Loss: 0.661175
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_IVT_W500.pt


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.47it/s, loss=0.7223]


Epoch 4 | Val Loss: 0.664928


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.61it/s, loss=0.8133]


Epoch 5 | Val Loss: 0.707956


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.64it/s, loss=0.6563]


Epoch 6 | Val Loss: 0.644150
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_IVT_W500.pt


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.54it/s, loss=0.7521]


Epoch 7 | Val Loss: 0.656658


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.69it/s, loss=2.1189]


Epoch 8 | Val Loss: 4.396570


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=1.7595]


Epoch 9 | Val Loss: 950.874237


Epoch 10/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=1.8197]


Epoch 10 | Val Loss: 2.087349


Epoch 11/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.3621]


Epoch 11 | Val Loss: 2.836181


Epoch 12/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.2411]


Epoch 12 | Val Loss: 3.526403


Epoch 13/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.6281]


Epoch 13 | Val Loss: 2.625594
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_IVT_W500 | Best Val Loss: 0.644150

🔬 STARTING ABLATION RUN: pair_T500_W500
   Variables (2): T500, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.04it/s, loss=0.5986]


Epoch 1 | Val Loss: 0.699504
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_pair_T500_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.63it/s, loss=0.6876]


Epoch 2 | Val Loss: 0.740923


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.68it/s, loss=0.7100]


Epoch 3 | Val Loss: 5.644346


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=2.2283]


Epoch 4 | Val Loss: 45.982360


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.7371]


Epoch 5 | Val Loss: 18.541214


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.9494]


Epoch 6 | Val Loss: 7.781294


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.2055]


Epoch 7 | Val Loss: 4.383222


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.4952]


Epoch 8 | Val Loss: 2.560978
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: pair_T500_W500 | Best Val Loss: 0.699504

🔬 STARTING ABLATION RUN: triplet_IVT_T500_RH700
   Variables (3): IVT, T500, RH700


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.01it/s, loss=0.7145]


Epoch 1 | Val Loss: 0.661056
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_T500_RH700.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.69it/s, loss=0.5254]


Epoch 2 | Val Loss: 0.658044
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_T500_RH700.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.58it/s, loss=0.5990]


Epoch 3 | Val Loss: 0.685095


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.70it/s, loss=1.2228]


Epoch 4 | Val Loss: 7.232293


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=2.5316]


Epoch 5 | Val Loss: 1.618143


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.4862]


Epoch 6 | Val Loss: 4.074911


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.8133]


Epoch 7 | Val Loss: 3.397140


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.9587]


Epoch 8 | Val Loss: 4.103336


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=3.1178]


Epoch 9 | Val Loss: 4.282472
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: triplet_IVT_T500_RH700 | Best Val Loss: 0.658044

🔬 STARTING ABLATION RUN: triplet_T500_T850_W500
   Variables (3): T500, T850, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.02it/s, loss=0.8498]


Epoch 1 | Val Loss: 0.785225
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_T500_T850_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.65it/s, loss=0.8845]


Epoch 2 | Val Loss: 0.690116
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_T500_T850_W500.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.56it/s, loss=0.6439]


Epoch 3 | Val Loss: 0.654452
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_T500_T850_W500.pt


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:33<00:00,  4.54it/s, loss=1.6068]


Epoch 4 | Val Loss: 3.225465


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=2.2715]


Epoch 5 | Val Loss: 3.120574


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=3.4636]


Epoch 6 | Val Loss: 2.137218


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.0235]


Epoch 7 | Val Loss: 3.382535


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.9960]


Epoch 8 | Val Loss: 2.839100


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=3.1150]


Epoch 9 | Val Loss: 2.723120


Epoch 10/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.7561]


Epoch 10 | Val Loss: 3.188538
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: triplet_T500_T850_W500 | Best Val Loss: 0.654452

🔬 STARTING ABLATION RUN: triplet_IVT_RH700_W500
   Variables (3): IVT, RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.02it/s, loss=0.7269]


Epoch 1 | Val Loss: 0.658414
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_RH700_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.65it/s, loss=0.8667]


Epoch 2 | Val Loss: 0.704640


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=1.2407]


Epoch 3 | Val Loss: 1.543323


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=3.1639]


Epoch 4 | Val Loss: 109.433516


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=1.9100]


Epoch 5 | Val Loss: 2.748820


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.4863]


Epoch 6 | Val Loss: 2.708580


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.2969]


Epoch 7 | Val Loss: 2.693905


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.8203]


Epoch 8 | Val Loss: 3.877355
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: triplet_IVT_RH700_W500 | Best Val Loss: 0.658414

🔬 STARTING ABLATION RUN: triplet_IVT_T500_W500
   Variables (3): IVT, T500, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.01it/s, loss=0.7288]


Epoch 1 | Val Loss: 0.888294
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_T500_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.67it/s, loss=0.6513]


Epoch 2 | Val Loss: 0.660102
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_T500_W500.pt


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.57it/s, loss=0.7851]


Epoch 3 | Val Loss: 0.651884
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_triplet_IVT_T500_W500.pt


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.56it/s, loss=0.6684]


Epoch 4 | Val Loss: 0.672337


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.2615]


Epoch 5 | Val Loss: 1.602244


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.6130]


Epoch 6 | Val Loss: 91.184482


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=1.8249]


Epoch 7 | Val Loss: 6.640826


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.4763]


Epoch 8 | Val Loss: 5.215681


Epoch 9/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=2.1976]


Epoch 9 | Val Loss: 5.886581


Epoch 10/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=1.7486]


Epoch 10 | Val Loss: 5.067738
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: triplet_IVT_T500_W500 | Best Val Loss: 0.651884

🔬 STARTING ABLATION RUN: remove_IVT
   Variables (4): T500, T850, RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.03it/s, loss=0.5137]


Epoch 1 | Val Loss: 0.708407
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_remove_IVT.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.69it/s, loss=1.4312]


Epoch 2 | Val Loss: 13.348461


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=1.5840]


Epoch 3 | Val Loss: 2.392385


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=2.0768]


Epoch 4 | Val Loss: 2.871944


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.0143]


Epoch 5 | Val Loss: 5.340636


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.76it/s, loss=4.4285]


Epoch 6 | Val Loss: 5.354775


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=4.5680]


Epoch 7 | Val Loss: 4.546124


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=4.1496]


Epoch 8 | Val Loss: 3.792282
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: remove_IVT | Best Val Loss: 0.708407

🔬 STARTING ABLATION RUN: remove_T500
   Variables (4): IVT, T850, RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.03it/s, loss=0.4924]


Epoch 1 | Val Loss: 0.868993
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_remove_T500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.68it/s, loss=1.0038]


Epoch 2 | Val Loss: 1.287481


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.70it/s, loss=2.5295]


Epoch 3 | Val Loss: 34.307343


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=3.0544]


Epoch 4 | Val Loss: 3.047592


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=3.3059]


Epoch 5 | Val Loss: 3.038998


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.9043]


Epoch 6 | Val Loss: 3.434396


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.4362]


Epoch 7 | Val Loss: 2.564261


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.7750]


Epoch 8 | Val Loss: 2.596969
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: remove_T500 | Best Val Loss: 0.868993

🔬 STARTING ABLATION RUN: remove_T850
   Variables (4): IVT, T500, RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:48<00:00,  3.07it/s, loss=0.6342]


Epoch 1 | Val Loss: 0.630803
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_remove_T850.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.68it/s, loss=1.9743]


Epoch 2 | Val Loss: 2.174740


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.71it/s, loss=2.9460]


Epoch 3 | Val Loss: 7.833229


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=3.0284]


Epoch 4 | Val Loss: 5.611734


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.9602]


Epoch 5 | Val Loss: 6.338569


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.6385]


Epoch 6 | Val Loss: 16.819782


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.4896]


Epoch 7 | Val Loss: 3.893343


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=3.3024]


Epoch 8 | Val Loss: 4.134728
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: remove_T850 | Best Val Loss: 0.630803

🔬 STARTING ABLATION RUN: remove_RH700
   Variables (4): IVT, T500, T850, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:49<00:00,  3.00it/s, loss=0.5872]


Epoch 1 | Val Loss: 0.657720
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_remove_RH700.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.66it/s, loss=1.1775]


Epoch 2 | Val Loss: 9.866232


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.9426]


Epoch 3 | Val Loss: 5.674184


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=1.9806]


Epoch 4 | Val Loss: 1.420342


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.5903]


Epoch 5 | Val Loss: 1.875696


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.6689]


Epoch 6 | Val Loss: 2.867537


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.5292]


Epoch 7 | Val Loss: 6.476830


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.5473]


Epoch 8 | Val Loss: 2.078914
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: remove_RH700 | Best Val Loss: 0.657720

🔬 STARTING ABLATION RUN: remove_W500
   Variables (4): IVT, T500, T850, RH700


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:50<00:00,  2.98it/s, loss=0.7437]


Epoch 1 | Val Loss: 0.670206
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_remove_W500.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.69it/s, loss=0.5821]


Epoch 2 | Val Loss: 0.676756


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.2897]


Epoch 3 | Val Loss: 2.673434


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.6092]


Epoch 4 | Val Loss: 2.473910


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.6397]


Epoch 5 | Val Loss: 2.842528


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=2.1626]


Epoch 6 | Val Loss: 2.214226


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.7676]


Epoch 7 | Val Loss: 2.774950


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.75it/s, loss=2.5092]


Epoch 8 | Val Loss: 3.700891
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: remove_W500 | Best Val Loss: 0.670206

🔬 STARTING ABLATION RUN: all_variables
   Variables (5): IVT, T500, T850, RH700, W500


Epoch 1/25 [Train]: 100%|██████████| 150/150 [00:50<00:00,  3.00it/s, loss=0.4587]


Epoch 1 | Val Loss: 0.725073
   ✅ Best model saved to /content/drive/My Drive/AR_Downscaling/ablation_study_models/ablation_all_variables.pt


Epoch 2/25 [Train]: 100%|██████████| 150/150 [00:32<00:00,  4.68it/s, loss=0.9583]


Epoch 2 | Val Loss: 2.974386


Epoch 3/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.70it/s, loss=1.0732]


Epoch 3 | Val Loss: 2.372182


Epoch 4/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.72it/s, loss=1.8062]


Epoch 4 | Val Loss: 5.312698


Epoch 5/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=1.7034]


Epoch 5 | Val Loss: 2.382265


Epoch 6/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.0423]


Epoch 6 | Val Loss: 4.105523


Epoch 7/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.74it/s, loss=2.4264]


Epoch 7 | Val Loss: 4.712220


Epoch 8/25 [Train]: 100%|██████████| 150/150 [00:31<00:00,  4.73it/s, loss=1.5335]


Epoch 8 | Val Loss: 1.607224
   🛑 Early stopping triggered.
🔬 FINISHED ABLATION RUN: all_variables | Best Val Loss: 0.725073


🎉🎉🎉 All ablation study models trained successfully! 🎉🎉🎉
