In [1]:
import numpy as np
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from pytorch_lightning.callbacks import ModelCheckpoint

from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

In [2]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [3]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
    },
    "training": {
        "lr": 1e-3,
        "weight_decay": 1e-5,
        "max_epochs": 3,
        "early_stopping_patience": 10,
        "gradient_clip_val": 1.0,
        "accumulate_grad_batches": 1
    }
})

In [4]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [5]:
# model = SimpleCNN(
#     n_input_channels = len(config.data['input_vars']),
#     n_output_channels = len(config.data['output_vars'])
# )

In [6]:
class EncoderBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True, dropout_rate=0.1):
        super().__init__()
        self.downsample = downsample
        
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.dropout1 = nn.Dropout3d(p=dropout_rate)
        
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.dropout2 = nn.Dropout3d(p=dropout_rate)
        
        # Downsample conv (if needed)
        if downsample:
            self.down = nn.Conv3d(out_channels, out_channels, kernel_size=(1,3,3), 
                                padding=(0,1,1), stride=(1,2,2))
        
        # Residual connection
        self.residual = nn.Conv3d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
        
    def forward(self, x):
        identity = self.residual(x)
        
        # Main path
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.leaky_relu(x, negative_slope=0.2)  # LeakyReLU instead of ReLU
        x = self.dropout1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, negative_slope=0.2)
        x = self.dropout2(x)
        
        # Add residual
        x = x + identity
        
        # Store pre-downsample output for skip connection
        skip = x
        
        # Downsample if needed
        if self.downsample:
            x = self.down(x)
        
        return x, skip

class DecoderBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=True, dropout_rate=0.1):
        super().__init__()
        self.upsample = upsample
        
        # Upsample conv (if needed)
        if upsample:
            self.up = nn.ConvTranspose3d(in_channels, in_channels, kernel_size=(1,3,3),
                                       padding=(0,1,1), stride=(1,2,2), output_padding=(0,1,1))
        
        # First convolution block (after concatenation)
        self.conv1 = nn.Conv3d(in_channels * 2, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.dropout1 = nn.Dropout3d(p=dropout_rate)
        
        # Second convolution block
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.dropout2 = nn.Dropout3d(p=dropout_rate)
        
    def forward(self, x, skip=None):
        # Upsample if needed
        if self.upsample:
            x = self.up(x)
        
        # Concatenate skip connection if provided
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        
        # Main path
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.leaky_relu(x, negative_slope=0.2)
        x = self.dropout1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, negative_slope=0.2)
        x = self.dropout2(x)
        
        return x

class EncoderDecoder3DCNN(nn.Module):
    def __init__(
        self,
        n_input_channels=5,
        n_output_channels=2,
        base_channels=64,
        dropout_rate=0.1
    ):
        super().__init__()
        
        # Store n_output_channels as instance variable
        self.n_output_channels = n_output_channels
        
        # Encoder
        self.enc1 = EncoderBlock3D(n_input_channels, base_channels, downsample=True, dropout_rate=dropout_rate)
        self.enc2 = EncoderBlock3D(base_channels, base_channels * 2, downsample=True, dropout_rate=dropout_rate)
        self.enc3 = EncoderBlock3D(base_channels * 2, base_channels * 4, downsample=False, dropout_rate=dropout_rate)
        
        # Enhanced Bottleneck with dropout
        self.bottleneck = nn.Sequential(
            nn.Conv3d(base_channels * 4, base_channels * 8, kernel_size=3, padding=1),
            nn.BatchNorm3d(base_channels * 8),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout3d(p=dropout_rate),
            
            nn.Conv3d(base_channels * 8, base_channels * 8, kernel_size=3, padding=1),
            nn.BatchNorm3d(base_channels * 8),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout3d(p=dropout_rate),
            
            nn.Conv3d(base_channels * 8, base_channels * 4, kernel_size=3, padding=1),
            nn.BatchNorm3d(base_channels * 4),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout3d(p=dropout_rate)
        )
        
        # Decoder
        self.dec3 = DecoderBlock3D(base_channels * 4, base_channels * 2, upsample=False, dropout_rate=dropout_rate)
        self.dec2 = DecoderBlock3D(base_channels * 2, base_channels, upsample=True, dropout_rate=dropout_rate)
        self.dec1 = DecoderBlock3D(base_channels, base_channels, upsample=True, dropout_rate=dropout_rate)
        
        # Final layers - simplified to ensure exact number of output channels
        self.final_conv = nn.Sequential(
            nn.Conv3d(base_channels, base_channels // 2, kernel_size=3, padding=1),
            nn.BatchNorm3d(base_channels // 2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Dropout3d(p=dropout_rate/2),
            nn.Conv3d(base_channels // 2, n_output_channels, kernel_size=1)
        )
        
    def forward(self, x):
        # Add batch dimension if not present
        if len(x.shape) == 4:
            x = x.unsqueeze(0)
            
        # Rearrange from [B, D, C, H, W] to [B, C, D, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        # Store input shape for debugging
        input_shape = x.shape
        
        # Encoder
        x, skip1 = self.enc1(x)
        x, skip2 = self.enc2(x)
        x, skip3 = self.enc3(x)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder (with skip connections)
        x = self.dec3(x, skip3)
        x = self.dec2(x, skip2)
        x = self.dec1(x, skip1)
        
        # Final convolutions
        x = self.final_conv(x)
        
        # Verify output channels match expected number
        assert x.shape[1] == self.n_output_channels, f"Expected {self.n_output_channels} output channels but got {x.shape[1]}"
        
        # Verify output spatial dimensions match input
        assert x.shape[-3:] == input_shape[-3:], f"Output shape {x.shape} doesn't match input shape {input_shape}"
        
        # Permute back to original format [B, D, C, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        # Remove batch dimension if it was added
        if len(input_shape) == 4:
            x = x.squeeze(0)
            
        return x


In [7]:
class myLightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        learning_rate: float = 1e-3,
        weight_decay: float = 0.01,
        scheduler_type: str = 'plateau'  # 'plateau' or 'cosine'
    ):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=["model"])
        self.criterion = nn.MSELoss()
        self.normalizer = None
        self.training_step_outputs = []
        self.current_epoch_losses = []
        self.validation_step_outputs = []
        
    def forward(self, x):
        return self.model(x)
    
    def on_fit_start(self):
        self.normalizer = self.trainer.datamodule.normalizer
    
    def training_step(self, batch, batch_idx):
        x, y_true_norm = batch
        y_pred_norm = self(x)
        loss = self.criterion(y_pred_norm, y_true_norm)
        
        # Store loss for epoch end logging
        self.current_epoch_losses.append(loss.item())
        
        # Log loss and learning rate
        self.log("train/loss", loss, prog_bar=True, batch_size=x.size(0))
        self.log("train/lr", self.optimizers().param_groups[0]['lr'], prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        # Calculate and print average loss for the epoch
        avg_loss = np.mean(self.current_epoch_losses)
        print(f"\nEpoch {self.current_epoch} - Average training loss: {avg_loss:.6f}")
        self.current_epoch_losses = []  # Reset for next epoch
    
    def validation_step(self, batch, batch_idx):
        x, y_true_norm = batch
        y_pred_norm = self(x)
        loss = self.criterion(y_pred_norm, y_true_norm)
        
        # Store predictions for epoch end calculations
        self.validation_step_outputs.append({
            'loss': loss.item(),
            'y_pred': y_pred_norm.detach(),
            'y_true': y_true_norm.detach()
        })
        
        # Log validation loss
        self.log("val/loss", loss, prog_bar=True, batch_size=x.size(0), sync_dist=True)
        return loss
    
    def on_validation_epoch_end(self):
        # Calculate average validation loss
        val_losses = [x['loss'] for x in self.validation_step_outputs]
        avg_val_loss = np.mean(val_losses)
        
        # Log validation metrics
        self.log("val/avg_loss", avg_val_loss, prog_bar=True)
        
        # Clear stored outputs
        self.validation_step_outputs.clear()
    
    def configure_optimizers(self):
        # Create optimizer with weight decay
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        # Configure scheduler
        if self.hparams.scheduler_type == 'plateau':
            scheduler = {
                'scheduler': ReduceLROnPlateau(
                    optimizer,
                    mode='min',
                    factor=0.5,
                    patience=5,
                    verbose=True
                ),
                'monitor': 'val/loss',
                'interval': 'epoch',
                'frequency': 1
            }
        else:  # cosine
            scheduler = {
                'scheduler': CosineAnnealingWarmRestarts(
                    optimizer,
                    T_0=10,  # Restart every 10 epochs
                    T_mult=2,  # Double the restart interval after each restart
                    eta_min=1e-6  # Minimum learning rate
                ),
                'interval': 'epoch',
                'frequency': 1
            }
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler
        } 

In [8]:
model = EncoderDecoder3DCNN(
    n_input_channels=5,
    n_output_channels=2,
    base_channels=64,
    dropout_rate=0.15
)
lightning_module = myLightningModule(
    model=model,
    learning_rate=5e-4,
    weight_decay=0.01,
    scheduler_type='cosine'
)

In [9]:
trainer = pl.Trainer(
    max_epochs=5,
    accelerator='auto',
    devices=1,
    callbacks=[
        pl.callbacks.EarlyStopping(
            monitor='val/loss',
            patience=10,
            mode='min'
        )
    ]
)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(lightning_module, data_module)

You are using a CUDA device ('NVIDIA A30 MIG 2g.12gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-05-15 05:16:33.646043: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747286193.661353    1329 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747286193.666213    1329 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-15 05:16:33.683850: I tensorflow/core/platform/cpu_feature_guard.c

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/a4shi/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
  return F.mse_loss(input, target, reduction=self.reduction)
/home/a4shi/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Validation: |          | 0/? [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)



Epoch 0 - Average training loss: 0.501503


Validation: |          | 0/? [00:00<?, ?it/s]


Epoch 1 - Average training loss: 0.366051


Validation: |          | 0/? [00:00<?, ?it/s]


Epoch 2 - Average training loss: 0.336115


Validation: |          | 0/? [00:00<?, ?it/s]


Epoch 3 - Average training loss: 0.327124


Validation: |          | 0/? [00:00<?, ?it/s]


Epoch 4 - Average training loss: 0.316043


`Trainer.fit` stopped: `max_epochs=5` reached.


In [16]:
model.eval()
all_preds = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x.to(device)).squeeze(0)
    all_preds.append(y_pred.cpu().numpy())

y_pred_np = np.concatenate(all_preds, axis=0)
y_pred_output = data_module.normalizer.inverse_transform_output(y_pred_np)

lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

In [17]:
submission_df = convert_predictions_to_kaggle_format(
    y_pred_output, time_coords, lat_coords, lon_coords, var_names
)

In [18]:
print(submission_df.shape)

(2488320, 2)


In [19]:
submission_df.to_csv("submission.csv", index=False)