In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Re-using the ResBlock, Attention, and DecoderBlock from previous model ---
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1), nn.BatchNorm2d(out_channels))
        else:
            self.shortcut = nn.Identity()
    def forward(self, x):
        identity = self.shortcut(x); out = F.relu(self.bn1(self.conv1(x))); out = self.bn2(self.conv2(out)); out += identity; return F.relu(out)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__(); self.attention = nn.Sequential(nn.Conv2d(channels, channels // 8, 1), nn.ReLU(True), nn.Conv2d(channels // 8, channels, 1), nn.Sigmoid())
    def forward(self, x): return x * self.attention(x)

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, use_attention=False):
        super().__init__(); self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2); self.res_block = ResBlock(out_channels + skip_channels, out_channels); self.attention = SimplifiedAttention(out_channels) if use_attention else nn.Identity()
    def forward(self, x, skip):
        x = self.up(x)
        if x.shape != skip.shape: x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1); x = self.res_block(x); return self.attention(x)

class HiP_EEF_Direct(nn.Module):
    """
    Final HiP-EEF architecture for Direct Residual Prediction.
    This version removes the Gating Head for simpler, direct supervision.
    """
    def __init__(self, n_channels=4, base_c=32):
        super().__init__()
        # --- Shared Encoder (Unchanged) ---
        self.enc1 = ResBlock(n_channels, base_c); self.enc2 = ResBlock(base_c, base_c * 2); self.enc3 = ResBlock(base_c * 2, base_c * 4); self.enc4 = ResBlock(base_c * 4, base_c * 8); self.down = nn.MaxPool2d(2)
        self.bottleneck = ResBlock(base_c * 8, base_c * 16)

        # --- Continuity Head Decoder (Unchanged) ---
        self.up1_cont = DecoderBlock(base_c * 16, base_c * 8, base_c * 8); self.up2_cont = DecoderBlock(base_c * 8, base_c * 4, base_c * 4); self.up3_cont = DecoderBlock(base_c * 4, base_c * 2, base_c * 2); self.up4_cont = DecoderBlock(base_c * 2, base_c, base_c); self.out_cont = nn.Conv2d(base_c, 1, kernel_size=1)

        # --- Extreme Event Head Decoder (Unchanged) ---
        self.up1_ext = DecoderBlock(base_c * 16, base_c * 8, base_c * 8); self.up2_ext = DecoderBlock(base_c * 8, base_c * 4, base_c * 4, use_attention=True); self.up3_ext = DecoderBlock(base_c * 4, base_c * 2, base_c * 2); self.up4_ext = DecoderBlock(base_c * 2, base_c, base_c); self.out_ext = nn.Conv2d(base_c, 1, kernel_size=1)

    def forward(self, x):
        # --- Encoder Path ---
        s1 = self.enc1(x); s2 = self.enc2(self.down(s1)); s3 = self.enc3(self.down(s2)); s4 = self.enc4(self.down(s3))
        b = self.bottleneck(self.down(s4))

        # --- Decoder Paths ---
        c4 = self.up1_cont(b, s4); c3 = self.up2_cont(c4, s3); c2 = self.up3_cont(c3, s2); c1 = self.up4_cont(c2, s1)
        cont_pred = self.out_cont(c1)

        e4 = self.up1_ext(b, s4); e3 = self.up2_ext(e4, s3); e2 = self.up3_ext(e3, s2); e1 = self.up4_ext(e2, s1)
        ext_pred = self.out_ext(e1)

        # --- Simple Fusion ---
        final_pred = cont_pred + ext_pred

        return final_pred, cont_pred, ext_pred


In [None]:
import torch
from torch.utils.data import Dataset
from pathlib import Path
import numpy as np
import joblib

class MultiVariableARDataset(Dataset):
    """
    A flexible dataset for the AR Downscaling project that loads all available
    predictor variables and the target variable.

    The selection of specific variables (e.g., for ablation or the HiP-EEF model)
    is handled within the training script, not in the dataset itself.
    """
    def __init__(self, data_dir: Path, split: str):
        """
        Initializes the dataset.

        Args:
            data_dir (Path): The root directory of the dataset containing the splits.
            split (str): The dataset split to load ('train', 'val', or 'test').
        """
        self.split_dir = Path(data_dir) / split
        if not self.split_dir.exists():
            raise FileNotFoundError(f"Dataset split directory not found: {self.split_dir}")

        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        if not self.predictor_files:
            raise FileNotFoundError(f"No predictor files found in {self.split_dir}")

        stats_path = Path(data_dir) / 'normalization_stats_multi_variable.joblib'
        if not stats_path.exists():
            raise FileNotFoundError(f"Normalization stats file not found: {stats_path}")

        self.stats = joblib.load(stats_path)
        print(f"Loaded '{split}' dataset with {len(self.predictor_files)} samples.")

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.predictor_files)

    def __getitem__(self, idx):
        """
        Retrieves a single sample (predictor, target, case_name) from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing:
                - torch.Tensor: The full normalized predictor variables.
                - torch.Tensor: The normalized target variable.
                - str: The name of the case for identification.
        """
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        case_name = pred_path.stem.replace('_predictor', '')

        # Load the full 5-channel predictor and single-channel target
        full_predictor = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)

        # Normalize using the pre-calculated stats
        # Add new axes for broadcasting (C, 1, 1)
        predictor_norm = (full_predictor - self.stats['predictor_mean'][:, None, None]) / \
                         (self.stats['predictor_std'][:, None, None] + 1e-8)

        target_norm = (target_data - self.stats['target_mean']) / \
                      (self.stats['target_std'] + 1e-8)

        # Convert to PyTorch tensors and add channel dimension to target
        return (
            torch.from_numpy(predictor_norm),
            torch.from_numpy(target_norm).unsqueeze(0),
            case_name
        )

def calculate_validation_csi(model, val_loader, device, stats, threshold_k=220.0):
    model.eval()
    all_csi = []

    with torch.no_grad():
        for predictor, target_norm, _ in val_loader:
            predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device)
            target_norm = target_norm.to(device)

            final_pred, _, _ = model(predictor_subset)

            # Convert to Kelvin
            pred_k = final_pred.cpu().numpy() * (stats['target_std'] + 1e-8) + stats['target_mean']
            target_k = target_norm.cpu().numpy() * (stats['target_std'] + 1e-8) + stats['target_mean']

            for i in range(pred_k.shape[0]):
                pred_mask = pred_k[i].squeeze() <= threshold_k
                target_mask = target_k[i].squeeze() <= threshold_k

                hits = np.sum(pred_mask & target_mask)
                misses = np.sum(~pred_mask & target_mask)
                false_alarms = np.sum(pred_mask & ~target_mask)

                if (hits + misses + false_alarms) > 0:
                    csi = hits / (hits + misses + false_alarms)
                    all_csi.append(csi)

    return np.mean(all_csi) if all_csi else 0.0

In [None]:
# Conservative Training with simple MSE loss

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from pathlib import Path
import gc
from tqdm import tqdm
import numpy as np


# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'hip_eef_direct_model'
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_SAVE_PATH = OUTPUT_DIR / 'hip_eef_direct_model.pth'
WEIGHTS_SAVE_PATH = PROJECT_PATH / 'hip_eef_smart_sampling_model' / 'sampler_weights.pt'

# --- Training Hyperparameters ---
EPOCHS = 50; BATCH_SIZE = 8; LEARNING_RATE = 1e-4; EARLY_STOPPING_PATIENCE = 10; GRADIENT_CLIP = 1.0

# --- HiP-EEF Variable & Loss Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']; VARIABLE_NAMES = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLE_NAMES]; INPUT_CHANNELS = len(VARIABLE_INDICES)

# --- Loss Weights for Direct Supervision ---
ALPHA = 0.5  # Continuity Head (predicting capped background)
BETA = 1.5   # Extreme Head (predicting storm residual)
DELTA = 0.4  # Final Fused Prediction (overall accuracy)

# --- Direct Supervision Parameters ---
RESIDUAL_THRESHOLD_K = 225.0
SAMPLING_WEIGHT_THRESHOLD = 220.0

# --- 2. GROUND TRUTH DECOMPOSITION & SAMPLER ---
def decompose_ground_truth(target_k, threshold_k):
    background_gt = torch.clamp(target_k, min=threshold_k)
    extreme_residual_gt = target_k - background_gt # This will be <= 0
    return background_gt, extreme_residual_gt

def get_sampler_weights(dataset, stats):
    if WEIGHTS_SAVE_PATH.exists():
        print(f"Loading cached sampler weights from {WEIGHTS_SAVE_PATH}")
        return torch.load(WEIGHTS_SAVE_PATH)
    # ... (sampler calculation is the same) ...
    print("Pre-computing sampler weights..."); weights = []
    for _, target_norm, _ in tqdm(dataset):
        target_k = target_norm.numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
        percentage = np.mean(target_k <= SAMPLING_WEIGHT_THRESHOLD); weights.append(0.1 + percentage)
    weights = torch.tensor(weights, dtype=torch.float); torch.save(weights, WEIGHTS_SAVE_PATH)
    return weights

# --- 3. MAIN TRAINING SCRIPT ---
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Starting HiP-EEF Training with DIRECT RESIDUAL PREDICTION on {device} ---")

    train_dataset = MultiVariableARDataset(DATA_DIR, 'train')
    val_dataset = MultiVariableARDataset(DATA_DIR, 'val')
    sampler = WeightedRandomSampler(get_sampler_weights(train_dataset, train_dataset.stats), len(train_dataset), True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = HiP_EEF_Direct(n_channels=INPUT_CHANNELS).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    criterion = nn.MSELoss()
    print(f"Loss Weights: α={ALPHA} (Cont), β={BETA} (Ext), δ={DELTA} (Final)")

    best_val_loss = float('inf'); patience_counter = 0
    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
        for predictor, target_norm, _ in pbar:
            predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device)
            target_norm = target_norm.to(device)
            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                final_pred, cont_pred, ext_pred = model(predictor_subset)

                with torch.no_grad():
                    stats = train_dataset.stats
                    target_k = target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
                    background_gt_k, extreme_residual_gt_k = decompose_ground_truth(target_k, RESIDUAL_THRESHOLD_K)
                    # Renormalize targets for the model
                    background_gt_norm = (background_gt_k - stats['target_mean']) / (stats['target_std'] + 1e-8)
                    extreme_residual_gt_norm = extreme_residual_gt_k / (stats['target_std'] + 1e-8)

                loss_cont = criterion(cont_pred, background_gt_norm)
                loss_ext = criterion(ext_pred, extreme_residual_gt_norm)
                loss_final = criterion(final_pred, target_norm)

                combined_loss = (ALPHA * loss_cont) + (BETA * loss_ext) + (DELTA * loss_final)

            scaler.scale(combined_loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            scaler.step(optimizer); scaler.update()
            pbar.set_postfix({'loss': f"{combined_loss.item():.4f}"})

        # --- Validation Loop ---
        model.eval(); val_total_loss = 0.0
        with torch.no_grad():
            for predictor, target_norm, _ in val_loader:
                predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device); target_norm = target_norm.to(device)
                final_pred, cont_pred, ext_pred = model(predictor_subset)
                stats = val_dataset.stats
                target_k = target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
                background_gt_k, extreme_residual_gt_k = decompose_ground_truth(target_k, RESIDUAL_THRESHOLD_K)
                background_gt_norm = (background_gt_k - stats['target_mean']) / (stats['target_std'] + 1e-8)
                extreme_residual_gt_norm = extreme_residual_gt_k / (stats['target_std'] + 1e-8)

                loss_cont = criterion(cont_pred, background_gt_norm)
                loss_ext = criterion(ext_pred, extreme_residual_gt_norm)
                loss_final = criterion(final_pred, target_norm)
                val_total_loss += ((ALPHA * loss_cont) + (BETA * loss_ext) + (DELTA * loss_final)).item()

        avg_val_loss = val_total_loss / len(val_loader)
        print(f"Epoch {epoch+1} | Validation Loss: {avg_val_loss:.6f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss; torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   ✅ Best model saved to {MODEL_SAVE_PATH}"); patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"   🛑 Early stopping triggered."); break
    print(f"\n--- 🔬 Direct Residual Prediction Training Finished. Best Val Loss: {best_val_loss:.6f} ---")

if __name__ == '__main__':
    train_model()



In [None]:
# Training with tiered loss

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from pathlib import Path
import gc
from tqdm import tqdm
import numpy as np

# --- FINAL TUNED Loss Weights ---
# BETA is now the clean "aggressiveness knob" for the storm-specialist head.
ALPHA = 0.5  # Continuity Head (simple MSE)
BETA = 1.5   # Extreme Head (Tiered, aggressive loss)
DELTA = 0.4  # Final Fused Prediction (simple MSE)


# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'hip_eef_final_model' # Final model directory
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_SAVE_PATH = OUTPUT_DIR / f'hip_eef_final_model_BETA_{BETA}.pth'
WEIGHTS_SAVE_PATH = PROJECT_PATH / 'hip_eef_smart_sampling_model' / 'sampler_weights.pt'

# --- Training Hyperparameters ---
EPOCHS = 50; BATCH_SIZE = 8; LEARNING_RATE = 1e-4; EARLY_STOPPING_PATIENCE = 10; GRADIENT_CLIP = 1.0

# --- HiP-EEF Variable & Loss Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']; DL_VARIABLES = ['T500', 'T850', 'RH700', 'W500']
DL_INDICES = [ALL_VARIABLES.index(v) for v in DL_VARIABLES]; INPUT_CHANNELS = len(DL_INDICES)


# --- Direct Supervision & Tiered Loss Parameters ---
RESIDUAL_THRESHOLD_K = 225.0
SAMPLING_WEIGHT_THRESHOLD = 220.0
LOSS_THRESHOLDS = { 220.0: 10.0, 210.0: 25.0 } # The "secret sauce" for the Extreme Head

# --- 2. TIERED WEIGHTED LOSS (For Extreme Head ONLY) ---
class TieredWeightedMSELoss(nn.Module):
    def __init__(self, thresholds: dict):
        super().__init__()
        self.thresholds = sorted(thresholds.items(), key=lambda item: item[0])
        self.mse = nn.MSELoss(reduction='none')

    def forward(self, prediction_norm, target_residual_norm, original_target_norm, stats):
        loss = self.mse(prediction_norm, target_residual_norm)
        with torch.no_grad():
            # Denormalize the ORIGINAL ground truth to find where the cold pixels are
            target_k = original_target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
            weights = torch.ones_like(target_k)
            for temp_k, weight in self.thresholds:
                weights[target_k <= temp_k] = weight
        return torch.mean(loss * weights)

# --- 3. GROUND TRUTH DECOMPOSITION & SAMPLER (Unchanged) ---
def decompose_ground_truth(target_k, threshold_k):
    background_gt = torch.clamp(target_k, min=threshold_k); extreme_residual_gt = target_k - background_gt
    return background_gt, extreme_residual_gt

def get_sampler_weights(dataset, stats):
    if WEIGHTS_SAVE_PATH.exists():
        print(f"Loading cached sampler weights from {WEIGHTS_SAVE_PATH}")
        return torch.load(WEIGHTS_SAVE_PATH)
    # ... (code to calculate weights is the same) ...

# --- 4. MAIN TRAINING SCRIPT ---
def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Starting FINAL HiP-EEF Training with Focused Tiered Loss on {device} ---")

    train_dataset = MultiVariableARDataset(DATA_DIR, 'train')
    val_dataset = MultiVariableARDataset(DATA_DIR, 'val')
    sampler = WeightedRandomSampler(get_sampler_weights(train_dataset, train_dataset.stats), len(train_dataset), True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = HiP_EEF_Direct(n_channels=INPUT_CHANNELS).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    criterion_mse = nn.MSELoss()
    criterion_tiered = TieredWeightedMSELoss(thresholds=LOSS_THRESHOLDS)

    print(f"Using FINAL Focused Loss Weights: α={ALPHA}, β={BETA}, δ={DELTA}")

    best_val_loss = float('inf'); patience_counter = 0
    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
        for predictor, target_norm, _ in pbar:
            predictor_subset = predictor[:, DL_INDICES, :, :].to(device); target_norm = target_norm.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                final_pred, cont_pred, ext_pred = model(predictor_subset)

                with torch.no_grad():
                    stats = train_dataset.stats
                    target_k = target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
                    background_gt_k, extreme_residual_gt_k = decompose_ground_truth(target_k, RESIDUAL_THRESHOLD_K)
                    background_gt_norm = (background_gt_k - stats['target_mean']) / (stats['target_std'] + 1e-8)
                    extreme_residual_gt_norm = extreme_residual_gt_k / (stats['target_std'] + 1e-8)

                # --- APPLY THE CORRECT, FOCUSED LOSS TO EACH HEAD ---
                loss_cont = criterion_mse(cont_pred, background_gt_norm)
                loss_ext = criterion_tiered(ext_pred, extreme_residual_gt_norm, target_norm, train_dataset.stats)
                loss_final = criterion_mse(final_pred, target_norm)

                combined_loss = (ALPHA * loss_cont) + (BETA * loss_ext) + (DELTA * loss_final)

            scaler.scale(combined_loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            scaler.step(optimizer); scaler.update()
            pbar.set_postfix({'loss': f"{combined_loss.item():.4f}"})

        # --- Validation Loop ---
        model.eval(); val_total_loss = 0.0
        with torch.no_grad():
            for predictor, target_norm, _ in val_loader:
                # ... (validation logic follows the same focused loss structure) ...
                predictor_subset = predictor[:, DL_INDICES, :, :].to(device); target_norm = target_norm.to(device)
                final_pred, cont_pred, ext_pred = model(predictor_subset)
                stats = val_dataset.stats
                target_k = target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
                background_gt_k, extreme_residual_gt_k = decompose_ground_truth(target_k, RESIDUAL_THRESHOLD_K)
                background_gt_norm = (background_gt_k - stats['target_mean']) / (stats['target_std'] + 1e-8)
                extreme_residual_gt_norm = extreme_residual_gt_k / (stats['target_std'] + 1e-8)

                loss_cont = criterion_mse(cont_pred, background_gt_norm)
                loss_ext = criterion_tiered(ext_pred, extreme_residual_gt_norm, target_norm, val_dataset.stats)
                loss_final = criterion_mse(final_pred, target_norm)
                val_total_loss += ((ALPHA * loss_cont) + (BETA * loss_ext) + (DELTA * loss_final)).item()

        avg_val_loss = val_total_loss / len(val_loader)
        print(f"Epoch {epoch+1} | Validation Loss: {avg_val_loss:.6f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss; torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   ✅ Best model saved to {MODEL_SAVE_PATH}"); patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"   🛑 Early stopping triggered."); break
    print(f"\n--- 🔬 FINAL Focused Loss Training Finished. Best Val Loss: {best_val_loss:.6f} ---")

if __name__ == '__main__':
    train_model()