In [1]:
import numpy as np
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from lightning.pytorch.callbacks import ModelCheckpoint, Callback


from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

In [2]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [3]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 32,
        "num_workers": 39
    }
})

In [4]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [17]:
class Climate3DCNN(nn.Module):
    def __init__(self, input_channels=5, output_channels=2):
        super(Climate3DCNN, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU()
        )

        # Decoder collapses temporal dimension
        self.decoder = nn.Sequential(
            nn.Conv3d(32, 16, kernel_size=(5, 1, 1)),  # collapse temporal dim
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, output_channels, kernel_size=1)
        )

    def forward(self, x):
        # x: [B, 5, H, W] → treat as 3D input with 1 channel
        B, C, H, W = x.shape
        x = x.unsqueeze(1)  # → [B, 1, 5, H, W]

        x = self.encoder(x)  # → [B, 32, 5, H, W]

        # Collapse temporal dimension via 3D conv → [B, 16, H, W]
        x = self.decoder[0](x).squeeze(2)  # first Conv3d then squeeze time dim

        # Pass through 2D decoder → [B, output_channels, H, W]
        x = self.decoder[1:](x)
        return x


In [18]:
model = Climate3DCNN()
lightning_module=ClimateEmulationModule(
    model = model,
    learning_rate = 5e-4
)

In [19]:
trainer = pl.Trainer(
    max_epochs=20,
    accelerator='auto',
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [20]:
trainer.fit(lightning_module, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type         | Params | Mode 
---------------------------------------------------
0 | model     | Climate3DCNN | 21.7 K | train
1 | criterion | MSELoss      | 0      | train
---------------------------------------------------
21.7 K    Trainable params
0         Non-trainable params
21.7 K    Total params
0.087     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


 92/92 [00:04<00:00, 22.10it/s, v_num=94, train/loss=0.461, val/loss=0.712, val/tas/avg/monthly_rmse=10.60, val/tas/time_mean_rmse=6.810, val/tas/time_stddev_mae=4.390, val/pr/avg/monthly_rmse=3.450, val/pr/time_mean_rmse=1.910, val/pr/time_stddev_mae=1.550]

In [21]:
# test = trainer.test(lightning_module, data_module)[0]
# test_vals = np.array([val for val in test.values()]).reshape(2, 3).T
# var_weights = [0.5, 0.5]
# metric_weights = np.array([
#     [0.1, 1.0, 1.0],
#     [0.1, 1.0, 0.75]
# ]).T
# np.dot((test_vals * metric_weights).sum(axis=0), var_weights)

## Produce Test Predictions

In [22]:
model.eval()
all_preds = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x.to(device)).squeeze(0)
    all_preds.append(y_pred.cpu().numpy())

y_pred_np = np.concatenate(all_preds, axis=0)
y_pred_output = data_module.normalizer.inverse_transform_output(y_pred_np)

lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

In [None]:
submission_df = convert_predictions_to_kaggle_format(
    y_pred_output, time_coords, lat_coords, lon_coords, var_names
)

In [None]:
print(submission_df.shape)

In [None]:
submission_df.to_csv("submission.csv", index=False)