In [1]:
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import numpy as np

from main import ClimateEmulationDataModule, ClimateEmulationModule
from src.models import SimpleCNN

from _climate_kaggle_metric import score as kaggle_score
from src.utils import convert_predictions_to_kaggle_format

In [2]:
data_path = '../data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [3]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 4,
        "num_workers": 4
    },
    "training": {
        "lr": 1e-3,
        "weight_decay": 1e-5,
        "max_epochs": 3,
        "early_stopping_patience": 10,
        "gradient_clip_val": 1.0,
        "accumulate_grad_batches": 1
    }
})
inputs = len(config.data['input_vars'])
outputs = len(config.data['output_vars'])
lr = config.training['lr']
weight_decay = config.training['weight_decay']

In [12]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [15]:
model = SimpleCNN(
    n_input_channels = inputs,
    n_output_channels = outputs
)

In [17]:
lightning_model = ClimateEmulationModule(model, learning_rate=1e-3)
trainer = pl.Trainer(
    max_epochs=3,  # Start with a small number for testing
    accelerator="auto",
    devices="auto"
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(lightning_model, data_module)

2025-05-13 01:33:33.726173: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747100013.754112    3096 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747100013.762722    3096 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-13 01:33:33.792601: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type      | Params | Mode 
---------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [19]:
trainer.test(lightning_model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/tas/avg/monthly_rmse': 290.73590087890625,
  'test/tas/time_mean_rmse': 290.7022705078125,
  'test/tas/time_stddev_mae': 3.0436794757843018,
  'test/pr/avg/monthly_rmse': 3.855160713195801,
  'test/pr/time_mean_rmse': 3.7450969219207764,
  'test/pr/time_stddev_mae': 0.6865308880805969}]

In [33]:
model.eval()
all_preds = []
all_trues = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x)
    # Move to CPU and convert to numpy
    all_preds.append(y_pred.cpu().numpy())
    all_trues.append(y_true.cpu().numpy())

# Concatenate along the batch (time) dimension
y_pred_np = np.concatenate(all_preds, axis=0)  # shape: [360, 2, 48, 72]
y_true_np = np.concatenate(all_trues, axis=0)

# Now convert to Kaggle format
lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

submission_df = convert_predictions_to_kaggle_format(
    y_pred_np, time_coords, lat_coords, lon_coords, var_names
)
submission_df.shape

Unnamed: 0,ID,Prediction
0,t000_tas_-88.59_1.88,-1.537119
1,t000_tas_-88.59_6.88,-1.733186
2,t000_tas_-88.59_11.88,-1.678512
3,t000_tas_-88.59_16.88,-1.649922
4,t000_tas_-88.59_21.88,-1.648948


In [34]:
solution_df = convert_predictions_to_kaggle_format(
    y_true_np, time_coords, lat_coords, lon_coords, var_names
)
kaggle_val_score = kaggle_score(solution_df, submission_df, "ID")
print("Kaggle metric score:", kaggle_val_score)

100%|██████████| 2/2 [00:01<00:00,  1.58it/s]

Kaggle metric score: 1.0043490249288263





In [36]:
submission_df.to_csv("submission.csv", index=False)