# Load data

In [1]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

dm = DataModule(
    dataset="ERA5",
    task="forecasting",
    root_dir="../climate-learn/data/weatherbench/era5/5.625",
    in_vars=["2m_temperature"],
    out_vars=["2m_temperature"],
    train_start_year=Year(1979),
    val_start_year=Year(1980),
    test_start_year=Year(1981),
    end_year=Year(1982),
    pred_range=Days(3),
    subsample=Hours(6),
    batch_size=128,
    num_workers=8
)

Creating train dataset


  0%|          | 0/1 [00:00<?, ?it/s]

Creating val dataset


  0%|          | 0/1 [00:00<?, ?it/s]

Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

Create BayesianCNN.

In [2]:
dm.train_dataset[0][0].shape

torch.Size([1, 1, 32, 64])

# Build BayesCNN

In [3]:
# Scratch cell for computing output sizes
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

import numpy as np

hout = np.floor((29 + 2 * 0 - 1 * (3 - 1) - 1) / 1 + 1)
wout = np.floor((61 + 2 * 0 - 1 * (3 - 1) - 1) / 1 + 1)

print(hout, wout)

27.0 59.0


In [81]:
from layers import *
import torch.nn as nn

# https://arxiv.org/pdf/1901.02731.pdf

class BayesCNN(ModuleWrapper):
    def __init__(self):
        super().__init__()
        self.conv1 = BBB_Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.act1 = nn.Softplus()
        self.bn1 = nn.BatchNorm2d(3)
        self.conv2 = BBB_Conv2d(
            in_channels=3,
            out_channels=5,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.act2 = nn.Softplus()
        self.bn2 = nn.BatchNorm2d(5)
        self.conv3 = BBB_Conv2d(
            in_channels=5,
            out_channels=3,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.act3 = nn.Softplus()
        self.bn3 = nn.BatchNorm2d(3)
        self.conv4 = BBB_Conv2d(
            in_channels=3,
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1
        )

In [82]:
import torch

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"
    
print(device)

cuda:0


In [83]:
# Test it
# outputs are (logits, kl)
net = BayesCNN().cuda()
logits, _ = net(dm.train_dataset[0][0].to(device))

In [84]:
logits

tensor([[[[-1.1002e-02,  4.5736e-01,  8.6272e-02,  ..., -1.6405e-01,
            4.7771e-04, -9.2713e-02],
          [-6.0578e-01, -8.3075e-02, -1.5220e-01,  ..., -5.1696e-01,
            5.3144e-01,  6.2323e-01],
          [-6.2265e-01, -6.1950e-01, -1.4407e+00,  ..., -1.1370e+00,
           -1.2253e-01, -2.7827e-01],
          ...,
          [-3.7115e-01,  5.3484e-01, -5.0040e-01,  ..., -8.6558e-01,
           -4.5810e-01, -6.9343e-02],
          [-5.9130e-01, -1.6491e-01, -6.4190e-01,  ..., -8.3902e-01,
            3.9776e-01,  9.0848e-01],
          [-1.4385e-01,  2.9654e-01, -4.2162e-01,  ..., -9.0225e-01,
           -2.1108e-01, -3.7161e-01]]]], device='cuda:0',
       grad_fn=<ConvolutionBackward0>)

In [85]:
logits.detach().cpu().size()

torch.Size([1, 1, 32, 64])

In [100]:
import pytorch_lightning as pl
from torchvision.transforms import transforms
from climate_learn.models.modules.utils.metrics import (
    lat_weighted_mse,
    lat_weighted_rmse,
    lat_weighted_acc,
)
from sklearn.linear_model import Ridge

class LitModule(pl.LightningModule):
    def __init__(self, kl_weight):
        super().__init__()
        self.net = BayesCNN()
        # function wrapper to include kl weighting
        def wkl_loss(loss_func):
            def wrap(*args, **kwargs):
                kl_loss = kl_weight * kwargs.pop("kl")
                result = loss_func(*args, **kwargs)
                key_map = {}
                for key in result:
                    key_map[key] = f"wkl_{key}"
                for old_key, new_key in key_map.items():
                    result[new_key] = result[old_key] + kl_loss
                    result.pop(old_key)
                return result
            return wrap
        # function wrapper to disregard kl
        def non_wkl_loss(loss_func):
            def wrap(*args, **kwargs):
                if "kl" in kwargs:
                    kwargs.pop("kl")
                result = loss_func(*args, **kwargs)
                return result
            return wrap
        self.train_loss = [
            non_wkl_loss(lat_weighted_mse),
            wkl_loss(lat_weighted_mse)
        ]
        self.val_loss = [
            lat_weighted_rmse,
            lat_weighted_acc
        ]
        
    def forward(self, x):
        logits, _ = self.net(torch.squeeze(x, 1))
        return logits
    
    def predict(self, x):
        logits, kl = self.net(torch.squeeze(x, 1))
        return logits, kl
    
    def set_denormalization(self, mean, std):
        self.denormalization = transforms.Normalize(mean, std)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

    def set_lat_lon(self, lat, lon):
        self.lat = lat
        self.lon = lon

    def set_pred_range(self, r):
        self.pred_range = r

    def set_train_climatology(self, clim):
        self.train_clim = clim

    def set_val_climatology(self, clim):
        self.val_clim = clim

    def set_test_climatology(self, clim):
        self.test_clim = clim
    
    def training_step(self, batch, batch_idx):
        x, y, _, out_variables = batch
        y_hat, kl = self.predict(x)
        loss_dict = [
            m(y_hat, y, out_variables, lat=self.lat, kl=kl)
            for m in self.train_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "train/" + var,
                loss_dict[var],
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def validation_step(self, batch, batch_idx):
        x, y, variables, out_variables = batch
        pred_steps = y.shape[1]
        pred_range = self.pred_range.hours()
        
        default_days = [1, 3, 5]
        days_each_step = pred_range / 24
        default_steps = [
            d / days_each_step for d in default_days if d % days_each_step == 0
        ]
        steps = [int(s) for s in default_steps if s <= pred_steps and s > 0]
        days = [int(s * pred_range / 24) for s in steps]
        day = int(days_each_step)
                
        preds = []
        total_kl = 0
        for _ in range(pred_steps):
            x, kl = self.predict(x)
            preds.append(x)
            total_kl += kl
        preds = torch.stack(preds, dim=1)
        if len(y.shape) == 4:
            y = y.unsqueeze(1)
        loss_dict = [
            m(preds, y, out_variables, transform=self.denormalization, lat=self.lat,
              log_steps=steps, log_days=days, clim=self.val_clim)
            for m in self.val_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "val/" + var,
                loss_dict[var],
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def test_step(self, batch, batch_idx):
        x, y, variables, out_variables = batch
        pred_steps = y.shape[1]
        pred_range = self.pred_range.hours()
        day = int(pred_range / 24)
        
        default_days = [1, 3, 5]
        days_each_step = pred_range / 24
        default_steps = [
            d / days_each_step for d in default_days if d % days_each_step == 0
        ]
        steps = [int(s) for s in default_steps if s <= pred_steps and s > 0]
        days = [int(s * pred_range / 24) for s in steps]
        day = int(days_each_step)
        
        # rmse for climatology baseline
        clim_pred = self.train_clim  # C, H, W
        clim_pred = (
            clim_pred.unsqueeze(0)
            .unsqueeze(0)
            .repeat(y.shape[0], y.shape[1], 1, 1, 1)
            .to(y.device)
        )
        baseline_rmse = lat_weighted_rmse(
            clim_pred,
            y,
            out_variables,
            transform_pred=False,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_climatology_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )

        # rmse for persistence baseline
        pers_pred = x  # B, 1, C, H, W
        baseline_rmse = lat_weighted_rmse(
            pers_pred,
            y,
            out_variables,
            transform_pred=True,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_persistence_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )

        # rmse for linear regression baseline
        # check if fit_lin_reg_baseline is called by checking whether self.lr_baseline is initialized
        try:
            lr_pred = self.lr_baseline.predict(
                x.cpu().reshape((x.shape[0], -1))
            ).reshape(y.shape)
        except AttributeError as e:
            raise NotImplementedError(
                "Expect climate_learn.models.fit_lin_reg_baseline be implemented before test steps."
            ) from None

        lr_pred = lr_pred[:, np.newaxis, :, :, :]  # B, 1, C, H, W
        lr_pred = torch.from_numpy(lr_pred).float().to(y.device)
        baseline_rmse = lat_weighted_rmse(
            lr_pred,
            y,
            out_variables,
            transform_pred=True,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_ridge_regression_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )
            
        preds = []
        total_kl = 0
        for _ in range(pred_steps):
            x, kl = self.predict(x)
            preds.append(x)
            total_kl += kl
        preds = torch.stack(preds, dim=1)
        if len(y.shape) == 4:
            y = y.unsqueeze(1)
        loss_dict = [
            m(preds, y, out_variables, transform=self.denormalization, lat=self.lat,
              log_steps=steps, log_days=days, clim=self.test_clim)
            for m in self.val_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "test/" + var,
                loss_dict[var],
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def fit_lin_reg_baseline(self, train_dataset, reg_hparam=0.0):
        X_train = train_dataset.inp_data.reshape(train_dataset.inp_data.shape[0], -1)
        y_train = train_dataset.out_data.reshape(train_dataset.out_data.shape[0], -1)
        self.lr_baseline = Ridge(alpha=reg_hparam)
        self.lr_baseline.fit(X_train, y_train)

In [101]:
lm = LitModule(0.05)

In [102]:
from climate_learn.models import set_climatology
set_climatology(lm, dm)

In [103]:
from climate_learn.training import Trainer

trainer = Trainer(
    seed=0,
    accelerator="gpu",
    precision=16,
    max_epochs=5
)

Global seed set to 0


In [104]:
trainer.fit(lm, dm)

You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [105]:
lm.fit_lin_reg_baseline(dm.train_dataset)

In [106]:
trainer.test(lm, dm)

You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Output()