# 🌎 Welcome to the CSE151B Spring 2025 Climate Emulation Competition!

Thank you for participating in this exciting challenge focused on building machine learning models to emulate complex climate systems.  
This notebook is provided as a **starter template** to help you:

- Understand how to load and preprocess the dataset  
- Construct a baseline model  
- Train and evaluate predictions using a PyTorch Lightning pipeline  
- Format your predictions for submission to the leaderboard  

You're encouraged to:
- Build on this structure or replace it entirely
- Try more advanced models and training strategies
- Incorporate your own ideas to push the boundaries of what's possible

If you're interested in developing within a repository structure and/or use helpful tools like configuration management (based on Hydra) and logging (with Weights & Biases), we recommend checking out the following Github repo. Such a structure can be useful when running multiple experiments and trying various research ideas.

👉 [https://github.com/salvaRC/cse151b-spring2025-competition](https://github.com/salvaRC/cse151b-spring2025-competition)

Good luck, have fun, and we hope you learn a lot through this process!


### 📦 Install Required Libraries
We install the necessary Python packages for data loading, deep learning, and visualization.


In [1]:
!pip install xarray zarr dask lightning matplotlib wandb cftime einops xr --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/205.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.0/819.0 kB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m59.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m104.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.5/65.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [23]:
import os
from datetime import datetime
import numpy as np
import xarray as xr
import dask.array as da
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl
import math
import torch.nn.functional as F
from torchvision import transforms
from lightning.pytorch.callbacks import ModelCheckpoint

### ⚙️ Configuration Setup  
Define all model, data, and training hyperparameters in one place for easy control and reproducibility.

### 📊 Data Configuration

We define the dataset settings used for training and evaluation. This includes:

- **`path`**: Path to the `.zarr` dataset containing monthly climate variables from CMIP6 simulations.
- **`input_vars`**: Climate forcing variables (e.g., CO₂, CH₄) used as model inputs.
- **`output_vars`**: Target variables to predict — surface air temperature (`tas`) and precipitation (`pr`).
- **`target_member_id`**: Ensemble member to use from the simulations (each SSP has 3) for target variables.
- **`train_ssps`**: SSP scenarios used for training (low to high emissions).
- **`test_ssp`**: Scenario held out for evaluation (Must be set to SSP245).
- **`test_months`**: Number of months to include in the test split (Must be set to 120).
- **`batch_size`** and **`num_workers`**: Data loading parameters for PyTorch training.

These settings reflect how the challenge is structured: models must learn from some emission scenarios and generalize to unseen ones.

> ⚠️ **Important:** Do **not modify** the following test settings:
>
> - `test_ssp` must remain **`ssp245`**, which is the held-out evaluation scenario.
> - `test_months` must be **`120`**, corresponding to the last 10 years (monthly resolution) of the scenario.



In [32]:
#NOTE Change the data directory according to where you have your zarr files stored
config = {
    "data": {
        "path": "/kaggle/input/cse151b-spring2025-competition/processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "target_member_id": [0,1,2],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "test_months": 360,
        "batch_size": 16,
        "num_workers": 4,
        "seq_len": 12,
        "predict_seq": False,
    },
    "model": {
        "type": "simple_cnn",
        "kernel_size": 3,
        "init_dim": 64,
        "depth": 4,
        "dropout_rate": 0.1,
    },
    "training": {
        "lr": 1e-3,
    },
    "trainer": {
        "max_epochs": 200,
        "accelerator": "gpu",
        "devices": "auto",
        "precision": 32,
        "deterministic": True,
        "num_sanity_val_steps": 0,
    },
    "seed": 42,
}
pl.seed_everything(config["seed"])  # Set seed for reproducibility

INFO: Seed set to 42


42

### 🔧 Spatial Weighting Utility Function

This cell sets up utility functions for reproducibility and spatial weighting:

- **`get_lat_weights(latitude_values)`**: Computes cosine-based area weights for each latitude, accounting for the Earth's curvature. This is critical for evaluating global climate metrics fairly — grid cells near the equator represent larger surface areas than those near the poles.


In [4]:
def get_lat_weights(latitude_values):
    lat_rad = np.deg2rad(latitude_values)
    weights = np.cos(lat_rad)
    return weights / np.mean(weights)

### 🧠 SimpleCNN: A Residual Convolutional Baseline

This is a lightweight baseline model designed to capture spatial patterns in global climate data using convolutional layers.

- The architecture starts with a **convolution + batch norm + ReLU** block to process the input channels.
- It then applies a series of **residual blocks** to extract increasingly abstract spatial features. These help preserve gradient flow during training.
- Finally, a few convolutional layers reduce the feature maps down to the desired number of output channels (`tas` and `pr`).

This model only serves as a **simple baseline for climate emulation**. 

We encourage you to build and experiment with your own models and ideas.


In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.skip(identity)
        return self.relu(out)

class SimpleCNN(nn.Module):
    def __init__(self, n_input_channels, n_output_channels, kernel_size=3, init_dim=64, depth=4, dropout_rate=0.2):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(n_input_channels, init_dim, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(init_dim),
            nn.ReLU(inplace=True),
        )
        self.res_blocks = nn.ModuleList()
        current_dim = init_dim
        for i in range(depth):
            out_dim = current_dim * 2 if i < depth - 1 else current_dim
            self.res_blocks.append(ResidualBlock(current_dim, out_dim))
            if i < depth - 1:
                current_dim *= 2
        self.dropout = nn.Dropout2d(dropout_rate)
        self.final = nn.Sequential(
            nn.Conv2d(current_dim, current_dim // 2, kernel_size=kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(current_dim // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(current_dim // 2, n_output_channels, kernel_size=1),
        )

    def forward(self, x):
        x = self.initial(x)
        for res_block in self.res_blocks:
            x = res_block(x)
        return self.final(self.dropout(x))


### 📐 Normalizer: Z-Score Scaling for Climate Inputs & Outputs

This class handles **Z-score normalization**, a crucial preprocessing step for stable and efficient neural network training:

- **`set_input_statistics(mean, std)` / `set_output_statistics(...)`**: Store the mean and standard deviation computed from the training data for later use.
- **`normalize(data, data_type)`**: Standardizes the data using `(x - mean) / std`. This is applied separately to inputs and outputs.
- **`inverse_transform_output(data)`**: Converts model predictions back to the original physical units (e.g., Kelvin for temperature, mm/day for precipitation).

Normalizing the data ensures the model sees inputs with similar dynamic ranges and avoids biases caused by different variable scales.


In [5]:
class Normalizer:
    def __init__(self):
        self.mean_in, self.std_in = None, None
        self.mean_out, self.std_out = None, None

    def set_input_statistics(self, mean, std):
        self.mean_in = mean
        self.std_in = std

    def set_output_statistics(self, mean, std):
        self.mean_out = mean
        self.std_out = std

    def normalize(self, data, data_type):
        if data_type == "input":
            return (data - self.mean_in) / self.std_in
        elif data_type == "output":
            return (data - self.mean_out) / self.std_out

    def inverse_transform_output(self, data):
        return data * self.std_out + self.mean_out


### 🌍 Data Module: Loading, Normalization, and Splitting

This section handles the entire data pipeline, from loading the `.zarr` dataset to preparing PyTorch-ready DataLoaders.

#### `ClimateDataset`
- A simple PyTorch `Dataset` wrapper that preloads the entire (normalized) dataset into memory using Dask.
- Converts the data to PyTorch tensors and handles any `NaN` checks up front.

#### `ClimateDataModule`
A PyTorch Lightning `DataModule` that handles:
- ✅ **Loading data** from different SSP scenarios and ensemble members
- ✅ **Broadcasting non-spatial inputs** (like CO₂) to match spatial grid size
- ✅ **Normalization** using mean/std computed from training data only
- ✅ **Splitting** into training, validation, and test sets:
  - Training: All months from selected SSPs (except last 10 years of SSP370)
  - Validation: Last 10 years (120 months) of SSP370
  - Test: Last 10 years of SSP245 (unseen scenario)
- ✅ **Batching** and parallelized data loading via PyTorch `DataLoader`s
- ✅ **Latitude-based area weighting** for fair climate metric evaluation
- Shape of the inputs are Batch_Size X 5 (num_input_variables) X 48 X 72
- Shape of ouputputs are Batch_Size X 2 (num_output_variables) X 48 X 72

> ℹ️ **Note:** You likely won’t need to modify this class but feel free to make modifications if you want to inlcude different ensemble mebers to feed more data to your models


In [34]:
class ClimateDataset(Dataset):
    def __init__(self, inputs_dask, outputs_dask, output_is_normalized=True, seq_len=12, predict_seq=True):
        self.size = inputs_dask.shape[0] - seq_len + 1
        print(f"Creating dataset with {self.size} samples...")

        inputs_np = inputs_dask.compute()
        outputs_np = outputs_dask.compute()
        
        self.inputs = torch.from_numpy(inputs_np).float()
        self.outputs = torch.from_numpy(outputs_np).float()

        self.seq_len = seq_len
        self.predict_seq = predict_seq

        if torch.isnan(self.inputs).any() or torch.isnan(self.outputs).any():
            raise ValueError("NaNs found in dataset")

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        x = self.inputs[idx:idx+self.seq_len]      # [T, C, H, W]
        y = self.outputs[idx:idx+self.seq_len] if self.predict_seq else self.outputs[idx+self.seq_len-1]
        return x, y

class AugmentedDataset(Dataset):
    def __init__(self, base_dataset, transform):
        self.base = base_dataset
        self.transform = transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, y = self.base[idx]
        x = self.transform(x)
        return x, y


class ClimateDataModule(pl.LightningDataModule):
    def __init__(
        self,
        path,
        input_vars,
        output_vars,
        train_ssps,
        test_ssp,
        target_member_id,
        val_split=0.1,
        test_months=120,
        batch_size=32,
        num_workers=0,
        seed=42,
        augmentations=None,
        seq_len=12, 
        predict_seq=True
    ):
        super().__init__()
        self.path = path
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.train_ssps = train_ssps
        self.test_ssp = test_ssp
        self.target_member_id = target_member_id
        self.val_split = val_split
        self.test_months = test_months
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.normalizer = Normalizer()
        self.augmentations=augmentations
        self.seq_len = seq_len
        self.predict_seq = predict_seq

    def prepare_data(self):
        assert os.path.exists(self.path), f"Data path not found: {self.path}"

    def setup(self, stage=None):
        ds = xr.open_zarr(self.path, consolidated=False, chunks={"time": 24})
        spatial_template = ds["rsdt"].isel(time=0, ssp=0, drop=True)

        def load_ssp(ssp, member_id=0):
            input_dask, output_dask = [], []
            for var in self.input_vars:
                da_var = ds[var].sel(ssp=ssp)
                if "latitude" in da_var.dims:
                    da_var = da_var.rename({"latitude": "y", "longitude": "x"})
                if "member_id" in da_var.dims:
                    # da_var = da_var.sel(member_id=self.target_member_id)
                    da_var = da_var.sel(member_id=member_id)
                if set(da_var.dims) == {"time"}:
                    da_var = da_var.broadcast_like(spatial_template).transpose("time", "y", "x")
                input_dask.append(da_var.data)

            for var in self.output_vars:
                # da_out = ds[var].sel(ssp=ssp, member_id=self.target_member_id)
                da_out = ds[var].sel(ssp=ssp, member_id=member_id)
                if "latitude" in da_out.dims:
                    da_out = da_out.rename({"latitude": "y", "longitude": "x"})
                output_dask.append(da_out.data)

            return da.stack(input_dask, axis=1), da.stack(output_dask, axis=1)

        train_input, train_output, val_input, val_output = [], [], None, None

        for ssp in self.train_ssps:
            for member_id in self.target_member_id:
                x, y = load_ssp(ssp, member_id=member_id)
                if ssp == "ssp370":
                    #val_input = x[-self.test_months:]
                    #val_output = y[-self.test_months:]
                    # train_input.append(x[:-self.test_months])
                    # train_output.append(y[:-self.test_months])
    
                    if member_id == 0:
                        val_input = x[-(self.test_months + self.seq_len - 1):]
                        val_output = y[-(self.test_months + self.seq_len - 1):]
                    train_input.append(x[:-self.test_months])
                    train_output.append(y[:-self.test_months])
                else:
                    train_input.append(x)
                    train_output.append(y)

        train_input = da.concatenate(train_input, axis=0)
        train_output = da.concatenate(train_output, axis=0)

        self.normalizer.set_input_statistics(
            mean=da.nanmean(train_input, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_input, axis=(0, 2, 3), keepdims=True).compute(),
        )
        self.normalizer.set_output_statistics(
            mean=da.nanmean(train_output, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_output, axis=(0, 2, 3), keepdims=True).compute(),
        )

        train_input_norm = self.normalizer.normalize(train_input, "input")
        train_output_norm = self.normalizer.normalize(train_output, "output")
        val_input_norm = self.normalizer.normalize(val_input, "input")
        val_output_norm = self.normalizer.normalize(val_output, "output")

        test_input, test_output = load_ssp(self.test_ssp)
        
        #test_input = test_input[-self.test_months:]
        #test_output = test_output[-self.test_months:]
        test_input = test_input[-(self.test_months + self.seq_len - 1):]  # Need full context
        test_output = test_output[-(self.test_months + self.seq_len - 1):]  # Only evaluating these months

        test_input_norm = self.normalizer.normalize(test_input, "input")

        train_dataset = ClimateDataset(train_input_norm, train_output_norm, seq_len=self.seq_len, predict_seq=self.predict_seq)
        if self.augmentations is not None:
            self.train_dataset = AugmentedDataset(train_dataset, self.augmentations)
        else:
            self.train_dataset = train_dataset
        self.val_dataset = ClimateDataset(val_input_norm, val_output_norm, seq_len=self.seq_len, predict_seq=self.predict_seq)
        self.test_dataset = ClimateDataset(test_input_norm, test_output, output_is_normalized=False, seq_len=self.seq_len, predict_seq=self.predict_seq)

        self.lat = spatial_template.y.values
        self.lon = spatial_template.x.values
        self.area_weights = xr.DataArray(get_lat_weights(self.lat), dims=["y"], coords={"y": self.lat})

        # Add PyTorch-ready weights
        self.lat_weights_torch = torch.tensor(get_lat_weights(self.lat), dtype=torch.float32)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def get_lat_weights(self):
        return self.area_weights

    def get_lat_weights_torch(self):
        return self.lat_weights_torch

    def get_coords(self):
        return self.lat, self.lon

In [33]:
class WeightedMSELoss(nn.Module):
    def __init__(self, lat_weights: torch.Tensor):
        super().__init__()
        self.register_buffer("lat_weights", lat_weights.view(1, 1, -1, 1))  # [1, 1, H, 1]

    def forward(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        se = (y_hat - y) ** 2                      # [B, C, H, W]
        weighted_se = se * self.lat_weights        # [B, C, H, W] × [1, 1, H, 1]
        return weighted_se.mean()                  # Final loss


### ⚡ ClimateEmulationModule: Lightning Wrapper for Climate Model Emulation

This is the core model wrapper built with **PyTorch Lightning**, which organizes the training, validation, and testing logic for the climate emulation task. Lightning abstracts away much of the boilerplate code in PyTorch-based deep learning workflows, making it easier to scale models.

#### ✅ Key Features

- **`training_step` / `validation_step` / `test_step`**: Standard Lightning hooks for computing loss and predictions at each stage. The loss used is **Mean Squared Error (MSE)**.

- **Normalization-aware outputs**:
  - During validation and testing, predictions and targets are denormalized before evaluation using stored mean/std statistics.
  - This ensures evaluation is done in real-world units (Kelvin and mm/day).

- **Metric Evaluation** via `_evaluate()`:
  For each variable (`tas`, `pr`), it calculates:
  - **Monthly Area-Weighted RMSE**
  - **Time-Mean RMSE** (RMSE on 10-year average's)
  - **Time-Stddev MAE** (MAE on 10-year standard deviation; a measure of temporal variability)
    
  These metrics reflect the competition's evaluation criteria and are logged and printed.

- **Kaggle Submission Writer**:
  After testing, predictions are saved to a `.csv` file in the required Kaggle format via `_save_submission()`.

- **Saving Predictions for Visualization**:
  - Validation predictions are saved tao `val_preds.npy` and `val_trues.npy`
  - These can be loaded later for visual inspection of the model's performance.

 🔧 **Feel free to modify any part of this module** (loss functions, evaluation, training logic) to better suit your model or training pipeline / Use pure PyTorch etc.

⚠️ The **final submission `.csv` file must strictly follow the format and naming convention used in `_save_submission()`**, as these `ID`s are used to match predictions to the hidden test set during evaluation.



In [40]:
import pandas as pd

class ClimateEmulationModule(pl.LightningModule):
    def __init__(self, model, lat_weights: torch.Tensor, learning_rate=1e-4):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=['model']) # Save all hyperparameters except the model to self.hparams.<param_name>
        #self.criterion = nn.MSELoss()
        #self.criterion = None # will be set in `on_fit_start`
        self.criterion = WeightedMSELoss(lat_weights)
        self.normalizer = None
        self.val_preds, self.val_targets = [], []
        self.test_preds, self.test_targets = [], []

    def forward(self, x):
        return self.model(x)

    def on_fit_start(self):
        self.normalizer = self.trainer.datamodule.normalizer  # Get the normalizer from the datamodule (see above)
        #lat_weights = self.trainer.datamodule.get_lat_weights_torch().to(self.device)
        #self.criterion = WeightedMSELoss(lat_weights)

    def training_step(self, batch, batch_idx):
        x, y = batch # Unpack inputs and targets (this is the output of the _getitem_ method in the Dataset above)
        y_hat = self(x)   # Forward pass
        loss = self.criterion(y_hat, y)  # Calculate loss
        self.log("train/loss", loss)  # Log loss for tracking
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log("val/loss", loss)

        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np = self.normalizer.inverse_transform_output(y.detach().cpu().numpy())

        """B, T, C, H, W = y_hat.shape

        y_hat_np = self.normalizer.inverse_transform_output(
            y_hat.detach().cpu().numpy().reshape(-1, C, H, W)
        ).reshape(B, T, C, H, W)
        
        y_np = self.normalizer.inverse_transform_output(
            y.detach().cpu().numpy().reshape(-1, C, H, W)
        ).reshape(B, T, C, H, W)"""

        
        
        self.val_preds.append(y_hat_np)
        self.val_targets.append(y_np)

        return loss

    def on_validation_epoch_end(self):
        # Concatenate all predictions and ground truths from each val step/batch into one array
        preds = np.concatenate(self.val_preds, axis=0)
        trues = np.concatenate(self.val_targets, axis=0)
        self._evaluate(preds, trues, phase="val")
        np.save("val_preds.npy", preds)
        np.save("val_trues.npy", trues)
        self.val_preds.clear()
        self.val_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np = y.detach().cpu().numpy()

        """B, T, C, H, W = y_hat.shape

        y_hat_np = self.normalizer.inverse_transform_output(
            y_hat.detach().cpu().numpy().reshape(-1, C, H, W)
        ).reshape(B, T, C, H, W)
        
        #y_np = self.normalizer.inverse_transform_output(
        #    y.detach().cpu().numpy().reshape(-1, C, H, W)
        #).reshape(B, T, C, H, W)
        y_np = y.detach().cpu().numpy()"""
        
        self.test_preds.append(y_hat_np)
        self.test_targets.append(y_np)

    def on_test_epoch_end(self):
        # Concatenate all predictions and ground truths from each test step/batch into one array
        preds = np.concatenate(self.test_preds, axis=0)
        trues = np.concatenate(self.test_targets, axis=0)
        self._evaluate(preds, trues, phase="test")
        self._save_submission(preds)
        self.test_preds.clear()
        self.test_targets.clear()

    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-2)

    def _evaluate(self, preds, trues, phase="val"):
        datamodule = self.trainer.datamodule
        area_weights = datamodule.get_lat_weights()
        lat, lon = datamodule.get_coords()
        output_vars = datamodule.output_vars

        # Flatten batch and time dims: [B, T, C, H, W] -> [B*T, C, H, W]
        #preds = preds.reshape(-1, preds.shape[2], preds.shape[3], preds.shape[4])
        #trues = trues.reshape(-1, trues.shape[2], trues.shape[3], trues.shape[4])
        time = np.arange(preds.shape[0])
        var_weights = {"tas": 0.5, "pr": 0.5}
        metric_var_weights = {
            "tas": {"monthly_rmse": 0.1, "time_mean": 1.0, "time_std": 1.0},
            "pr": {"monthly_rmse": 0.1, "time_mean": 1.0, "time_std": 0.75},
        }
        var_scores = {}
        
        for i, var in enumerate(output_vars):
            p = preds[:, i]
            t = trues[:, i]
            p_xr = xr.DataArray(p, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
            t_xr = xr.DataArray(t, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

            # RMSE
            rmse = np.sqrt(((p_xr - t_xr) ** 2).weighted(area_weights).mean(("time", "y", "x")).item())
            # RMSE of time-mean
            mean_rmse = np.sqrt(((p_xr.mean("time") - t_xr.mean("time")) ** 2).weighted(area_weights).mean(("y", "x")).item())
            # MAE of time-stddev
            std_mae = np.abs(p_xr.std("time") - t_xr.std("time")).weighted(area_weights).mean(("y", "x")).item()

            print(f"[{phase.upper()}] {var}: RMSE={rmse:.4f}, Time-Mean RMSE={mean_rmse:.4f}, Time-Stddev MAE={std_mae:.4f}")
            self.log_dict({
                f"{phase}/{var}/rmse": rmse,
                f"{phase}/{var}/time_mean_rmse": mean_rmse,
                f"{phase}/{var}/time_std_mae": std_mae,
            })
            
            weights = metric_var_weights[var]
            var_score = (
                weights["monthly_rmse"] * rmse
                + weights["time_mean"] * mean_rmse
                + weights["time_std"] * std_mae
            )

            var_scores[var] = var_score
        final_score = sum(var_weights[var] * var_scores[var] for var in output_vars)
        print(final_score)
        self.log_dict({
            "final_score": final_score,
        })

    def _save_submission(self, predictions):
        datamodule = self.trainer.datamodule
        lat, lon = datamodule.get_coords()
        output_vars = datamodule.output_vars

        # Flatten batch and time dims: [B, T, C, H, W] -> [B*T, C, H, W]
        #predictions = predictions.reshape(-1, predictions.shape[2], predictions.shape[3], predictions.shape[4])
        
        # Only keep the last timestep of each sequence
        #predictions = predictions[:, -1]  # shape: [B, C, H, W]
        time = np.arange(predictions.shape[0])

        rows = []
        for t_idx, t in enumerate(time):
            for var_idx, var in enumerate(output_vars):
                for y_idx, y in enumerate(lat):
                    for x_idx, x in enumerate(lon):
                        row_id = f"t{t_idx:03d}_{var}_{y:.2f}_{x:.2f}"
                        pred = predictions[t_idx, var_idx, y_idx, x_idx]
                        rows.append({"ID": row_id, "Prediction": pred})

        df = pd.DataFrame(rows)
        os.makedirs("submissions", exist_ok=True)
        filepath = f"submissions/kaggle_submission_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        df.to_csv(filepath, index=False)
        print(f"✅ Submission saved to: {filepath}")

### ⚡ Training & Evaluation with PyTorch Lightning

This block sets up and runs the training and testing pipeline using **PyTorch Lightning’s `Trainer`**, which abstracts away much of the boilerplate in deep learning workflows.

- **Modular Setup**:
  - `datamodule`: Handles loading, normalization, and batching of climate data.
  - `model`: A convolutional neural network that maps climate forcings to predicted outputs.
  - `lightning_module`: Wraps the model with training/validation/test logic and metric evaluation.

- **Trainer Flexibility**:
  The `Trainer` accepts a wide range of configuration options from `config["trainer"]`, including:
  - Number of epochs
  - Precision (e.g., 16-bit or 32-bit)
  - Device configuration (CPU, GPU, or TPU)
  - Determinism, logging, callbacks, and more

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1
    ):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self._reset_parameters()
        self.d_model = d_model
        self.pos_encoder = PositionalEncoding(d_model, dropout)

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self,
                src,
                tgt,
                src_mask=None,
                tgt_mask=None,
                memory_mask=None):
        src = self.pos_encoder(src)
        tgt = self.pos_encoder(tgt)
        return self.transformer(src, tgt,
                                src_mask=src_mask,
                                tgt_mask=tgt_mask,
                                memory_mask=memory_mask)


class ClimateTransformer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        d_model=128,
        nhead=8,
        num_layers=4,
        dim_feedforward=512,
        dropout=0.1,
    ):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, d_model, kernel_size=1)
        self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        self.head = nn.Linear(d_model, out_channels)

        self.out_channels = out_channels

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).permute(2, 0, 1)
        x = self.pos_encoder(x)
        x = self.transformer(x)

        x = x.permute(1, 0, 2)
        x = self.head(x)
        x = x.permute(0, 2, 1)
        x = x.view(B, self.out_channels, H, W)
        return x

In [10]:
class Relative2DPositionalEncoding(nn.Module):
    def __init__(self, num_rows: int, num_cols: int, d_model: int):
        super().__init__()
        assert d_model % 2 == 0, "d_model must be even"
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.d_model = d_model
        d_half = d_model // 2

        self.row_embed = nn.Parameter(torch.randn(num_rows, d_half))
        self.col_embed = nn.Parameter(torch.randn(num_cols, d_half))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [seq_len, B, d_model] where seq_len = num_rows * num_cols
        """
        device = x.device
        row_embed = self.row_embed.to(device)
        col_embed = self.col_embed.to(device)

        pe = torch.cat([
            row_embed[:, None, :].expand(self.num_rows, self.num_cols, -1),
            col_embed[None, :, :].expand(self.num_rows, self.num_cols, -1)
        ], dim=-1).reshape(-1, self.d_model)  # [seq_len, d_model]

        return x + pe.unsqueeze(1)  # [seq_len, B, d_model]


class HierarchicalPatchTransformer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        d_model: int = 128,
        nhead: int = 8,
        depths: tuple[int, int] = (2, 2),
        patch_sizes: tuple[int, int] = (4, 2),
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        height: int = 48,
        width: int = 72
    ):
        super().__init__()
        p1, p2 = patch_sizes
        H1, W1 = height // p1, width // p1
        H2, W2 = H1 // p2, W1 // p2
        self.d_model = d_model

        self.patch_embed1 = nn.Conv2d(in_channels, d_model, kernel_size=p1, stride=p1)
        self.pos1 = Relative2DPositionalEncoding(H1, W1, d_model)
        enc1 = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='gelu')
        self.encoder1 = nn.TransformerEncoder(enc1, num_layers=depths[0])

        self.patch_embed2 = nn.Conv2d(d_model, d_model, kernel_size=p2, stride=p2)
        self.pos2 = Relative2DPositionalEncoding(H2, W2, d_model)
        enc2 = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='gelu')
        self.encoder2 = nn.TransformerEncoder(enc2, num_layers=depths[1])

        self.up1 = nn.ConvTranspose2d(d_model, d_model, kernel_size=p2, stride=p2)
        self.up2 = nn.ConvTranspose2d(d_model, out_channels, kernel_size=p1, stride=p1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        p1 = self.patch_embed1.kernel_size[0]
        p2 = self.patch_embed2.kernel_size[0]

        x1 = self.patch_embed1(x)
        seq1 = x1.flatten(2).permute(2, 0, 1)
        seq1 = self.pos1(seq1)
        seq1 = self.encoder1(seq1)
        x1 = seq1.permute(1, 2, 0).view(B, self.d_model, H // p1, W // p1)

        x2 = self.patch_embed2(x1)
        seq2 = x2.flatten(2).permute(2, 0, 1)
        seq2 = self.pos2(seq2)
        seq2 = self.encoder2(seq2)
        x2 = seq2.permute(1, 2, 0).view(B, self.d_model, H // (p1*p2), W // (p1*p2))

        u1 = self.up1(x2)
        out = self.up2(u1)
        return out


In [11]:
class TemporalHierarchicalTransformer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        seq_len: int = 12,
        d_model: int = 128,
        nhead: int = 8,
        depths: tuple[int, int] = (2, 2),
        patch_sizes: tuple[int, int] = (4, 2),
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        height: int = 48,
        width: int = 72
    ):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        p1, p2 = patch_sizes
        H1, W1 = height // p1, width // p1
        H2, W2 = H1 // p2, W1 // p2
        self.spatial_size = H2 * W2
        self.H2, self.W2 = H2, W2

        # Shared patch encoder (like before)
        self.patch_embed1 = nn.Conv2d(in_channels, d_model, kernel_size=p1, stride=p1)
        self.patch_embed2 = nn.Conv2d(d_model, d_model, kernel_size=p2, stride=p2)

        # Positional encodings
        self.temporal_pos = nn.Parameter(torch.randn(seq_len, 1, d_model))
        self.spatial_pos = nn.Parameter(torch.randn(H2 * W2, d_model))

        # Transformer encoder
        enc = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation='gelu')
        self.transformer = nn.TransformerEncoder(enc, num_layers=sum(depths))

        # Decoder
        self.up1 = nn.ConvTranspose2d(d_model, d_model, kernel_size=p2, stride=p2)
        self.up2 = nn.ConvTranspose2d(d_model, out_channels, kernel_size=p1, stride=p1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, C, H, W]
        B, T, C, H, W = x.shape
        p1 = self.patch_embed1.kernel_size[0]
        p2 = self.patch_embed2.kernel_size[0]
        #H2, W2 = H // (p1 * p2), W // (p1 * p2)
        H2, W2 = self.H2, self.W2

        # Step 1: apply shared spatial patching per timestep
        x = x.view(B * T, C, H, W)
        x = self.patch_embed1(x)
        x = self.patch_embed2(x)  # [B*T, d_model, H2, W2]
        x = x.view(B, T, self.d_model, H2, W2)

        # Step 2: flatten spatial dims and add positional encodings
        x = x.permute(0, 1, 3, 4, 2).reshape(B, T * H2 * W2, self.d_model)  # [B, T*P, d_model]

        # Temporal-spatial positional encoding
        t_pos = self.temporal_pos[:T].expand(-1, H2 * W2, -1).reshape(T * H2 * W2, 1, self.d_model)
        s_pos = self.spatial_pos[:H2 * W2].repeat(T, 1).reshape(T * H2 * W2, self.d_model)
        x = x.permute(1, 0, 2)  # [T*P, B, d_model]
        x = x + t_pos.to(x.device) + s_pos.unsqueeze(1).to(x.device)

        # Step 3: Transformer encoding
        x = self.transformer(x)  # [T*P, B, d_model]

        # Step 4: reshape and decode per timestep
        """
        x = x.permute(1, 0, 2).reshape(B, T, H2, W2, self.d_model)
        x = x.permute(0, 1, 4, 2, 3).reshape(B * T, self.d_model, H2, W2)

        # Step 5: upsample
        x = self.up1(x)
        x = self.up2(x)
        x = x.view(B, T, -1, H, W)  # [B, T, 2, H, W]
        return x
        """

        x = x.permute(1, 0, 2)
        x = x.view(B, T, H2 * W2, self.d_model)
        x_last = x[:, -1]  # [B, H2*W2, d_model]
        x_last = x_last.reshape(B, self.d_model, H2, W2)

        x = self.up1(x_last)
        x = self.up2(x)
        return x  # [B, 2, H, W]




In [12]:
import torch
import torch.nn as nn


class TemporalHierarchicalTransformer(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        seq_len: int = 12,
        d_model: int = 128,
        nhead: int = 8,
        depths: tuple[int, int] = (2, 2),
        patch_sizes: tuple[int, int] = (4, 2),
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        height: int = 48,
        width: int = 72
    ):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        p1, p2 = patch_sizes
        H1, W1 = height // p1, width // p1
        H2, W2 = H1 // p2, W1 // p2
        self.spatial_size = H2 * W2
        self.H2, self.W2 = H2, W2
        self.H1, self.W1 = H1, H2

        # ----------------------------------------
        # 1) Shared patch encoder + LayerNorm
        # ----------------------------------------
        # First conv: in_channels → d_model, patch size = p1
        self.patch_embed1 = nn.Conv2d(in_channels, d_model, kernel_size=p1, stride=p1)
        # Immediately apply LayerNorm over flattened “(H1*W1)×d_model” tokens
        self.patch_norm1 = nn.LayerNorm(d_model)

        # Second conv: d_model → d_model, patch size = p2
        self.patch_embed2 = nn.Conv2d(d_model, d_model, kernel_size=p2, stride=p2)
        # Then normalize those “(H2*W2)×d_model” tokens
        self.patch_norm2 = nn.LayerNorm(d_model)

        # ----------------------------------------
        # 2) Positional encodings (learnable)
        # ----------------------------------------
        self.temporal_pos = nn.Parameter(torch.randn(seq_len, 1, d_model))
        self.spatial_pos  = nn.Parameter(torch.randn(H2 * W2, d_model))

        # ----------------------------------------
        # 3) TransformerEncoder (pre-norm)
        #    using norm_first=True for a “pre-norm” block
        # ----------------------------------------
        # Build one layer with norm_first=True (pre-norm)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='gelu',
            norm_first=True
        )
        # Stack `sum(depths)` of these
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=sum(depths))

        # ----------------------------------------
        # 4) Decoder / Upsampling
        # ----------------------------------------
        # First upsample from (H2, W2) back to (H1, W1)
        self.up1 = nn.ConvTranspose2d(d_model, d_model, kernel_size=p2, stride=p2)
        # Then upsample from (H1, W1) back to (H, W)
        self.up2 = nn.ConvTranspose2d(d_model, out_channels, kernel_size=p1, stride=p1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Input:
            x: [B, T, C, H, W]
        Output:
            out: [B, out_channels, H, W]
            
        (We predict only the last frame’s output.)
        """
        B, T, C, H, W = x.shape
        H1, W1 = self.H1, self.W1
        H2, W2 = self.H2, self.W2
        p1 = self.patch_embed1.kernel_size[0]  # first patch size
        p2 = self.patch_embed2.kernel_size[0]  # second patch size

        # ----------------------------------------
        # Step 1: Shared spatial patching per timestep
        # ----------------------------------------
        # Merge batch & time dims: [B*T, C, H, W]
        x = x.view(B * T, C, H, W)

        # 1.1 First Conv → [B*T, d_model, H1, W1]
        x = self.patch_embed1(x)
        # Flatten spatial → [B*T, d_model, H1*W1], then transpose → [B*T, H1*W1, d_model]
        x = x.flatten(2).transpose(1, 2)
        # Apply LayerNorm across the last dimension (d_model)
        x = self.patch_norm1(x)
        # Restore to [B*T, d_model, H1, W1]
        x = x.transpose(1, 2).view(B * T, self.d_model, H1, W1)

        # 1.2 Second Conv → [B*T, d_model, H2, W2]
        x = self.patch_embed2(x)
        # Flatten again → [B*T, d_model, H2*W2] → transpose → [B*T, H2*W2, d_model]
        x = x.flatten(2).transpose(1, 2)
        # Apply LayerNorm again
        x = self.patch_norm2(x)
        # Reshape to [B, T, d_model, H2, W2]
        x = x.transpose(1, 2).view(B, T, self.d_model, H2, W2)

        # ----------------------------------------
        # Step 2: Flatten spatial dims & add positional encodings
        # ----------------------------------------
        # Permute so tokens are contiguous: [B, T, H2, W2, d_model]
        x = x.permute(0, 1, 3, 4, 2).reshape(B, T * H2 * W2, self.d_model)
        # Now x is [B, (T·H2·W2), d_model]

        # Build temporal + spatial positional biases:
        #  - temporal_pos: [T, 1, d_model] → expand over H2*W2 → [T, H2*W2, d_model]
        t_pos = self.temporal_pos[:T]                       # [T, 1, d_model]
        t_pos = t_pos.expand(-1, H2 * W2, -1)               # [T, H2*W2, d_model]
        t_pos = t_pos.reshape(T * H2 * W2, 1, self.d_model) # [T·P, 1, d_model]

        #  - spatial_pos: [H2*W2, d_model] → repeat for each t → [T·(H2·W2), d_model]
        s_pos = self.spatial_pos[: H2 * W2]                 # [P, d_model]
        s_pos = s_pos.repeat(T, 1).reshape(T * H2 * W2, self.d_model)  # [T·P, d_model]

        # Permute x → [T·P, B, d_model] for Transformer
        x = x.permute(1, 0, 2)  # (seq_len = T·P,  B,  d_model)
        # Add both positional terms
        x = x + t_pos.to(x.device) + s_pos.unsqueeze(1).to(x.device)

        # ----------------------------------------
        # Step 3: TransformerEncoder (pre-norm)
        # ----------------------------------------
        # Now x stays [T·P, B, d_model] throughout.
        x = self.transformer(x)  # → [T·P, B, d_model]

        # ----------------------------------------
        # Step 4: Reshape & take only “last‐frame” tokens
        # ----------------------------------------
        # Bring back to [B, T·P, d_model]
        x = x.permute(1, 0, 2).reshape(B, T, H2 * W2, self.d_model)
        # We only decode the last timestep (index = T-1)
        x_last = x[:, -1]                # [B, H2·W2, d_model]
        x_last = x_last.reshape(B, self.d_model, H2, W2)

        # ----------------------------------------
        # Step 5: Upsample back to original H×W
        # ----------------------------------------
        x_out = self.up1(x_last)         # [B, d_model, H1, W1]
        x_out = self.up2(x_out)          # [B, out_channels, H, W]

        return x_out


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TemporalHierarchicalTransformerPlus(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        seq_len: int = 12,
        d_model: int = 128,
        nhead: int = 8,
        depths: tuple[int,int] = (2,2),
        patch_sizes: tuple[int,int] = (4,2),
        dim_feedforward: int = 256,
        dropout: float = 0.1,
        height: int = 48,
        width: int = 72,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model

        # 1) patch sizes and spatial dims
        p1, p2 = patch_sizes
        H1, W1 = height // p1, width // p1
        H2, W2 = H1 // p2, W1 // p2
        self.H2, self.W2 = H2, W2
        self.spatial_size = H2 * W2

        # 2) spatial patch embed
        self.patch_embed1 = nn.Conv2d(in_channels, d_model, kernel_size=p1, stride=p1)
        self.patch_embed2 = nn.Conv2d(d_model, d_model, kernel_size=p2, stride=p2)

        # 3) temporal conv stem
        self.temporal_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, bias=False)
        self.temporal_bn   = nn.BatchNorm1d(d_model)

        # 4) positional encodings
        self.temporal_pos = nn.Parameter(torch.randn(seq_len, 1, d_model))
        self.spatial_pos  = nn.Parameter(torch.randn(H2 * W2, d_model))

        # 5) transformer
        total_layers      = sum(depths)
        encoder_layer     = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation='gelu'
        )
        self.transformer  = nn.TransformerEncoder(encoder_layer, num_layers=total_layers)
        self.trans_norm   = nn.LayerNorm(d_model)

        # 6) decoder up-modules + skip projections
        self.up1        = nn.ConvTranspose2d(d_model, d_model, kernel_size=p2, stride=p2)
        self.up2        = nn.ConvTranspose2d(d_model, out_channels, kernel_size=p1, stride=p1)
        self.skip_proj1 = nn.ConvTranspose2d(d_model, d_model, kernel_size=p2, stride=p2)
        self.skip_proj2 = nn.ConvTranspose2d(d_model, d_model, kernel_size=p1, stride=p1)
        self.skip_out   = nn.Conv2d(d_model, out_channels, kernel_size=1, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        assert T == self.seq_len, f"Expected T={self.seq_len}, got {T}"
        p1 = self.patch_embed1.kernel_size[0]
        p2 = self.patch_embed2.kernel_size[0]
        H2, W2 = self.H2, self.W2
        d = self.d_model

        # A) apply patch_embed1 + patch_embed2
        z1 = x.view(B * T, C, H, W)
        z2 = self.patch_embed1(z1)
        z3 = self.patch_embed2(z2)
        z4 = z3.view(B, T, d, H2, W2)

        # B) skip_early
        skip_early = z4[:, 0].clone()

        # C) temporal‐conv stem
        stem = z4.permute(0, 3, 4, 2, 1).contiguous()
        stem = stem.view(B * H2 * W2, d, T)
        stem = self.temporal_conv(stem)
        stem = self.temporal_bn(stem)
        stem = F.gelu(stem)
        stem = stem.view(B, H2, W2, d, T).permute(0, 4, 3, 1, 2).contiguous()
        x3 = stem

        # D) flatten + positional encoding
        x4 = x3.permute(0, 1, 3, 4, 2).reshape(B, T * (H2 * W2), d)
        x4 = x4.permute(1, 0, 2).contiguous()
        tp = self.temporal_pos[:T].expand(-1, H2*W2, -1).reshape(T*H2*W2, 1, d)
        sp = self.spatial_pos[:H2*W2].repeat(T, 1).reshape(T*H2*W2, d).unsqueeze(1)
        x4 = x4 + tp.to(x4.device) + sp.to(x4.device)

        # E) transformer + residual + LayerNorm
        tr_out = self.transformer(x4)
        alpha = 0.5
        x5 = x4 + alpha * tr_out
        x5 = self.trans_norm(x5)

        # F) reshape back to [B,T,d,H2,W2]
        x5 = x5.permute(1, 0, 2).contiguous()
        x5 = x5.view(B, T, H2, W2, d)
        x5 = x5.permute(0, 1, 4, 2, 3).contiguous()

        # G) last time slice
        x_last = x5[:, -1]

        # H) decoder up1 + skip fusion
        u1 = self.up1(x_last)
        se = self.skip_proj1(skip_early)
        u1 = u1 + se
        u1 = F.gelu(u1)

        # I) final up2 + long skip
        u2 = self.up2(u1)
        se2 = self.skip_proj1(skip_early)
        se2 = F.gelu(se2)
        se2 = self.skip_proj2(se2)
        se2 = F.gelu(se2)
        se2 = self.skip_out(se2)

        out = u2 + se2
        return out


In [36]:
datamodule = ClimateDataModule(**config["data"])
"""model = HierarchicalPatchTransformer(
    in_channels=len(config["data"]["input_vars"]),
    out_channels=len(config["data"]["output_vars"]),
    dropout=0.1,
)"""

model = TemporalHierarchicalTransformerPlus(
    in_channels=len(config["data"]["input_vars"]),
    out_channels=len(config["data"]["output_vars"]),
    dropout=0.1,
    seq_len=config['data']['seq_len'],
    patch_sizes = (2,4),
)

lightning_module = ClimateEmulationModule(model, learning_rate=config["training"]["lr"])

checkpoint_callback = ModelCheckpoint(
    monitor="final_score",        # 🧠 This must match the exact string logged in `self.log()`
    mode="min",                    # ✅ Because lower RMSE is better
    save_top_k=3,                  # Save only the best
    filename="best-final-score-{epoch:02d}-{final_score:.4f}",
    auto_insert_metric_name=False,
    save_weights_only=False       # or True if you only want weights
)

trainer = pl.Trainer(callbacks=[checkpoint_callback], **config["trainer"])
trainer.fit(lightning_module, datamodule=datamodule)   # Training



Creating dataset with 8098 samples...
Creating dataset with 360 samples...
Creating dataset with 360 samples...


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type                                | Params | Mode 
----------------------------------------------------------------------
0 | model | TemporalHierarchicalTransformerPlus | 1.4 M  | train
----------------------------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.778     Total estimated model params size (MB)
53        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7299, Time-Mean RMSE=0.8899, Time-Stddev MAE=0.3999
[VAL] pr: RMSE=1.9937, Time-Mean RMSE=0.3661, Time-Stddev MAE=0.7352
1.2898792600575721


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5789, Time-Mean RMSE=0.7626, Time-Stddev MAE=0.3443
[VAL] pr: RMSE=1.9526, Time-Mean RMSE=0.2504, Time-Stddev MAE=0.7636
1.1415929836313135


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5511, Time-Mean RMSE=0.7358, Time-Stddev MAE=0.4093
[VAL] pr: RMSE=1.9559, Time-Mean RMSE=0.2977, Time-Stddev MAE=0.7904
1.1931340367349315


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5487, Time-Mean RMSE=0.7825, Time-Stddev MAE=0.3019
[VAL] pr: RMSE=1.9568, Time-Mean RMSE=0.2731, Time-Stddev MAE=0.8095
1.1575645889567125


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5105, Time-Mean RMSE=0.7769, Time-Stddev MAE=0.3105
[VAL] pr: RMSE=1.9510, Time-Mean RMSE=0.2757, Time-Stddev MAE=0.7072
1.1198030138303534


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4027, Time-Mean RMSE=0.6041, Time-Stddev MAE=0.2447
[VAL] pr: RMSE=1.9432, Time-Mean RMSE=0.2340, Time-Stddev MAE=0.7080
0.9741961104814097


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3888, Time-Mean RMSE=0.5969, Time-Stddev MAE=0.2713
[VAL] pr: RMSE=1.9406, Time-Mean RMSE=0.2612, Time-Stddev MAE=0.7430
1.0097943246794197


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3592, Time-Mean RMSE=0.5505, Time-Stddev MAE=0.2660
[VAL] pr: RMSE=1.9424, Time-Mean RMSE=0.2686, Time-Stddev MAE=0.7688
0.995959747896233


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4053, Time-Mean RMSE=0.6733, Time-Stddev MAE=0.2929
[VAL] pr: RMSE=1.9366, Time-Mean RMSE=0.2525, Time-Stddev MAE=0.7325
1.0511549666929003


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3521, Time-Mean RMSE=0.5575, Time-Stddev MAE=0.2352
[VAL] pr: RMSE=1.9430, Time-Mean RMSE=0.2540, Time-Stddev MAE=0.7022
0.951430415884192


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3585, Time-Mean RMSE=0.5995, Time-Stddev MAE=0.2083
[VAL] pr: RMSE=1.9420, Time-Mean RMSE=0.2597, Time-Stddev MAE=0.7278
0.9717202925318557


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3031, Time-Mean RMSE=0.5182, Time-Stddev MAE=0.2235
[VAL] pr: RMSE=1.9458, Time-Mean RMSE=0.2901, Time-Stddev MAE=0.7279
0.9513258789668428


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2900, Time-Mean RMSE=0.4754, Time-Stddev MAE=0.2463
[VAL] pr: RMSE=1.9334, Time-Mean RMSE=0.2175, Time-Stddev MAE=0.7206
0.900999168783357


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3111, Time-Mean RMSE=0.5123, Time-Stddev MAE=0.2251
[VAL] pr: RMSE=1.9519, Time-Mean RMSE=0.3223, Time-Stddev MAE=0.7638
0.9794314428518623


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3035, Time-Mean RMSE=0.5272, Time-Stddev MAE=0.2072
[VAL] pr: RMSE=1.9373, Time-Mean RMSE=0.2525, Time-Stddev MAE=0.7368
0.9318120566116957


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3410, Time-Mean RMSE=0.5501, Time-Stddev MAE=0.3376
[VAL] pr: RMSE=1.9338, Time-Mean RMSE=0.2359, Time-Stddev MAE=0.7314
0.9998360104441666


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4706, Time-Mean RMSE=0.8063, Time-Stddev MAE=0.2386
[VAL] pr: RMSE=1.9458, Time-Mean RMSE=0.2918, Time-Stddev MAE=0.7669
1.1267587701184245


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3451, Time-Mean RMSE=0.5563, Time-Stddev MAE=0.3352
[VAL] pr: RMSE=1.9418, Time-Mean RMSE=0.2917, Time-Stddev MAE=0.7332
1.03093416569786


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5430, Time-Mean RMSE=0.9685, Time-Stddev MAE=0.2434
[VAL] pr: RMSE=1.9334, Time-Mean RMSE=0.2305, Time-Stddev MAE=0.7554
1.1783024497404133


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3166, Time-Mean RMSE=0.5452, Time-Stddev MAE=0.2661
[VAL] pr: RMSE=1.9370, Time-Mean RMSE=0.2490, Time-Stddev MAE=0.7628
0.9788933998131315


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2659, Time-Mean RMSE=0.4189, Time-Stddev MAE=0.2980
[VAL] pr: RMSE=1.9310, Time-Mean RMSE=0.2156, Time-Stddev MAE=0.7633
0.9123813562269103


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3114, Time-Mean RMSE=0.5584, Time-Stddev MAE=0.2082
[VAL] pr: RMSE=1.9322, Time-Mean RMSE=0.2076, Time-Stddev MAE=0.7502
0.9306055580408753


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3792, Time-Mean RMSE=0.6625, Time-Stddev MAE=0.2057
[VAL] pr: RMSE=1.9335, Time-Mean RMSE=0.2321, Time-Stddev MAE=0.7173
0.9847871987693134


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2775, Time-Mean RMSE=0.4667, Time-Stddev MAE=0.3047
[VAL] pr: RMSE=1.9287, Time-Mean RMSE=0.2214, Time-Stddev MAE=0.7560
0.9401692794163892


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2312, Time-Mean RMSE=0.3643, Time-Stddev MAE=0.1988
[VAL] pr: RMSE=1.9239, Time-Mean RMSE=0.1839, Time-Stddev MAE=0.7606
0.8164626336763929


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2838, Time-Mean RMSE=0.4830, Time-Stddev MAE=0.2083
[VAL] pr: RMSE=1.9339, Time-Mean RMSE=0.2191, Time-Stddev MAE=0.7607
0.9013498623673559


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2551, Time-Mean RMSE=0.4490, Time-Stddev MAE=0.2133
[VAL] pr: RMSE=1.9258, Time-Mean RMSE=0.2063, Time-Stddev MAE=0.7239
0.8648230910670465


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3387, Time-Mean RMSE=0.5886, Time-Stddev MAE=0.3077
[VAL] pr: RMSE=1.9315, Time-Mean RMSE=0.2178, Time-Stddev MAE=0.7578
1.0047739838072247


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2540, Time-Mean RMSE=0.4419, Time-Stddev MAE=0.1767
[VAL] pr: RMSE=1.9293, Time-Mean RMSE=0.2111, Time-Stddev MAE=0.7004
0.8366512198921335


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2691, Time-Mean RMSE=0.4483, Time-Stddev MAE=0.1820
[VAL] pr: RMSE=1.9321, Time-Mean RMSE=0.2282, Time-Stddev MAE=0.7410
0.8672138290760321


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2366, Time-Mean RMSE=0.3992, Time-Stddev MAE=0.1979
[VAL] pr: RMSE=1.9263, Time-Mean RMSE=0.2007, Time-Stddev MAE=0.7123
0.8241639793862998


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2845, Time-Mean RMSE=0.5188, Time-Stddev MAE=0.2274
[VAL] pr: RMSE=1.9328, Time-Mean RMSE=0.2023, Time-Stddev MAE=0.6781
0.8894290298619705


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2708, Time-Mean RMSE=0.4855, Time-Stddev MAE=0.2631
[VAL] pr: RMSE=1.9276, Time-Mean RMSE=0.2059, Time-Stddev MAE=0.7199
0.9071500966306081


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3097, Time-Mean RMSE=0.5089, Time-Stddev MAE=0.1978
[VAL] pr: RMSE=1.9277, Time-Mean RMSE=0.1984, Time-Stddev MAE=0.7161
0.8829417568516083


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2444, Time-Mean RMSE=0.4025, Time-Stddev MAE=0.2177
[VAL] pr: RMSE=1.9298, Time-Mean RMSE=0.2204, Time-Stddev MAE=0.7387
0.8559982815214408


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2682, Time-Mean RMSE=0.4430, Time-Stddev MAE=0.2269
[VAL] pr: RMSE=1.9349, Time-Mean RMSE=0.2166, Time-Stddev MAE=0.7416
0.8814367461184576


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3169, Time-Mean RMSE=0.5986, Time-Stddev MAE=0.2376
[VAL] pr: RMSE=1.9318, Time-Mean RMSE=0.2316, Time-Stddev MAE=0.7606
0.9815153957612607


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2080, Time-Mean RMSE=0.3256, Time-Stddev MAE=0.2199
[VAL] pr: RMSE=1.9309, Time-Mean RMSE=0.2149, Time-Stddev MAE=0.7056
0.8017632252258386


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2517, Time-Mean RMSE=0.4533, Time-Stddev MAE=0.2388
[VAL] pr: RMSE=1.9267, Time-Mean RMSE=0.2033, Time-Stddev MAE=0.7269
0.8791851588428403


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3076, Time-Mean RMSE=0.5644, Time-Stddev MAE=0.2645
[VAL] pr: RMSE=1.9267, Time-Mean RMSE=0.1938, Time-Stddev MAE=0.7722
0.962669171150485


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2679, Time-Mean RMSE=0.4778, Time-Stddev MAE=0.2212
[VAL] pr: RMSE=1.9257, Time-Mean RMSE=0.2138, Time-Stddev MAE=0.7341
0.8913504254569613


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2313, Time-Mean RMSE=0.3931, Time-Stddev MAE=0.2020
[VAL] pr: RMSE=1.9255, Time-Mean RMSE=0.1948, Time-Stddev MAE=0.7014
0.8158275574984102


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2985, Time-Mean RMSE=0.5675, Time-Stddev MAE=0.2291
[VAL] pr: RMSE=1.9271, Time-Mean RMSE=0.2145, Time-Stddev MAE=0.7379
0.9435303845282048


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2651, Time-Mean RMSE=0.4841, Time-Stddev MAE=0.2175
[VAL] pr: RMSE=1.9375, Time-Mean RMSE=0.2595, Time-Stddev MAE=0.7489
0.9215241753263816


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2537, Time-Mean RMSE=0.4244, Time-Stddev MAE=0.1766
[VAL] pr: RMSE=1.9233, Time-Mean RMSE=0.1781, Time-Stddev MAE=0.7434
0.827204781918625


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2323, Time-Mean RMSE=0.4036, Time-Stddev MAE=0.1797
[VAL] pr: RMSE=1.9240, Time-Mean RMSE=0.1806, Time-Stddev MAE=0.7064
0.8046652674410852


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2302, Time-Mean RMSE=0.4151, Time-Stddev MAE=0.1981
[VAL] pr: RMSE=1.9327, Time-Mean RMSE=0.2440, Time-Stddev MAE=0.7606
0.8719880704471941


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2142, Time-Mean RMSE=0.3377, Time-Stddev MAE=0.2413
[VAL] pr: RMSE=1.9295, Time-Mean RMSE=0.2208, Time-Stddev MAE=0.7291
0.8304521208065451


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2966, Time-Mean RMSE=0.5596, Time-Stddev MAE=0.1934
[VAL] pr: RMSE=1.9275, Time-Mean RMSE=0.2076, Time-Stddev MAE=0.7354
0.9172728135003714


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3269, Time-Mean RMSE=0.6025, Time-Stddev MAE=0.1830
[VAL] pr: RMSE=1.9311, Time-Mean RMSE=0.2211, Time-Stddev MAE=0.7500
0.9474896632864882


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3046, Time-Mean RMSE=0.5535, Time-Stddev MAE=0.2646
[VAL] pr: RMSE=1.9325, Time-Mean RMSE=0.2406, Time-Stddev MAE=0.7604
0.9763707828007593


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2293, Time-Mean RMSE=0.3897, Time-Stddev MAE=0.1962
[VAL] pr: RMSE=1.9288, Time-Mean RMSE=0.2086, Time-Stddev MAE=0.7163
0.8238068959858635


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2457, Time-Mean RMSE=0.4443, Time-Stddev MAE=0.2156
[VAL] pr: RMSE=1.9280, Time-Mean RMSE=0.2149, Time-Stddev MAE=0.7432
0.8747964793929395


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2154, Time-Mean RMSE=0.3388, Time-Stddev MAE=0.2058
[VAL] pr: RMSE=1.9270, Time-Mean RMSE=0.1921, Time-Stddev MAE=0.7107
0.791968720725323


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2299, Time-Mean RMSE=0.3986, Time-Stddev MAE=0.2394
[VAL] pr: RMSE=1.9276, Time-Mean RMSE=0.2093, Time-Stddev MAE=0.7457
0.8611396029387024


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2257, Time-Mean RMSE=0.3682, Time-Stddev MAE=0.2222
[VAL] pr: RMSE=1.9271, Time-Mean RMSE=0.1935, Time-Stddev MAE=0.7592
0.8342981183958458


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2328, Time-Mean RMSE=0.4148, Time-Stddev MAE=0.2202
[VAL] pr: RMSE=1.9265, Time-Mean RMSE=0.1959, Time-Stddev MAE=0.7282
0.8464976088668614


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2210, Time-Mean RMSE=0.3473, Time-Stddev MAE=0.2248
[VAL] pr: RMSE=1.9326, Time-Mean RMSE=0.2208, Time-Stddev MAE=0.7941
0.8519347230261554


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3960, Time-Mean RMSE=0.7375, Time-Stddev MAE=0.1986
[VAL] pr: RMSE=1.9264, Time-Mean RMSE=0.2015, Time-Stddev MAE=0.7124
1.0020433614338797


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2884, Time-Mean RMSE=0.5279, Time-Stddev MAE=0.2267
[VAL] pr: RMSE=1.9332, Time-Mean RMSE=0.2471, Time-Stddev MAE=0.7402
0.939453990491234


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2389, Time-Mean RMSE=0.4350, Time-Stddev MAE=0.1798
[VAL] pr: RMSE=1.9261, Time-Mean RMSE=0.1982, Time-Stddev MAE=0.7378
0.8414436644597729


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2225, Time-Mean RMSE=0.3724, Time-Stddev MAE=0.2567
[VAL] pr: RMSE=1.9259, Time-Mean RMSE=0.1868, Time-Stddev MAE=0.7381
0.8421257983633665


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2480, Time-Mean RMSE=0.4092, Time-Stddev MAE=0.1877
[VAL] pr: RMSE=1.9275, Time-Mean RMSE=0.2062, Time-Stddev MAE=0.7770
0.8516956123593403


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2508, Time-Mean RMSE=0.4628, Time-Stddev MAE=0.2212
[VAL] pr: RMSE=1.9265, Time-Mean RMSE=0.2065, Time-Stddev MAE=0.7418
0.8822797269370688


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3100, Time-Mean RMSE=0.5684, Time-Stddev MAE=0.1828
[VAL] pr: RMSE=1.9286, Time-Mean RMSE=0.2211, Time-Stddev MAE=0.7422
0.9264119239160737


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2619, Time-Mean RMSE=0.4882, Time-Stddev MAE=0.2109
[VAL] pr: RMSE=1.9340, Time-Mean RMSE=0.2455, Time-Stddev MAE=0.7625
0.917988707345506


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2405, Time-Mean RMSE=0.4395, Time-Stddev MAE=0.2069
[VAL] pr: RMSE=1.9263, Time-Mean RMSE=0.2003, Time-Stddev MAE=0.7457
0.861352677673032


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2561, Time-Mean RMSE=0.4825, Time-Stddev MAE=0.1906
[VAL] pr: RMSE=1.9284, Time-Mean RMSE=0.2042, Time-Stddev MAE=0.7236
0.8692290247265875


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2556, Time-Mean RMSE=0.4638, Time-Stddev MAE=0.2656
[VAL] pr: RMSE=1.9311, Time-Mean RMSE=0.2376, Time-Stddev MAE=0.7766
0.9340501148831956


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2268, Time-Mean RMSE=0.3627, Time-Stddev MAE=0.2145
[VAL] pr: RMSE=1.9306, Time-Mean RMSE=0.2264, Time-Stddev MAE=0.7585
0.8440889501926178


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2636, Time-Mean RMSE=0.5019, Time-Stddev MAE=0.2312
[VAL] pr: RMSE=1.9267, Time-Mean RMSE=0.1958, Time-Stddev MAE=0.7344
0.8993627238844828


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2246, Time-Mean RMSE=0.3758, Time-Stddev MAE=0.2362
[VAL] pr: RMSE=1.9325, Time-Mean RMSE=0.2290, Time-Stddev MAE=0.7362
0.8544139486442246


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2362, Time-Mean RMSE=0.4146, Time-Stddev MAE=0.1962
[VAL] pr: RMSE=1.9266, Time-Mean RMSE=0.1984, Time-Stddev MAE=0.7268
0.8353035254256973


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2287, Time-Mean RMSE=0.3958, Time-Stddev MAE=0.1768
[VAL] pr: RMSE=1.9324, Time-Mean RMSE=0.2488, Time-Stddev MAE=0.7252
0.8407381097154327


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2457, Time-Mean RMSE=0.4293, Time-Stddev MAE=0.1849
[VAL] pr: RMSE=1.9284, Time-Mean RMSE=0.2180, Time-Stddev MAE=0.7466
0.8547983000372561


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2293, Time-Mean RMSE=0.3605, Time-Stddev MAE=0.1780
[VAL] pr: RMSE=1.9252, Time-Mean RMSE=0.1965, Time-Stddev MAE=0.7419
0.803409053175399


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2980, Time-Mean RMSE=0.5712, Time-Stddev MAE=0.1855
[VAL] pr: RMSE=1.9252, Time-Mean RMSE=0.2011, Time-Stddev MAE=0.7465
0.9199843099564248


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2232, Time-Mean RMSE=0.3634, Time-Stddev MAE=0.2377
[VAL] pr: RMSE=1.9273, Time-Mean RMSE=0.1986, Time-Stddev MAE=0.7504
0.8387563364377086


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2345, Time-Mean RMSE=0.4267, Time-Stddev MAE=0.1995
[VAL] pr: RMSE=1.9289, Time-Mean RMSE=0.2000, Time-Stddev MAE=0.7537
0.8538715180459944


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2744, Time-Mean RMSE=0.5114, Time-Stddev MAE=0.2086
[VAL] pr: RMSE=1.9305, Time-Mean RMSE=0.2338, Time-Stddev MAE=0.7631
0.9233044991999215


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2445, Time-Mean RMSE=0.4352, Time-Stddev MAE=0.1942
[VAL] pr: RMSE=1.9277, Time-Mean RMSE=0.2144, Time-Stddev MAE=0.7516
0.8623274938950236


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2502, Time-Mean RMSE=0.4096, Time-Stddev MAE=0.2728
[VAL] pr: RMSE=1.9326, Time-Mean RMSE=0.2164, Time-Stddev MAE=0.7123
0.8756406378407178


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2140, Time-Mean RMSE=0.3470, Time-Stddev MAE=0.1930
[VAL] pr: RMSE=1.9273, Time-Mean RMSE=0.2087, Time-Stddev MAE=0.7305
0.8053728157411871


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3945, Time-Mean RMSE=0.7300, Time-Stddev MAE=0.1889
[VAL] pr: RMSE=1.9296, Time-Mean RMSE=0.2134, Time-Stddev MAE=0.7421
1.0106521072344763


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2196, Time-Mean RMSE=0.3626, Time-Stddev MAE=0.1836
[VAL] pr: RMSE=1.9267, Time-Mean RMSE=0.1835, Time-Stddev MAE=0.7575
0.8062497255606812


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2493, Time-Mean RMSE=0.4300, Time-Stddev MAE=0.1831
[VAL] pr: RMSE=1.9240, Time-Mean RMSE=0.1913, Time-Stddev MAE=0.7325
0.8355791116785416


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2236, Time-Mean RMSE=0.3642, Time-Stddev MAE=0.2753
[VAL] pr: RMSE=1.9261, Time-Mean RMSE=0.1921, Time-Stddev MAE=0.7537
0.8559377028706142


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.2120, Time-Mean RMSE=0.3537, Time-Stddev MAE=0.2216
[VAL] pr: RMSE=1.9281, Time-Mean RMSE=0.2099, Time-Stddev MAE=0.7475
0.8299111460850972


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3055, Time-Mean RMSE=0.5850, Time-Stddev MAE=0.1827
[VAL] pr: RMSE=1.9277, Time-Mean RMSE=0.1998, Time-Stddev MAE=0.7132
0.9128885030615255


NameError: name 'exit' is not defined

# Test model

**IMPORTANT:** Please note that the test metrics will be bad because the test targets have been corrupted on the public Kaggle dataset.
The purpose of testing below is to generate the Kaggle submission file based on your model's predictions, which you can submit to the competition.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# 1) Switch to eval and pull out model + dataloader
lightning_module.eval()
model = lightning_module.model  # your ClimateTransformer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_loader = datamodule.train_dataloader()
errors = []

# 2) Loop and record per-sample error + raw arrays
with torch.no_grad():
    for idx, batch in enumerate(train_loader):
        x, y = batch[0].to(device), batch[1].to(device)
        y_hat = model(x)
        # compute MSE map, then mean over channels + spatial dims
        mse_map = F.mse_loss(y_hat, y, reduction='none')      # [B,C,H,W]
        per_sample = mse_map.mean(dim=[1,2,3])                 # [B]
        for b in range(x.size(0)):
            errors.append({
                'err': per_sample[b].item(),
                'x'  : x[b].cpu().numpy(),     # [C,H,W]
                'y'  : y[b].cpu().numpy(),
                'ŷ' : y_hat[b].cpu().numpy()
            })

# 3) Pick top-K
K = 3
errors = sorted(errors, key=lambda d: d['err'], reverse=True)[:K]

# 4) Plot them
fig, axes = plt.subplots(K, 3, figsize=(12, 4*K), constrained_layout=True)

for i, sample in enumerate(errors):
    x, y_true, y_pred, e = sample['x'], sample['y'], sample['ŷ'], sample['err']
    # collapse multi-channel by mean:
    err_map = np.mean((y_pred - y_true)**2, axis=0)  # [H,W]

    # show channel 0 of input & truth (or pick whichever)
    axes[i,0].imshow(x[0], cmap='viridis')
    axes[i,0].set_title(f"Input #{i} (err={e:.3f})")
    axes[i,0].axis('off')

    axes[i,1].imshow(y_true[0], cmap='viridis')
    axes[i,1].set_title("Ground Truth")
    axes[i,1].axis('off')

    im = axes[i,2].imshow(err_map, cmap='magma')
    axes[i,2].set_title("Per-pixel MSE")
    axes[i,2].axis('off')
    fig.colorbar(im, ax=axes[i,2], fraction=0.046, pad=0.04)

plt.show()


In [11]:
trainer.test(lightning_module, datamodule=datamodule) 

trainer.save_checkpoint("transformer_v69.ckpt")

Creating dataset with 8097 samples...
Creating dataset with 360 samples...
Creating dataset with 360 samples...


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[TEST] tas: RMSE=290.6458, Time-Mean RMSE=290.6026, Time-Stddev MAE=3.6575
[TEST] pr: RMSE=4.3190, Time-Mean RMSE=3.8195, Time-Stddev MAE=1.4322
164.3251006995555
✅ Submission saved to: submissions/kaggle_submission_20250602_034045.csv


### Plotting Utils


In [None]:
def plot_comparison(true_xr, pred_xr, title, cmap='viridis', diff_cmap='RdBu_r', metric=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    vmin = min(true_xr.min().item(), pred_xr.min().item())
    vmax = max(true_xr.max().item(), pred_xr.max().item())

    # Ground truth
    true_xr.plot(ax=axs[0], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[0].set_title(f"{title} (Ground Truth)")

    # Prediction
    pred_xr.plot(ax=axs[1], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[1].set_title(f"{title} (Prediction)")

    # Difference
    diff = pred_xr - true_xr
    abs_max = np.max(np.abs(diff))
    diff.plot(ax=axs[2], cmap=diff_cmap, vmin=-abs_max, vmax=abs_max, add_colorbar=True)
    axs[2].set_title(f"{title} (Difference) {f'- {metric:.4f}' if metric else ''}")

    plt.tight_layout()
    plt.show()


### 🖼️ Visualizing Validation Predictions

This cell loads saved validation predictions and compares them to the ground truth using spatial plots. These visualizations help you qualitatively assess your model's performance.

For each output variable (`tas`, `pr`), we visualize:

- **📈 Time-Mean Map**: The 10-year average spatial pattern for both prediction and ground truth. Helps identify long-term biases or spatial shifts.
- **📊 Time-Stddev Map**: Shows the standard deviation across time for each grid cell — useful for assessing how well the model captures **temporal variability** at each location.
- **🕓 Random Timestep Sample**: Visual comparison of prediction vs ground truth for a single month. Useful for spotting fine-grained anomalies or errors in specific months.

> These plots provide intuition beyond metrics and are useful for debugging spatial or temporal model failures.


In [None]:
# Load validation predictions
# make sure to have run the validation loop at least once
val_preds = np.load("val_preds.npy")
val_trues = np.load("val_trues.npy")

lat, lon = datamodule.get_coords()
output_vars = datamodule.output_vars
time = np.arange(val_preds.shape[0])

for i, var in enumerate(output_vars):
    pred_xr = xr.DataArray(val_preds[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    true_xr = xr.DataArray(val_trues[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

    # --- Time Mean ---
    plot_comparison(true_xr.mean("time"), pred_xr.mean("time"), f"{var} Val Time-Mean")

    # --- Time Stddev ---
    plot_comparison(true_xr.std("time"), pred_xr.std("time"), f"{var} Val Time-Stddev", cmap="plasma")

    # --- Random timestep ---
    t_idx = np.random.randint(0, len(time))
    plot_comparison(true_xr.isel(time=t_idx), pred_xr.isel(time=t_idx), f"{var} Val Sample Timestep {t_idx}")


## 🧪 Final Notes

This notebook is meant to serve as a **baseline template** — a starting point to help you get up and running quickly with the climate emulation challenge.

You are **not** required to stick to this exact setup. In fact, we **encourage** you to:

- 🔁 Build on top of the provided `DataModule`. 
- 🧠 Use your own model architectures or training pipelines that you’re more comfortable with 
- ⚗️ Experiment with ideas  
- 🥇 Compete creatively to climb the Kaggle leaderboard  
- 🙌 Most importantly: **have fun** and **learn as much as you can** along the way

This challenge simulates a real-world scientific problem, and there’s no single "correct" approach — so be curious, experiment boldly, and make it your own!


# Load And Test Pretrained Model

In [41]:
datamodule2 = ClimateDataModule(**config["data"])

lat_weights = datamodule.get_lat_weights_torch()

model2 = TemporalHierarchicalTransformerPlus(
    in_channels=len(config["data"]["input_vars"]),
    out_channels=len(config["data"]["output_vars"]),
    dropout=0.1,
    seq_len=config['data']['seq_len'],
    patch_sizes = (2,4),
)


restored = ClimateEmulationModule.load_from_checkpoint(
    "/kaggle/working/lightning_logs/version_5/checkpoints/best-final-score-53-0.7920.ckpt", # change to path where checkpoint is stored
    model=model2,
    lat_weights=lat_weights,
    learning_rate=config["training"]["lr"],
)

restored.eval()
restored.freeze()
model2 = restored.model

trainer2 = pl.Trainer(**config["trainer"])

restored.normalizer = datamodule2.normalizer

trainer2.test(restored, datamodule=datamodule2)

Creating dataset with 8098 samples...
Creating dataset with 360 samples...
Creating dataset with 360 samples...


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[TEST] tas: RMSE=290.8431, Time-Mean RMSE=290.8013, Time-Stddev MAE=3.5732
[TEST] pr: RMSE=4.3201, Time-Mean RMSE=3.8042, Time-Stddev MAE=1.4468
164.39006621576434
✅ Submission saved to: submissions/kaggle_submission_20250602_061250.csv


[{'test/tas/rmse': 290.8431091308594,
  'test/tas/time_mean_rmse': 290.80126953125,
  'test/tas/time_std_mae': 3.573169469833374,
  'test/pr/rmse': 4.3201189041137695,
  'test/pr/time_mean_rmse': 3.804246664047241,
  'test/pr/time_std_mae': 1.4468493461608887,
  'final_score': 164.3900604248047}]

In [43]:
datamodule2 = ClimateDataModule(**config["data"])

lat_weights = datamodule.get_lat_weights_torch()

model2 = TemporalHierarchicalTransformerPlus(
    in_channels=len(config["data"]["input_vars"]),
    out_channels=len(config["data"]["output_vars"]),
    dropout=0.1,
    seq_len=config['data']['seq_len'],
    patch_sizes = (2,4),
)


restored = ClimateEmulationModule.load_from_checkpoint(
    "/kaggle/working/lightning_logs/version_5/checkpoints/best-final-score-37-0.8018.ckpt", # change to path where checkpoint is stored
    model=model2,
    lat_weights=lat_weights,
    learning_rate=config["training"]["lr"],
)

restored.eval()
restored.freeze()
model2 = restored.model

trainer2 = pl.Trainer(**config["trainer"])

restored.normalizer = datamodule2.normalizer

trainer2.test(restored, datamodule=datamodule2)



Creating dataset with 8098 samples...
Creating dataset with 360 samples...
Creating dataset with 360 samples...


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[TEST] tas: RMSE=290.8242, Time-Mean RMSE=290.7821, Time-Stddev MAE=3.5295
[TEST] pr: RMSE=4.3497, Time-Mean RMSE=3.8409, Time-Stddev MAE=1.4385
164.37439953041343
✅ Submission saved to: submissions/kaggle_submission_20250602_061854.csv


[{'test/tas/rmse': 290.82415771484375,
  'test/tas/time_mean_rmse': 290.7821350097656,
  'test/tas/time_std_mae': 3.529449939727783,
  'test/pr/rmse': 4.349699020385742,
  'test/pr/time_mean_rmse': 3.8409433364868164,
  'test/pr/time_std_mae': 1.4385148286819458,
  'final_score': 164.37440490722656}]

In [42]:
!ls /kaggle/working/lightning_logs/version_5/checkpoints/

best-final-score-37-0.8018.ckpt  best-final-score-75-0.8034.ckpt
best-final-score-53-0.7920.ckpt


In [45]:
!cp /kaggle/working/lightning_logs/version_5/checkpoints/best-final-score-53-0.7920.ckpt /kaggle/working/