In [1]:
import numpy as np
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from lightning.pytorch.callbacks import ModelCheckpoint, Callback


from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

In [2]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [3]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 32,
        "num_workers": 39
    }
})

In [4]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [5]:
class ClimateRNNNet(nn.Module):
    def __init__(self, input_channels=5, conv_channels=16, rnn_hidden=64, output_channels=2,
                 dropout_prob=0.2):
        super(ClimateRNNNet, self).__init__()

        # Step 1: CNN encoder for spatial feature extraction at each time step
        self.encoder = nn.Sequential(
            nn.Conv2d(1, conv_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(conv_channels),
            nn.ReLU(),
            nn.Conv2d(conv_channels, conv_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(conv_channels),
            nn.ReLU()
        )

        self.rnn_hidden = rnn_hidden
        self.rnn = nn.RNN(input_size=conv_channels, hidden_size=rnn_hidden, batch_first=True)

        # Dropout after RNN
        self.dropout = nn.Dropout(dropout_prob)

        # Project RNN output back to spatial channels
        self.proj = nn.Sequential(
            nn.Linear(rnn_hidden, conv_channels),
            nn.BatchNorm1d(conv_channels),
            nn.ReLU()
        )

        # Step 4: CNN decoder to produce final output
        self.decoder = nn.Sequential(
            nn.Conv2d(conv_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(dropout_prob),
            nn.Conv2d(32, output_channels, kernel_size=1)
        )

    def forward(self, x):
        # x: [B, 5, H, W]
        B, C, H, W = x.shape
        assert C == 5

        # Step 1: CNN encoder for each timestep
        encoded = []
        for t in range(C):
            xt = x[:, t:t+1, :, :]  # [B, 1, H, W]
            feat = self.encoder(xt)  # [B, conv_channels, H, W]
            encoded.append(feat)

        # Step 2: Stack along time and prepare for RNN
        x_seq = torch.stack(encoded, dim=1)          # [B, 5, conv_channels, H, W]
        x_seq = x_seq.permute(0, 3, 4, 1, 2)          # [B, H, W, 5, conv_channels]
        x_seq = x_seq.reshape(B * H * W, C, -1)       # [B*H*W, 5, conv_channels]

        # Step 3: RNN
        h0 = torch.zeros(1, x_seq.size(0), self.rnn_hidden, device=x.device)
        rnn_out, _ = self.rnn(x_seq, h0)              # [B*H*W, 5, rnn_hidden]
        final = rnn_out[:, -1, :]                     # [B*H*W, rnn_hidden]
        final = self.dropout(final)

        # Step 4: Project and reshape
        out = self.proj(final)                        # [B*H*W, conv_channels]
        out = out.view(B, H, W, -1).permute(0, 3, 1, 2)  # [B, conv_channels, H, W]

        # Step 5: Decode to output
        out = self.decoder(out)                       # [B, 2, H, W]
        return out


In [6]:
model = ClimateRNNNet()
lightning_module=ClimateEmulationModule(
    model = model,
    learning_rate = 5e-4
)

In [7]:
trainer = pl.Trainer(
    max_epochs=20,
    accelerator='auto',
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(lightning_module, data_module)

2025-05-16 00:11:32.887910: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747354292.914980    4741 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747354292.923460    4741 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-16 00:11:32.953109: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type          | Params | Mode 
-----------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


In [9]:
# test = trainer.test(lightning_module, data_module)[0]
# test_vals = np.array([val for val in test.values()]).reshape(2, 3).T
# var_weights = [0.5, 0.5]
# metric_weights = np.array([
#     [0.1, 1.0, 1.0],
#     [0.1, 1.0, 0.75]
# ]).T
# np.dot((test_vals * metric_weights).sum(axis=0), var_weights)

## Produce Test Predictions

In [10]:
model.eval()
all_preds = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x.to(device)).squeeze(0)
    all_preds.append(y_pred.cpu().numpy())

y_pred_np = np.concatenate(all_preds, axis=0)
y_pred_output = data_module.normalizer.inverse_transform_output(y_pred_np)

lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

In [11]:
submission_df = convert_predictions_to_kaggle_format(
    y_pred_output, time_coords, lat_coords, lon_coords, var_names
)

In [12]:
print(submission_df.shape)

(2488320, 2)


In [13]:
submission_df.to_csv("submission.csv", index=False)