# Load data

In [68]:
from climate_learn.utils.datetime import Year, Days, Hours
from climate_learn.data import DataModule

dm = DataModule(
    dataset="ERA5",
    task="forecasting",
    root_dir="../climate-learn/data/weatherbench/era5/5.625",
    in_vars=["2m_temperature"],
    out_vars=["2m_temperature"],
    train_start_year=Year(1979),
    val_start_year=Year(1980),
    test_start_year=Year(1981),
    end_year=Year(1982),
    pred_range=Days(3),
    subsample=Hours(6),
    batch_size=128,
    num_workers=8
)

Creating train dataset


  0%|          | 0/1 [00:00<?, ?it/s]

Creating val dataset


  0%|          | 0/1 [00:00<?, ?it/s]

Creating test dataset


  0%|          | 0/2 [00:00<?, ?it/s]

Create BayesianCNN.

In [69]:
dm.train_dataset[0][0].shape

torch.Size([1, 1, 32, 64])

# Build BayesCNN

In [165]:
# Scratch cell for computing output sizes
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

import numpy as np

hout = np.floor((29 + 2 * 0 - 1 * (3 - 1) - 1) / 1 + 1)
wout = np.floor((61 + 2 * 0 - 1 * (3 - 1) - 1) / 1 + 1)

print(hout, wout)

27.0 59.0


In [210]:
from layers import *
import torch.nn as nn

# https://arxiv.org/pdf/1901.02731.pdf

class BayesCNN(ModuleWrapper):
    def __init__(self):
        super().__init__()
        self.conv1 = BBB_Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.act1 = nn.Softplus()
        self.conv2 = BBB_Conv2d(
            in_channels=3,
            out_channels=5,
            kernel_size=3,
            stride=1,
            padding=1
        )
        # self.maxpool2 = nn.AvgPool2d(2)
        # self.bn2 = nn.BatchNorm2d(5)
        # self.tanh2 = nn.Tanh()
        # self.conv3 = BBB_Conv2d(
        #     in_channels=5,
        #     out_channels=7,
        #     kernel_size=3,
        #     stride=1
        # )
        # self.bn3 = nn.BatchNorm2d(7)
        # self.tanh3 = nn.Tanh()
        # self.conv4 = BBB_Conv2d(
        #     in_channels=7,
        #     out_channels=7,
        #     kernel_size=3,
        #     stride=1
        # )
        # self.flatten = FlattenLayer(2352)
        # self.fc = BBB_Linear(2352, 32*64)
        self.act2 = nn.Softplus()
        self.conv3 = BBB_Conv2d(
            in_channels=5,
            out_channels=3,
            kernel_size=3,
            stride=1,
            padding=1
        )
    
    # def forward(self, x):
    #     orig_shape = x.shape
    #     x = self.conv1(x)
    #     x = self.bn1(x)
    #     x = self.tanh1(x)
    #     x = self.conv2(x)
    #     x = self.maxpool2(x)
    #     x = self.bn2(x)
    #     x = self.tanh2(x)
    #     x = self.conv3(x)
    #     x = self.bn3(x)
    #     x = self.tanh3(x)
    #     x = self.conv4(x)
    #     x = self.flatten(x)
    #     x = self.fc(x)
    #     z = x.reshape(orig_shape)
    #     return z

In [211]:
import torch

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"
    
print(device)

cuda:0


In [215]:
# Test it
# outputs are (logits, kl)
net = BayesCNN().cuda()
logits, _ = net(dm.train_dataset[0][0].to(device))

In [216]:
logits

tensor([[[[ 6.7370e-02, -1.2802e-01, -6.8495e-03,  ...,  2.5561e-03,
            8.8802e-03, -5.3915e-02],
          [ 2.1930e-01, -1.0504e-02,  6.8589e-02,  ...,  5.9458e-02,
            4.7761e-02,  7.1697e-03],
          [ 2.6401e-01,  5.7176e-02,  1.3267e-01,  ...,  7.8685e-02,
            7.5449e-02, -2.7352e-02],
          ...,
          [ 2.4301e-01,  4.5493e-05,  8.4145e-02,  ...,  9.2749e-02,
            1.0889e-01,  6.9901e-02],
          [ 2.1565e-01, -1.6570e-02,  4.6437e-02,  ..., -9.4138e-03,
           -7.0095e-03, -5.5375e-02],
          [ 1.3910e-01,  4.5960e-02,  6.4850e-02,  ...,  4.1872e-02,
            3.8163e-02, -7.2926e-02]],

         [[-2.1275e-01, -5.7509e-02, -5.8342e-02,  ..., -3.0651e-02,
           -3.0758e-02,  6.0147e-02],
          [-1.9617e-01, -2.2463e-02, -3.5384e-02,  ..., -1.0935e-02,
            8.1864e-02,  4.8895e-02],
          [-1.8487e-01, -3.0218e-02, -4.0427e-02,  ..., -5.7251e-02,
            1.4066e-02,  4.0066e-02],
          ...,
     

In [217]:
logits.detach().cpu().size()

torch.Size([1, 3, 32, 64])

In [218]:
import pytorch_lightning as pl
from torchvision.transforms import transforms
from climate_learn.models.modules.utils.metrics import (
    lat_weighted_mse,
    lat_weighted_rmse,
    lat_weighted_acc,    
)
from sklearn.linear_model import Ridge

class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = BayesCNN()
        self.train_loss = [lat_weighted_mse]
        self.val_loss = [
            lat_weighted_rmse,
            lat_weighted_acc
        ]
        
    def forward(self, x):
        # if using the BayesCNN pre-defined forward
        logits, _ = self.net(torch.squeeze(x, 1))
        # otherwise
        # logits = self.net(torch.squeeze(x, 1))
        return logits
    
    def set_denormalization(self, mean, std):
        self.denormalization = transforms.Normalize(mean, std)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

        mean_mean_denorm, mean_std_denorm = -mean / std, 1 / std
        self.mean_denormalize = transforms.Normalize(mean_mean_denorm, mean_std_denorm)

        std_mean_denorm, std_std_denorm = np.zeros_like(std), 1 / std
        self.std_denormalize = transforms.Normalize(std_mean_denorm, std_std_denorm)

    def set_lat_lon(self, lat, lon):
        self.lat = lat
        self.lon = lon

    def set_pred_range(self, r):
        self.pred_range = r

    def set_train_climatology(self, clim):
        self.train_clim = clim

    def set_val_climatology(self, clim):
        self.val_clim = clim

    def set_test_climatology(self, clim):
        self.test_clim = clim
    
    def training_step(self, batch, batch_idx):
        x, y, _, out_variables = batch
        y_hat = self(x)
        loss_dict = [
            m(y_hat, y, out_variables, lat=self.lat)
            for m in self.train_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "train/" + var,
                loss_dict[var],
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def validation_step(self, batch, batch_idx):
        x, y, variables, out_variables = batch
        pred_steps = y.shape[1]
        pred_range = self.pred_range.hours()
        
        default_days = [1, 3, 5]
        days_each_step = pred_range / 24
        default_steps = [
            d / days_each_step for d in default_days if d % days_each_step == 0
        ]
        steps = [int(s) for s in default_steps if s <= pred_steps and s > 0]
        days = [int(s * pred_range / 24) for s in steps]
        day = int(days_each_step)
                
        preds = []
        for _ in range(pred_steps):
            x = self.forward(x)
            preds.append(x)
        preds = torch.stack(preds, dim=1)
        if len(y.shape) == 4:
            y = y.unsqueeze(1)
        loss_dict = [
            m(preds, y, out_variables, transform=self.denormalization, lat=self.lat,
              log_steps=steps, log_days=days, clim=self.val_clim)
            for m in self.val_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "val/" + var,
                loss_dict[var],
                on_step=True,
                on_epoch=False,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def test_step(self, batch, batch_idx):
        x, y, variables, out_variables = batch
        pred_steps = y.shape[1]
        pred_range = self.pred_range.hours()
        day = int(pred_range / 24)
        
        default_days = [1, 3, 5]
        days_each_step = pred_range / 24
        default_steps = [
            d / days_each_step for d in default_days if d % days_each_step == 0
        ]
        steps = [int(s) for s in default_steps if s <= pred_steps and s > 0]
        days = [int(s * pred_range / 24) for s in steps]
        day = int(days_each_step)
        
        # rmse for climatology baseline
        clim_pred = self.train_clim  # C, H, W
        clim_pred = (
            clim_pred.unsqueeze(0)
            .unsqueeze(0)
            .repeat(y.shape[0], y.shape[1], 1, 1, 1)
            .to(y.device)
        )
        baseline_rmse = lat_weighted_rmse(
            clim_pred,
            y,
            out_variables,
            transform_pred=False,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_climatology_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )

        # rmse for persistence baseline
        pers_pred = x  # B, 1, C, H, W
        baseline_rmse = lat_weighted_rmse(
            pers_pred,
            y,
            out_variables,
            transform_pred=True,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_persistence_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )

        # rmse for linear regression baseline
        # check if fit_lin_reg_baseline is called by checking whether self.lr_baseline is initialized
        try:
            lr_pred = self.lr_baseline.predict(
                x.cpu().reshape((x.shape[0], -1))
            ).reshape(y.shape)
        except AttributeError as e:
            raise NotImplementedError(
                "Expect climate_learn.models.fit_lin_reg_baseline be implemented before test steps."
            ) from None

        lr_pred = lr_pred[:, np.newaxis, :, :, :]  # B, 1, C, H, W
        lr_pred = torch.from_numpy(lr_pred).float().to(y.device)
        baseline_rmse = lat_weighted_rmse(
            lr_pred,
            y,
            out_variables,
            transform_pred=True,
            transform=self.denormalization,
            lat=self.lat,
            log_steps=steps,
            log_days=days,
        )
        for var in baseline_rmse.keys():
            self.log(
                "test_ridge_regression_baseline/" + var,
                baseline_rmse[var],
                on_step=False,
                on_epoch=True,
                sync_dist=True,
                batch_size=len(x),
            )
            
        preds = []
        for _ in range(pred_steps):
            x = self.forward(x)
            preds.append(x)
        preds = torch.stack(preds, dim=1)
        if len(y.shape) == 4:
            y = y.unsqueeze(1)
        loss_dict = [
            m(preds, y, out_variables, transform=self.denormalization, lat=self.lat,
              log_steps=steps, log_days=days, clim=self.test_clim)
            for m in self.val_loss
        ][0]
        for var in loss_dict.keys():
            self.log(
                "test/" + var,
                loss_dict[var],
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                batch_size=len(x)
            )
        return loss_dict
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def fit_lin_reg_baseline(self, train_dataset, reg_hparam=0.0):
        X_train = train_dataset.inp_data.reshape(train_dataset.inp_data.shape[0], -1)
        y_train = train_dataset.out_data.reshape(train_dataset.out_data.shape[0], -1)
        self.lr_baseline = Ridge(alpha=reg_hparam)
        self.lr_baseline.fit(X_train, y_train)

In [219]:
lm = LitModule()

In [220]:
from climate_learn.models import set_climatology
set_climatology(lm, dm)

In [221]:
from climate_learn.training import Trainer

trainer = Trainer(
    seed=0,
    accelerator="gpu",
    precision=16,
    max_epochs=5
)

Global seed set to 0


In [222]:
trainer.fit(lm, dm)

You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Output()

In [223]:
lm.fit_lin_reg_baseline(dm.train_dataset)

In [224]:
trainer.test(lm, dm)

You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Output()