In [1]:
import numpy as np
import pandas as pd
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

from temporal_data_module import TemporalDataModule

In [2]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr'
data = xr.open_zarr(data_path)

### Plotting

In [3]:
def plot_batch_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))

    plt.plot(train_losses, label="Train Loss (per batch)", linewidth=0.8, alpha=0.8)
    plt.plot(val_losses, label="Validation Loss (per batch)", linewidth=0.8, alpha=0.8)

    plt.xlabel("Batch (time)")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss Over Time")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def plot_epoch_losses(train_epoch_losses, val_epoch_losses):
    epochs = range(1, len(train_epoch_losses) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, train_epoch_losses, label='Train Loss (per epoch)', marker='o')
    plt.plot(epochs, val_epoch_losses, label='Validation Loss (per epoch)', marker='s')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss per Epoch")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_top_mse_grid_points(y_pred_np, y_true_np, lat_coords, lon_coords, var_names, top_k=3):
    T, C, H, W = y_true_np.shape
    squared_error = (y_pred_np - y_true_np) ** 2

    top_coords = []

    for c in range(C):
        var = var_names[int(c)]
        var_error = squared_error[:, c, :, :]
        flat_error = var_error.reshape(-1)
        top_k_idx = np.argpartition(flat_error, -top_k)[-top_k:]
        top_k_idx = top_k_idx[np.argsort(flat_error[top_k_idx])[::-1]]

        for idx in top_k_idx:
            t, h, w = np.unravel_index(idx, (T, H, W))
            top_coords.append({
                "var": var,
                "t": int(t),
                "h": int(h),
                "w": int(w),
                "mse": var_error[t, h, w].item()
            })

    plt.figure(figsize=(10, 6))
    plt.imshow(np.zeros((H, W)), cmap='Greys', alpha=0.1)

    color_map = {var: color for var, color in zip(var_names, ['red', 'blue', 'green', 'orange'])}

    for entry in top_coords:
        plt.scatter(entry["w"], entry["h"], color=color_map[entry["var"]],
                    s=100, edgecolors='black')

    legend_elements = [
        Line2D([0], [0], marker='o', color='w', label=var,
               markerfacecolor=color_map[var], markersize=10, markeredgecolor='black')
        for var in var_names
    ]
    plt.legend(handles=legend_elements)
    plt.title(f"Top {top_k} Highest MSE Grid Locations per Output Variable")
    plt.xlabel("Longitude Index (W)")
    plt.ylabel("Latitude Index (H)")
    plt.xlim(0, W - 1)
    plt.ylim(H - 1, 0)
    plt.grid(True)
    plt.tight_layout()
    plt.show()


### Model

In [4]:
def area_weighted_climate_loss(y_pred, y_true, latitudes, var_type="tas"):
    """
    Custom loss approximating the competition score in a differentiable way.
    
    y_pred, y_true: tensors of shape [B, C, H, W]
    latitudes: 1D tensor of shape [H] corresponding to each row
    var_type: "tas" or "pr"
    """
    B, C, H, W = y_true.shape

    # 1. Create latitude weights [H] → [H, 1]
    lat_radians = torch.deg2rad(latitudes)
    lat_weights = torch.cos(lat_radians)
    lat_weights = lat_weights / lat_weights.sum()
    lat_weights = lat_weights.view(1, 1, H, 1)  # broadcast to [B, C, H, W]

    # 2. Area-weighted RMSE (per time step)
    mse = F.mse_loss(y_pred, y_true, reduction='none')  # [B, C, H, W]
    area_weighted_mse = (mse * lat_weights).mean()

    # 3. Mean climate RMSE (mean over time, then spatial RMSE)
    pred_mean = y_pred.mean(dim=0)  # [C, H, W]
    true_mean = y_true.mean(dim=0)  # [C, H, W]
    mean_mse = ((pred_mean - true_mean) ** 2 * lat_weights[0]).mean()

    # 4. Std climate MAE (per location stddev over time)
    pred_std = y_pred.std(dim=0)  # [C, H, W]
    true_std = y_true.std(dim=0)  # [C, H, W]
    std_mae = (torch.abs(pred_std - true_std) * lat_weights[0]).mean()

    # 5. Combine with variable-specific weights
    if var_type == "tas":
        loss = 0.1 * torch.sqrt(area_weighted_mse) + 1.0 * torch.sqrt(mean_mse) + 1.0 * std_mae
    elif var_type == "pr":
        loss = 0.1 * torch.sqrt(area_weighted_mse) + 1.0 * torch.sqrt(mean_mse) + 0.75 * std_mae
    else:
        raise ValueError("var_type must be 'tas' or 'pr'")

    return loss


In [5]:
class ConvBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout3d(dropout),

            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout3d(dropout),
        )

    def forward(self, x):
        return self.block(x)

class ClimateUNet3D(nn.Module):
    def __init__(self, in_channels=5, out_channels=2, base_channels=32, dropout=0.1):
        super().__init__()

        # Encoder
        self.enc1 = ConvBlock3D(in_channels, base_channels, dropout)
        self.pool1 = nn.MaxPool3d((1, 2, 2))

        self.enc2 = ConvBlock3D(base_channels, base_channels * 2, dropout)
        self.pool2 = nn.MaxPool3d((1, 2, 2))

        self.enc3 = ConvBlock3D(base_channels * 2, base_channels * 4, dropout)
        self.pool3 = nn.MaxPool3d((1, 2, 2))

        # Bottleneck
        self.bottleneck = ConvBlock3D(base_channels * 4, base_channels * 8, dropout)

        # Decoder
        self.up3 = nn.ConvTranspose3d(base_channels * 8, base_channels * 4, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        self.dec3 = ConvBlock3D(base_channels * 8, base_channels * 4, dropout)

        self.up2 = nn.ConvTranspose3d(base_channels * 4, base_channels * 2, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        self.dec2 = ConvBlock3D(base_channels * 4, base_channels * 2, dropout)

        self.up1 = nn.ConvTranspose3d(base_channels * 2, base_channels, kernel_size=(1, 2, 2), stride=(1, 2, 2))
        self.dec1 = ConvBlock3D(base_channels * 2, base_channels, dropout)

        # Final output
        self.out_conv = nn.Conv3d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = x.permute(0, 2, 1, 3, 4)  # [B, C, t, H, W]

        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)

        e2 = self.enc2(p1)
        p2 = self.pool2(e2)

        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        # Bottleneck
        b = self.bottleneck(p3)

        # Decoder
        u3 = self.up3(b)
        u3 = torch.cat([u3, e3], dim=1)
        d3 = self.dec3(u3)

        u2 = self.up2(d3)
        u2 = torch.cat([u2, e2], dim=1)
        d2 = self.dec2(u2)

        u1 = self.up1(d2)
        u1 = torch.cat([u1, e1], dim=1)
        d1 = self.dec1(u1)

        out = self.out_conv(d1)
        out = out[:, :, -1]
        
        return out


In [6]:
class TrainingMetrics(Callback):
    def __init__(self):
        self.train_epoch_losses = []
        self.val_epoch_losses = []

        self._train_losses = []
        self._val_losses = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        loss = outputs['loss'] if isinstance(outputs, dict) else outputs
        self._train_losses.append(loss.item())

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        loss = outputs['loss'] if isinstance(outputs, dict) else outputs
        self._val_losses.append(loss.item())

    def on_train_epoch_end(self, trainer, pl_module):
        if self._train_losses:
            avg = sum(self._train_losses) / len(self._train_losses)
            self.train_epoch_losses.append(avg)
            self._train_losses.clear()

    def on_validation_epoch_end(self, trainer, pl_module):
        if self._val_losses:
            avg = sum(self._val_losses) / len(self._val_losses)
            self.val_epoch_losses.append(avg)
            self._val_losses.clear()


### Training

In [7]:
config_temp = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas"],  # Only temperature
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 32,
        "num_workers": 39,
    }
})

In [8]:
config_precip = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "CH4", "BC", "rsdt"],
        "output_vars": ["pr"],  # Only precipitation
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 32,
        "num_workers": 39,
    }
})

In [9]:
data_module_temp = TemporalDataModule(**config_temp.data)
data_module_temp.setup()

In [10]:
data_module_precip = TemporalDataModule(**config_precip.data)
data_module_precip.setup()

In [11]:
# Temperature Model
in_channels_temp = len(config_temp.data['input_vars']) + 1
out_channels_temp = len(config_temp.data['output_vars'])

model_temp = ClimateUNet3D(
    base_channels=64,
    in_channels=in_channels_temp,
    out_channels=out_channels_temp
)

lightning_module_temp = ClimateEmulationModule(
    model=model_temp,
    learning_rate=5e-4
)

In [None]:
# Precipitation Model
in_channels_precip = len(config_precip.data['input_vars']) + 1
out_channels_precip = len(config_precip.data['output_vars'])

model_precip = ClimateUNet3D(
    base_channels=64,
    in_channels=in_channels_precip,
    out_channels=out_channels_precip
)

lightning_module_precip = ClimateEmulationModule(
    model=model_precip,
    learning_rate=5e-4
)

In [None]:
logger_temp = TrainingMetrics()
trainer_temp = pl.Trainer(
    max_epochs=10,
    accelerator='auto',
    callbacks=[logger_temp]
)

trainer_temp.fit(lightning_module_temp, data_module_temp)

In [None]:
logger_precip = TrainingMetrics()
trainer_precip = pl.Trainer(
    max_epochs=10,
    accelerator='auto',
    callbacks=[logger_precip]
)

trainer_precip.fit(lightning_module_precip, data_module_precip)


In [None]:
train_losses = logger_temp.train_epoch_losses
val_losses = logger_temp.val_epoch_losses
min_len = min(len(train_losses), len(val_losses))

plot_epoch_losses(train_losses[:min_len], val_losses[:min_len])

# trainer_temp.test(lightning_module_temp, data_module_temp)

In [None]:
train_losses = logger_precip.train_epoch_losses
val_losses = logger_precip.val_epoch_losses
min_len = min(len(train_losses), len(val_losses))

plot_epoch_losses(train_losses[:min_len], val_losses[:min_len])

# trainer_precip.test(lightning_module_precip, data_module_precip)

### Produce Submission

In [None]:
lightning_module_temp.eval()
all_preds_temp = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_temp = lightning_module_temp.model.to(device)

for x, y_true in data_module_temp.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model_temp(x.to(device)).squeeze(0)
        all_preds_temp.append(y_pred.cpu().numpy())

y_pred_np_temp = np.concatenate(all_preds_temp, axis=0)
y_pred_output_temp = data_module_temp.normalizer.inverse_transform_output(y_pred_np_temp)

lat_coords, lon_coords = data_module_temp.get_coords()
time_coords = np.arange(y_pred_np_temp.shape[0])
var_names_temp = config_temp.data['output_vars']

submission_temp = convert_predictions_to_kaggle_format(
    y_pred_output_temp, time_coords, lat_coords, lon_coords, var_names_temp
)

submission_temp.to_csv("submission_temp.csv", index=False)

lightning_module_precip.eval()
all_preds_precip = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_precip = lightning_module_precip.model.to(device)

for x, y_true in data_module_precip.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model_precip(x.to(device)).squeeze(0)
        all_preds_precip.append(y_pred.cpu().numpy())

y_pred_np_precip = np.concatenate(all_preds_precip, axis=0)
y_pred_output_precip = data_module_precip.normalizer.inverse_transform_output(y_pred_np_precip)

lat_coords, lon_coords = data_module_precip.get_coords()
time_coords = np.arange(y_pred_np_precip.shape[0])
var_names_precip = config_precip.data['output_vars']

submission_precip = convert_predictions_to_kaggle_format(
    y_pred_output_precip, time_coords, lat_coords, lon_coords, var_names_precip
)

submission_precip.to_csv("submission_precip.csv", index=False)


In [None]:
submission_final = pd.concat([submission_temp, submission_precip], axis=0)
submission_final.to_csv("submission_final.csv", index=False)
