In [2]:
import numpy as np
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from lightning.pytorch.callbacks import ModelCheckpoint

from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

In [3]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [4]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
    },
    "training": {
        "lr": 1e-3,
        "weight_decay": 1e-5,
        "max_epochs": 20,
        "early_stopping_patience": 10,
        "gradient_clip_val": 1.0,
        "accumulate_grad_batches": 1
    }
})

In [5]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [6]:
class BaseRNN(nn.Module):
    def __init__(self, input_channels=5, hidden_size=64, output_channels=2):
        super(BaseRNN, self).__init__()
        self.hidden_size = hidden_size

        # Each pixel's input sequence has length `input_channels`
        self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_channels)

    def forward(self, x):
        # x: [B, C, H, W] where C=5
        B, C, H, W = x.shape

        # Reshape so each pixel has a sequence of length C
        # Step 1: move C to last dimension => [B, H, W, C]
        x = x.permute(0, 2, 3, 1)

        # Step 2: flatten to a batch of sequences => [B*H*W, C, 1]
        x = x.reshape(B * H * W, C, 1)

        # Step 3: run through RNN
        out, _ = self.rnn(x)  # out: [B*H*W, C, hidden_size]

        # Step 4: take final time step's output => [B*H*W, hidden_size]
        out = out[:, -1, :]

        # Step 5: fully connected layer to output_channels => [B*H*W, output_channels]
        out = self.fc(out)

        # Step 6: reshape back to [B, output_channels, H, W]
        out = out.view(B, H, W, -1).permute(0, 3, 1, 2)
        return out


In [7]:
model = BaseRNN()
lightning_module=ClimateEmulationModule(
    model = model,
    learning_rate = 5e-4
)

In [8]:
trainer = pl.Trainer(
    max_epochs=10,
    accelerator='auto',
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(lightning_module, data_module)

2025-05-15 20:39:27.758023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747341567.783793    1339 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747341567.792417    1339 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-15 20:39:27.821137: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
-----------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/a4shi/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=39` in the `DataLoader` to improve performance.
/home/a4shi/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=39` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Produce Test Predictions

In [15]:
model.eval()
all_preds = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x.to(device)).squeeze(0)
    all_preds.append(y_pred.cpu().numpy())

y_pred_np = np.concatenate(all_preds, axis=0)
y_pred_output = data_module.normalizer.inverse_transform_output(y_pred_np)

lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

In [16]:
submission_df = convert_predictions_to_kaggle_format(
    y_pred_output, time_coords, lat_coords, lon_coords, var_names
)

In [17]:
print(submission_df.shape)

(2488320, 2)


In [18]:
submission_df.to_csv("submission.csv", index=False)