In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, Dataset
from pathlib import Path
import gc
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import joblib
import warnings

warnings.filterwarnings('ignore')

# --- DATASET AND MODEL DEFINITIONS (SELF-CONTAINED) ---

class MultiVariableARDataset(Dataset):
    """ The dataset class provided in your code. """
    def __init__(self, data_dir: Path, split: str = 'train'):
        self.split_dir = data_dir / split
        self.predictor_files = sorted(list(self.split_dir.glob('*_predictor.npy')))
        self.stats = joblib.load(data_dir / 'normalization_stats_multi_variable.joblib')
        if not self.predictor_files: raise FileNotFoundError(f"No predictor files in {self.split_dir}")

    def __len__(self): return len(self.predictor_files)
    def __getitem__(self, idx):
        pred_path = self.predictor_files[idx]
        targ_path = Path(str(pred_path).replace('_predictor.npy', '_target.npy'))
        predictor_data = np.load(pred_path).astype(np.float32)
        target_data = np.load(targ_path).astype(np.float32)
        predictor_norm = (predictor_data - self.stats['predictor_mean'][:, None, None]) / (self.stats['predictor_std'][:, None, None] + 1e-8)
        target_norm = (target_data - self.stats['target_mean']) / (self.stats['target_std'] + 1e-8)
        return torch.from_numpy(predictor_norm), torch.from_numpy(target_norm).unsqueeze(0)

class SimplifiedAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels // 16, 1), nn.ReLU(inplace=True), nn.Conv2d(channels // 16, channels, 1), nn.Sigmoid())
        self.spatial_att = nn.Sequential(nn.Conv2d(channels, 1, 7, padding=3), nn.Sigmoid())
    def forward(self, x):
        ch_att = self.channel_att(x); x = x * ch_att; sp_att = self.spatial_att(x); x = x * sp_att
        return x

class AttentionUNet(nn.Module):
    """ The Attention U-Net from your architecture study. """
    def __init__(self, input_channels=4, base_channels=64, depth=4, use_attention=True):
        super().__init__()
        self.use_attention, self.depth = use_attention, depth
        self.channels = [base_channels * min(2**i, 8) for i in range(depth)]
        self.encoders, self.downsamplers = nn.ModuleList(), nn.ModuleList()
        in_ch = input_channels
        for i, out_ch in enumerate(self.channels):
            self.encoders.append(self._conv_block(in_ch, out_ch))
            if i < len(self.channels) - 1: self.downsamplers.append(nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1))
            in_ch = out_ch
        bottleneck_ch = self.channels[-1]
        self.bottleneck = self._conv_block(self.channels[-1], bottleneck_ch)
        if self.use_attention: self.attention = SimplifiedAttention(bottleneck_ch)
        self.upsamplers, self.decoders = nn.ModuleList(), nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            in_ch = bottleneck_ch if i == depth - 1 else self.channels[i+1]
            out_ch = self.channels[i]
            self.upsamplers.append(nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2))
            self.decoders.append(self._conv_block(out_ch * 2, out_ch))
        self.final_conv = nn.Sequential(nn.Conv2d(self.channels[0], self.channels[0] // 2, 3, padding=1), nn.BatchNorm2d(self.channels[0] // 2), nn.ReLU(inplace=True), nn.Conv2d(self.channels[0] // 2, 1, 1))
    def _conv_block(self, in_ch, out_ch): return nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
    def forward(self, x):
        skips = []; x_orig = x
        for i in range(len(self.encoders)):
            x = self.encoders[i](x); skips.append(x)
            if i < len(self.downsamplers): x = self.downsamplers[i](x)
        x = self.bottleneck(x)
        if self.use_attention: x = self.attention(x)
        for i, (up, dec) in enumerate(zip(self.upsamplers, self.decoders)):
            x = up(x); skip = skips[len(skips) - 1 - i]
            if x.shape[-2:] != skip.shape[-2:]: x = F.interpolate(x, size=skip.shape[-2:], mode='bilinear', align_corners=False)
            x = torch.cat([x, skip], dim=1); x = dec(x)
        return self.final_conv(x)


# --- 1. CONFIGURATION ---
PROJECT_PATH = Path('/content/drive/My Drive/AR_Downscaling')
DATA_DIR = PROJECT_PATH / 'final_dataset_multi_variable'
OUTPUT_DIR = PROJECT_PATH / 'publication_experiments' / 'attention_unet_final'
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
MODEL_SAVE_PATH = OUTPUT_DIR / 'attention_unet_final_model.pth'
WEIGHTS_SAVE_PATH = PROJECT_PATH / 'hip_eef_smart_sampling_model' / 'sampler_weights.pt'

# --- Training Hyperparameters ---
EPOCHS = 50; BATCH_SIZE = 8; LEARNING_RATE = 1e-4; EARLY_STOPPING_PATIENCE = 10; GRADIENT_CLIP = 1.0

# --- Variable & Loss Configuration ---
ALL_VARIABLES = ['IVT', 'T500', 'T850', 'RH700', 'W500']
VARIABLE_NAMES = ['T500', 'T850', 'RH700', 'W500']
VARIABLE_INDICES = [ALL_VARIABLES.index(v) for v in VARIABLE_NAMES]
INPUT_CHANNELS = len(VARIABLE_INDICES)
LOSS_THRESHOLDS = { 220.0: 10.0, 210.0: 25.0 }
SAMPLING_WEIGHT_THRESHOLD = 220.0

# --- 2. TIERED WEIGHTED LOSS ---
class TieredWeightedMSELoss(nn.Module):
    def __init__(self, thresholds: dict):
        super().__init__()
        self.thresholds = sorted(thresholds.items(), key=lambda item: item[0])
        self.mse = nn.MSELoss(reduction='none')

    def forward(self, prediction_norm, target_norm, stats):
        loss = self.mse(prediction_norm, target_norm)
        with torch.no_grad():
            target_k = target_norm * (stats['target_std'] + 1e-8) + stats['target_mean']
            weights = torch.ones_like(target_k)
            for temp_k, weight in self.thresholds:
                weights[target_k <= temp_k] = weight
        return torch.mean(loss * weights)

# --- 3. SAMPLER WEIGHT CALCULATION ---
def get_sampler_weights(dataset, stats):
    if WEIGHTS_SAVE_PATH.exists():
        print(f"Loading cached sampler weights from {WEIGHTS_SAVE_PATH}")
        return torch.load(WEIGHTS_SAVE_PATH)
    # This part of the code will not be reached if weights exist
    print("Pre-computing sampler weights..."); weights = []
    for _, target_norm in tqdm(dataset):
        target_k = target_norm.numpy().squeeze() * (stats['target_std'] + 1e-8) + stats['target_mean']
        percentage = np.mean(target_k <= SAMPLING_WEIGHT_THRESHOLD); weights.append(0.1 + percentage)
    weights = torch.tensor(weights, dtype=torch.float); torch.save(weights, WEIGHTS_SAVE_PATH)
    return weights

# --- 4. MAIN TRAINING SCRIPT ---
def train_dl_baseline():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Training Strongest DL Baseline (Attention U-Net) with Smart Sampling ---")

    train_dataset = MultiVariableARDataset(DATA_DIR, 'train')
    val_dataset = MultiVariableARDataset(DATA_DIR, 'val')
    sampler_weights = get_sampler_weights(train_dataset, train_dataset.stats)
    sampler = WeightedRandomSampler(sampler_weights, num_samples=len(sampler_weights), replacement=True)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = AttentionUNet(input_channels=INPUT_CHANNELS).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
    criterion = TieredWeightedMSELoss(thresholds=LOSS_THRESHOLDS)
    print(f"Training with TieredWeightedMSELoss.")

    best_val_loss = float('inf'); patience_counter = 0
    for epoch in range(EPOCHS):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")

        # *** KEY FIX: Unpack only 2 items from the DataLoader ***
        for predictor, target_norm in pbar:
            predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device)
            target_norm = target_norm.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                prediction_norm = model(predictor_subset)
                loss = criterion(prediction_norm, target_norm, train_dataset.stats)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
            scaler.step(optimizer); scaler.update()
            pbar.set_postfix({'loss': f"{loss.item():.4f}"})

        model.eval()
        val_total_loss = 0.0
        with torch.no_grad():
            # *** KEY FIX: Unpack only 2 items in the validation loop as well ***
            for predictor, target_norm in val_loader:
                predictor_subset = predictor[:, VARIABLE_INDICES, :, :].to(device)
                target_norm = target_norm.to(device)
                prediction_norm = model(predictor_subset)
                loss = criterion(prediction_norm, target_norm, val_dataset.stats)
                val_total_loss += loss.item()
        avg_val_loss = val_total_loss / len(val_loader)
        print(f"Epoch {epoch+1} | Validation Loss: {avg_val_loss:.6f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"   ✅ Best model saved to {MODEL_SAVE_PATH}")
            patience_counter = 0
        else:
            patience_counter += 1
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"   🛑 Early stopping triggered."); break

    print(f"\n--- ✅ Strong DL Baseline Training Finished. Best Validation Loss: {best_val_loss:.6f} ---")

if __name__ == '__main__':
    train_dl_baseline()
