In [1]:
#| default_exp BaltNet
#| default_cls_lvl 3

# BaltNet

Model architecture used for predicting river runoff in the Baltic Sea

In [2]:
#| export

import torch
import torch.nn as nn
import lightning as L
import pytorch_lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch.loggers import CSVLogger

import xarray as xr
from glob import glob
from tqdm import tqdm

from BalticRiverPrediction.convLSTM import ConvLSTM
from BalticRiverPrediction.sharedUtilities import read_netcdfs, preprocess, plot_loss_and_acc

In [3]:
#| export
class BaltNet(nn.Module):
    def __init__(self, modelPar):
        super(BaltNet, self).__init__()

        # initialize all attributes
        for k, v in modelPar.items():
            setattr(self, k, v)

        self.linear_dim = self.dimensions[0]*self.dimensions[1]*self.hidden_dim

        self.convLSTM = ConvLSTM(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                kernel_size=self.kernel_size,
                num_layers=self.num_layers,
                batch_first=self.batch_first,
                bias=self.bias,
                return_all_layers=self.return_all_layers
        )

        self.convLSTM2 = ConvLSTM(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                kernel_size=self.kernel_size,
                num_layers=1,
                batch_first=self.batch_first,
                bias=self.bias,
                return_all_layers=self.return_all_layers
        )

        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(self.linear_dim, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 97)
            )

    def forward(self, x):
        _, encode_state = self.convLSTM(x)
        decoder_out, _ = self.convLSTM2(x[:,-1:,:,:,:], encode_state)
        x = decoder_out[0]
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x).squeeze()
        return x


In [4]:
#| export
class BaseLineModel(nn.Module):
    def __init__(self, modelPar):
        super(BaseLineModel, self).__init__()

        # initialize all attributes
        for k, v in modelPar.items():
            setattr(self, k, v)

        self.linear_dim = self.dimensions[0]*self.dimensions[1]*self.hidden_dim*self.input_dim

        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(self.linear_dim, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1)
            )
        
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x).squeeze()
        return x

In [5]:
#| export

class LightningModel(L.LightningModule):
    def __init__(self, model, learning_rate, cosine_t_max):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.cosine_t_max = cosine_t_max

        self.save_hyperparameters(ignore=["model"])

        self.train_mse = torchmetrics.MeanSquaredError()
        self.val_mse = torchmetrics.MeanSquaredError()
        self.test_mse = torchmetrics.MeanSquaredError()

    def forward(self, x):
        return self.model(x)
    
    def _shared_step(self, batch, debug=False):
        features, true_labels = batch
        logits = self.model(features)
        loss = F.mse_loss(logits, true_labels)
        if debug == True:
            print(loss)
        return loss, true_labels, logits
    
    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch, debug=False)
        mse = self.train_mse(predicted_labels, true_labels)
        metrics = {"train_mse":mse, "train_loss":loss}
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar=True,logger=True, sync_dist=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, true_labes, predicted_labels = self._shared_step(batch)
        mse = self.val_mse(predicted_labels, true_labes)
        self.log("val_loss", loss, sync_dist=True)
        self.log("val_mse", mse, prog_bar=True, sync_dist=True)
    
    def test_step(self, batch, _):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.test_mse(predicted_labels, true_labels)
        self.log("test_loss", loss, rank_zero_only=True)
        self.log("test_mse", mse, sync_dist=True)
        return loss
    
    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
        _, _, predicted_labels = self._shared_step(batch)
        return predicted_labels

    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, eta_min=1e-7, T_max=self.cosine_t_max)

        return [opt], [sch]


In [6]:
#| export
class AtmosphericDataset(Dataset):
    def __init__(self, input_size, atmosphericData, runoff, transform=None):

        # Update 6.10.2023
        # The function is not handling the preprocessing anymore
        # which makes the function more flexible 
        # Following the technical description of the river data (Germo et al.)
        # the original river data is limited to 1979 to 2011
        # start_year, end_year = 1979, 2011
        # self.timeRange = slice(str(start_year), str(end_year))
        
        # Length of the sequence
        self.input_size = input_size

        # input data (x) 
        atmosphericDataStd = atmosphericData.std("time") # dimension will be channel, lat, lon
        atmosphericDataMean = atmosphericData.mean("time")
        self.atmosphericStats = (atmosphericDataMean, atmosphericDataStd)

        # output data - label (y)
        runoffData = runoff.transpose("time", "river")
        runoffDataMean = runoffData.mean("time")
        runoffDataSTD = runoffData.std("time")
        self.runoffData = (runoffDataMean, runoffDataSTD)

        # save data
        np.savetxt(
            "/silor/boergel/paper/runoff_prediction/data/modelStats.txt",
            [runoffDataMean, runoffDataSTD]
        )
        
        # normalize data
        X = ((atmosphericData - atmosphericDataMean)/atmosphericDataStd).compute()
        y = ((runoffData - runoffDataMean)/runoffDataSTD).compute()
        
        # If only 3 dimension are available (time, lat, lon) 
        # an additional dimension for the channel is added
        # to end up with (time, channel, lat, lon)

        xStacked = X.to_array(dim='variable')
        xStacked = xStacked.transpose("time", "variable", "lat", "lon")

        if len(xStacked.data.ndim) == 3:
            self.x = torch.tensor(xStacked.data, dtype=torch.float32).unsqueeze(dim=1)
        else:
            assert len(xStacked.data.ndim) == 4
            self.x = torch.tensor(xStacked.data, dtype=torch.float32)
       
        self.y = torch.tensor(y.data, dtype=torch.float32)

    def __getitem__(self, index):
        return self.x[:, index:index+(self.input_size)], self.y[index+int(self.input_size)]

    def __len__(self):
        return self.y.shape[0]-(self.input_size)


In [7]:
#| export
class AtmosphereDataModule(L.LightningDataModule):
    
    def __init__(self, atmosphericData, runoff, batch_size=64, num_workers=8, add_first_dim=True, input_size=30):
        super().__init__()

        self.data = atmosphericData
        self.runoff = runoff
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_first_dim = add_first_dim
        self.input_size = input_size
    
    def setup(self, stage:str):
        UserWarning("Loading atmospheric data ...")
        dataset = AtmosphericDataset(
            atmosphericData=self.data,
            runoff=self.runoff,
            input_size=self.input_size
            )
        n_samples = len(dataset)
        train_size = int(0.9 * n_samples)
        val_size = n_samples - train_size
        self.train, self.val, = random_split(dataset, [train_size, val_size])
        # val_size = int(0.1 * n_samples)
        # test_size = n_samples - train_size  - val_size
        # self.train, self.val, self.test = random_split(dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train,
            batch_size=self.batch_size,
            shuffle=True, 
            drop_last=True, 
            num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(
            dataset=self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True)

    # def test_dataloader(self):
    #     return DataLoader(
    #         self.test,
    #         batch_size=self.batch_size,
    #         shuffle=False,
    #         num_workers=self.num_workers, 
    #         drop_last=True)