In [None]:
import zarr
import xarray as xr
import matplotlib.pyplot as plt
import lightning.pytorch as pl
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from main import ClimateEmulationDataModule, ClimateEmulationModule
from _climate_kaggle_metric import score as kaggle_score

from src.models import SimpleCNN
from src.utils import convert_predictions_to_kaggle_format

In [None]:
data_path = 'data/processed_data_cse151b_v2_corrupted_ssp245.zarr'

In [None]:
config = OmegaConf.create({
    "data": {
        "path": data_path,
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "target_member_id": 0,
        "batch_size": 4,
        "num_workers": 4
    },
    "training": {
        "lr": 1e-3,
        "weight_decay": 1e-5,
        "max_epochs": 3,
        "early_stopping_patience": 10,
        "gradient_clip_val": 1.0,
        "accumulate_grad_batches": 1
    }
})
inputs = len(config.data['input_vars'])
outputs = len(config.data['output_vars'])

In [None]:
data_module = ClimateEmulationDataModule(**config.data)
data_module.setup()

In [None]:
# model = SimpleCNN(
#     n_input_channels = inputs,
#     n_output_channels = outputs
# )

In [None]:
class myNN(nn.Module):
    def __init__(
        self,
        n_input_channels=5,
        n_output_channels=2,
        hidden_channels=32
    ):
        super().__init__()
        self.conv1 = nn.Conv3d(n_input_channels, hidden_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(hidden_channels)
        self.conv2 = nn.Conv3d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(hidden_channels)
        self.conv3 = nn.Conv3d(hidden_channels, n_output_channels, kernel_size=1)
        
    def forward(self, x):
        # Add batch dimension if not present
        if len(x.shape) == 4:
            x = x.unsqueeze(0)  # Add batch dimension at the start
        
        # Rearrange from [B, D, C, H, W] to [B, C, D, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.conv3(x)
        
        # Permute back to original format [B, D, C, H, W]
        x = x.permute(0, 2, 1, 3, 4)
        
        x = x.squeeze(0)
        return x

In [None]:
class myLightningModule(pl.LightningModule):
    def __init__(self, model: nn.Module, learning_rate: float):
        super().__init__()
        self.model = model
        self.save_hyperparameters(ignore=["model"])
        self.criterion = nn.MSELoss()
        self.normalizer = None
        self.training_step_outputs = []
        self.current_epoch_losses = []
        
    def forward(self, x):
        return self.model(x)
    
    def on_fit_start(self):
        self.normalizer = self.trainer.datamodule.normalizer
    
    def training_step(self, batch, batch_idx):
        x, y_true_norm = batch
        y_pred_norm = self(x)
        loss = self.criterion(y_pred_norm, y_true_norm)
        
        # Store loss for epoch end logging
        self.current_epoch_losses.append(loss.item())
        
        # Log loss for progress bar
        self.log("train/loss", loss, prog_bar=True, batch_size=x.size(0))
        return loss
    
    def on_train_epoch_end(self):
        # Calculate and print average loss for the epoch
        avg_loss = np.mean(self.current_epoch_losses)
        print(f"\nEpoch {self.current_epoch} - Average training loss: {avg_loss:.6f}")
        self.current_epoch_losses = []  # Reset for next epoch
    
    def validation_step(self, batch, batch_idx):
        x, y_true_norm = batch
        y_pred_norm = self(x)
        loss = self.criterion(y_pred_norm, y_true_norm)
        self.log("val/loss", loss, prog_bar=True, batch_size=x.size(0))
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer 

In [None]:
model = myNN(
    n_input_channels=5,
    n_output_channels=2,
    hidden_channels=32
)

In [None]:
lightning_model = myLightningModule(model, learning_rate=1e-3)
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto",
    devices="auto"
)

In [None]:
trainer.fit(lightning_model, data_module)

In [None]:
model.eval()
all_preds = []
all_trues = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

for x, y_true in data_module.test_dataloader():
    x = x.to(device)
    with torch.no_grad():
        y_pred = model(x)
    all_preds.append(y_pred.cpu().numpy())
    all_trues.append(y_true.cpu().numpy())

y_pred_np = np.concatenate(all_preds, axis=0)
y_true_np = np.concatenate(all_trues, axis=0)

lat_coords, lon_coords = data_module.get_coords()
time_coords = np.arange(y_pred_np.shape[0])
var_names = config.data['output_vars']

submission_df = convert_predictions_to_kaggle_format(
    y_pred_np, time_coords, lat_coords, lon_coords, var_names
)
submission_df.shape

In [None]:
solution_df = convert_predictions_to_kaggle_format(
    y_true_np, time_coords, lat_coords, lon_coords, var_names
)
kaggle_val_score = kaggle_score(solution_df, submission_df, "ID")
print("Kaggle metric score:", kaggle_val_score)

In [None]:
solution_df = convert_predictions_to_kaggle_format(
    y_true_np, time_coords, lat_coords, lon_coords, var_names
)
kaggle_val_score = kaggle_score(solution_df, submission_df, "ID")
print("Kaggle metric score:", kaggle_val_score)

In [49]:
submission_df.to_csv("submission.csv", index=False)