In [9]:
#| default_exp BaltNet
#| default_cls_lvl 3

# BaltNet

Model architecture used for predicting river runoff in the Baltic Sea

In [10]:
#| hide
from nbdev import show_doc

In [11]:
#| export

import torch
import torch.nn as nn
import lightning as L
import pytorch_lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch.loggers import CSVLogger

import xarray as xr
from glob import glob
from tqdm import tqdm

from BalticRiverPrediction.convLSTM import ConvLSTM
from BalticRiverPrediction.sharedUtilities import EnhancedMSELoss, EnhancedMSEMetric
from BalticRiverPrediction.sharedUtilities import PredictionPlottingCallback


In [12]:
#| export
class BaltNet(nn.Module):
    def __init__(self, modelPar):
        super(BaltNet, self).__init__()

        # Initialize all attributes
        for k, v in modelPar.items():
            setattr(self, k, v)

        self.encoder = ConvLSTM(
            input_dim=self.input_dim,
            hidden_dim=self.hidden_dim,
            kernel_size=self.kernel_size,
            num_layers=self.num_layers,
            batch_first=self.batch_first,
            bias=self.bias,
            return_all_layers=False
        )

        self.decoder = ConvLSTM(
            input_dim=self.hidden_dim,
            hidden_dim=self.hidden_dim,
            kernel_size=self.kernel_size,
            num_layers=1,
            batch_first=self.batch_first,
            bias=self.bias,
            return_all_layers=False
        )

        self.linear_dim = self.dimensions[0] * self.dimensions[1] * self.hidden_dim 

        # Single fully connected network for all rivers
         
        self.river_predictors = nn.Sequential(
            nn.Linear(self.linear_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 97)
        )

        # Creating separate attention weights for each river
        # self.attention_weights = nn.Parameter(torch.randn(self.hidden_dim, 1, 1), requires_grad=True)  # 97 rivers

    # def spatial_attention(self, x):
    #     """Spatial attention mechanism."""
    #     B, T, C, H, W = x.size()

    #     x = x.view(B * T, C, H, W)
        
    #     # Apply attention weights for all rivers
    #     self.attention_map = torch.sigmoid(F.conv2d(x, self.attention_weights.unsqueeze(0), bias=None, stride=1, padding=0))
        
    #     # Weighted sum
    #     output = x * self.attention_map  # B*T, C, H, W
    #     output = output.view(B, T, C, H, W)  # B, T, C

    #     return output

    def forward(self, x):
        B, _, _, _, _ = x.size()

        # Pass through encoder
        encoder_outputs, encoder_hidden = self.encoder(x)

        # Use the entire encoder output as input to the decoder
        decoder_input = encoder_outputs[0][:,-1,:,:,:].unsqueeze(1)

        # Pass through decoder using the final hidden state of the encoder
        decoder_outputs, _ = self.decoder(decoder_input, encoder_hidden)

        # Apply spatial attention
        # decoder_with_spatial_attention = self.spatial_attention(decoder_outputs[0])  # B, T, C, H, W
            
        # Flatten the temporal sequence
        decoder_with_spatial_attention_flattened = decoder_outputs[0].view(B, -1)  #
            
        # Pass through its own predictor
        output = self.river_predictors(decoder_with_spatial_attention_flattened)  # B, -1

        return output

In [13]:
show_doc(BaltNet)

---

### BaltNet

>      BaltNet (modelPar)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

.. note::
    As per the example above, an ``__init__()`` call to the parent class
    must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or
                evaluation mode.
:vartype training: bool

In [14]:
# show_doc(BaltNet.spatial_attention)


In [15]:
show_doc(BaltNet.forward)

---

### BaltNet.forward

>      BaltNet.forward (x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.

In [16]:
#| export
class LightningModel(L.LightningModule):
    """
    A PyTorch Lightning model for training and evaluation.
    
    Attributes:
        model (nn.Module): The neural network model.
        learning_rate (float): Learning rate for the optimizer.
        cosine_t_max (int): Maximum number of iterations for the cosine annealing scheduler.
        train_mse (torchmetrics.MeanSquaredError): Metric for training mean squared error.
        val_mse (torchmetrics.MeanSquaredError): Metric for validation mean squared error.
        test_mse (torchmetrics.MeanSquaredError): Metric for testing mean squared error.
    """
    
    def __init__(self, model, learning_rate, cosine_t_max, alpha=4):
        """
        Initializes the LightningModel.

        Args:
            model (nn.Module): The neural network model.
            learning_rate (float): Learning rate for the optimizer.
            cosine_t_max (int): Maximum number of iterations for the cosine annealing scheduler.
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.cosine_t_max = cosine_t_max
        self.loss_function = EnhancedMSELoss(alpha=alpha)

        # Save hyperparameters except the model
        self.save_hyperparameters(ignore=["model"])

        # Define metrics
        self.train_mse = EnhancedMSEMetric(alpha=alpha)
        self.val_mse = EnhancedMSEMetric(alpha=alpha)
        self.test_mse = EnhancedMSEMetric(alpha=alpha)

    def forward(self, x):
        """Defines the forward pass of the model."""
        return self.model(x)
    
    def _shared_step(self, batch, debug=False):
        """
        Shared step for training, validation, and testing.

        Args:
            batch (tuple): Input batch of data.
            debug (bool, optional): If True, prints the loss. Defaults to False.

        Returns:
            tuple: Computed loss, true labels, and predicted labels.
        """
        features, true_labels = batch
        logits = self.model(features)
        loss = self.loss_function(logits, true_labels)
        if debug:
            print(loss)
        return loss, true_labels, logits
    
    def training_step(self, batch, batch_idx):
        """Training step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.train_mse(predicted_labels, true_labels)
        metrics = {"train_mse": mse, "train_loss": loss}
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.val_mse(predicted_labels, true_labels)
        self.log("val_loss", loss, sync_dist=True)
        self.log("val_mse", mse, prog_bar=True, sync_dist=True)
    
    def test_step(self, batch, _):
        """Test step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.test_mse(predicted_labels, true_labels)
        self.log("test_loss", loss, rank_zero_only=True)
        self.log("test_mse", mse, sync_dist=True)
        #return loss
        return predicted_labels
    
    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
        """Prediction step."""
        _, _, predicted_labels = self._shared_step(batch)
        return predicted_labels

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.

        Returns:
            tuple: List of optimizers and list of learning rate schedulers.
        """
        opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=1e-4)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10, verbose=False)
        return {"optimizer": opt, "lr_scheduler": sch, "monitor": "val_mse"}

In [17]:
show_doc(LightningModel)

---

### LightningModel

>      LightningModel (model, learning_rate, cosine_t_max, alpha=4)

A PyTorch Lightning model for training and evaluation.

Attributes:
    model (nn.Module): The neural network model.
    learning_rate (float): Learning rate for the optimizer.
    cosine_t_max (int): Maximum number of iterations for the cosine annealing scheduler.
    train_mse (torchmetrics.MeanSquaredError): Metric for training mean squared error.
    val_mse (torchmetrics.MeanSquaredError): Metric for validation mean squared error.
    test_mse (torchmetrics.MeanSquaredError): Metric for testing mean squared error.

In [18]:
show_doc(LightningModel._shared_step)

---

### LightningModel._shared_step

>      LightningModel._shared_step (batch, debug=False)

Shared step for training, validation, and testing.

Args:
    batch (tuple): Input batch of data.
    debug (bool, optional): If True, prints the loss. Defaults to False.

Returns:
    tuple: Computed loss, true labels, and predicted labels.

In [19]:
#| export
class AtmosphericDataset(Dataset):
    def __init__(self, input_size, atmosphericData, runoff, transform=None):

        # Length of the sequence
        self.input_size = input_size

        # input data (x) 
        atmosphericDataStd = atmosphericData.std("time") # dimension will be channel, lat, lon
        atmosphericDataMean = atmosphericData.mean("time")
        self.atmosphericStats = (atmosphericDataMean, atmosphericDataStd)

        # output data - label (y)
        runoffData = runoff.transpose("time", "river")
        runoffDataMean = runoffData.mean("time")
        runoffDataSTD = runoffData.std("time")
        self.runoffDataStats = (runoffDataMean, runoffDataSTD)
        
        # normalize data
        X = ((atmosphericData - atmosphericDataMean)/atmosphericDataStd).compute()
        y = ((runoffData - runoffDataMean)/runoffDataSTD).compute()
        
        # an additional dimension for the channel is added
        # to end up with (time, channel, lat, lon)
        xStacked = X.to_array(dim='variable')
        xStacked = xStacked.transpose("time", "variable", "y", "x")

        assert xStacked.data.ndim == 4
        self.x = torch.tensor(xStacked.data, dtype=torch.float32)
        self.y = torch.tensor(y.data, dtype=torch.float32)

    def __getitem__(self, index):
        return self.x[index:index+(self.input_size)], self.y[index+int(self.input_size)]

    def __len__(self):
        return self.y.shape[0]-(self.input_size)

In [20]:
show_doc(AtmosphericDataset)

---

### AtmosphericDataset

>      AtmosphericDataset (input_size, atmosphericData, runoff, transform=None)

An abstract class representing a :class:`Dataset`.

All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.

.. note::
  :class:`~torch.utils.data.DataLoader` by default constructs a index
  sampler that yields integral indices.  To make it work with a map-style
  dataset with non-integral indices/keys, a custom sampler must be provided.

In [21]:
#| export
class AtmosphereDataModule(L.LightningDataModule):

    def __init__(self, atmosphericData, runoff, batch_size=64, num_workers=8, add_first_dim=True, input_size=30):
        super().__init__()

        self.data = atmosphericData
        self.runoff = runoff
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_first_dim = add_first_dim
        self.input_size = input_size

    def setup(self, stage:str):

        UserWarning("Loading atmospheric data ...")
        dataset = AtmosphericDataset(
            atmosphericData=self.data,
            runoff=self.runoff,
            input_size=self.input_size
            )
        n_samples = len(dataset)
        train_size = int(0.8 * n_samples)  
        val_size = int(0.1 * n_samples)   
        test_size = n_samples - train_size - val_size  

        generator1 = torch.Generator().manual_seed(42)
        self.train, self.val, self.test = random_split(dataset, [train_size, val_size, test_size], generator=generator1)
        self.runoffDataStats = dataset.runoffDataStats
        
    def train_dataloader(self):
        return DataLoader(
            dataset=self.train,
            batch_size=self.batch_size,
            shuffle=True, 
            drop_last=True, 
            num_workers=self.num_workers,
            pin_memory=False  # Speed up data transfer to GPU
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=False  # Speed up data transfer to GPU
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=False  # Speed up data transfer to GPU
    )


In [22]:
show_doc(AtmosphereDataModule)

---

### AtmosphereDataModule

>      AtmosphereDataModule (atmosphericData, runoff, batch_size=64,
>                            num_workers=8, add_first_dim=True, input_size=30)

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
consistent data splits, data preparation and transforms across models.

Example::

    class MyDataModule(LightningDataModule):
        def __init__(self):
            super().__init__()
        def prepare_data(self):
            # download, split, etc...
            # only called on 1 GPU/TPU in distributed
        def setup(self, stage):
            # make assignments here (val/train/test split)
            # called on every process in DDP
        def train_dataloader(self):
            train_split = Dataset(...)
            return DataLoader(train_split)
        def val_dataloader(self):
            val_split = Dataset(...)
            return DataLoader(val_split)
        def test_dataloader(self):
            test_split = Dataset(...)
            return DataLoader(test_split)
        def teardown(self):
            # clean up after fit or test
            # called on every process in DDP