In [1]:
import numpy as np
import argparse
import sys
import matplotlib.pyplot as plt

import xarray as xr
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchmetrics import Metric
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from BalticRiverPrediction.BaltNet import AtmosphericDataset
from BalticRiverPrediction.BaltNet import BaltNet

from lightning.pytorch.callbacks import BasePredictionWriter
import math

import torch
import torch.nn as nn

In [2]:
class EnhancedMSELoss(nn.Module):
    def __init__(self, alpha=1.5):
        """
        Initialize the enhanced MSE loss module.

        Args:
            alpha (float): Exponential factor to increase penalty for larger errors.
        """
        super(EnhancedMSELoss, self).__init__()
        self.alpha = alpha

    def forward(self, predictions, targets):
        """
        Calculate the enhanced MSE loss.

        Args:
            predictions (torch.Tensor): The predicted values.
            targets (torch.Tensor): The ground truth values.

        Returns:
            torch.Tensor: The calculated loss.
        """
        error = predictions - targets
        mse_loss = torch.mean(error**2)
        enhanced_error = torch.mean(torch.abs(error) ** self.alpha)
        enhanced_mse_loss = mse_loss + enhanced_error
        return enhanced_mse_loss

In [3]:
class EnhancedMSEMetric(Metric):
    def __init__(self, alpha=1.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.alpha = alpha
        self.add_state("sum_enhanced_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, predictions: torch.Tensor, targets: torch.Tensor):
        error = predictions - targets
        mse_loss = torch.mean(error ** 2)
        enhanced_error = torch.mean(torch.abs(error) ** self.alpha)

        self.sum_enhanced_error += (mse_loss + enhanced_error) * targets.numel()
        self.total += targets.numel()

    def compute(self):
        return self.sum_enhanced_error / self.total

In [4]:
class AtmosphereDataModule(L.LightningDataModule):

    def __init__(self, atmosphericData, runoff, batch_size=64, num_workers=8, add_first_dim=True, input_size=30):
        super().__init__()

        self.data = atmosphericData
        self.runoff = runoff
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_first_dim = add_first_dim
        self.input_size = input_size

    def setup(self, stage:str):

        UserWarning("Loading atmospheric data ...")
        dataset = AtmosphericDataset(
            atmosphericData=self.data,
            runoff=self.runoff,
            input_size=self.input_size
            )
        
        n_samples = len(dataset)

        train_size = int(0.8 * n_samples)  
        val_size = int(0.1 * n_samples)   
        test_size = n_samples - train_size - val_size  
        self.train, self.val, self.test = random_split(dataset, [train_size, val_size, test_size])
        self.runoffDataStats = dataset.runoffDataStats
        
    def train_dataloader(self):
        return DataLoader(
            dataset=self.train,
            batch_size=self.batch_size,
            shuffle=True, 
            drop_last=True, 
            num_workers=self.num_workers,
            pin_memory=False  # Speed up data transfer to GPU
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=False  # Speed up data transfer to GPU
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
            pin_memory=False  # Speed up data transfer to GPU
    )


In [5]:
class LightningModel(L.LightningModule):
    """
    A PyTorch Lightning model for training and evaluation.
    
    Attributes:
        model (nn.Module): The neural network model.
        learning_rate (float): Learning rate for the optimizer.
        cosine_t_max (int): Maximum number of iterations for the cosine annealing scheduler.
        train_mse (torchmetrics.MeanSquaredError): Metric for training mean squared error.
        val_mse (torchmetrics.MeanSquaredError): Metric for validation mean squared error.
        test_mse (torchmetrics.MeanSquaredError): Metric for testing mean squared error.
    """
    
    def __init__(self, model, learning_rate, cosine_t_max, alpha=4):
        """
        Initializes the LightningModel.

        Args:
            model (nn.Module): The neural network model.
            learning_rate (float): Learning rate for the optimizer.
            cosine_t_max (int): Maximum number of iterations for the cosine annealing scheduler.
        """
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.cosine_t_max = cosine_t_max
        self.loss_function = EnhancedMSELoss(alpha=alpha)

        # Save hyperparameters except the model
        self.save_hyperparameters(ignore=["model"])

        # Define metrics
        self.train_mse = EnhancedMSEMetric(alpha=alpha)
        self.val_mse = EnhancedMSEMetric(alpha=alpha)
        self.test_mse = EnhancedMSEMetric(alpha=alpha)

    def forward(self, x):
        """Defines the forward pass of the model."""
        return self.model(x)
    
    def _shared_step(self, batch, debug=False):
        """
        Shared step for training, validation, and testing.

        Args:
            batch (tuple): Input batch of data.
            debug (bool, optional): If True, prints the loss. Defaults to False.

        Returns:
            tuple: Computed loss, true labels, and predicted labels.
        """
        features, true_labels = batch
        logits = self.model(features)
        loss = self.loss_function(logits, true_labels)
        if debug:
            print(loss)
        return loss, true_labels, logits
    
    def training_step(self, batch, batch_idx):
        """Training step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.train_mse(predicted_labels, true_labels)
        metrics = {"train_mse": mse, "train_loss": loss}
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.val_mse(predicted_labels, true_labels)
        self.log("val_loss", loss, sync_dist=True)
        self.log("val_mse", mse, prog_bar=True, sync_dist=True)
    
    def test_step(self, batch, _):
        """Test step."""
        loss, true_labels, predicted_labels = self._shared_step(batch)
        mse = self.test_mse(predicted_labels, true_labels)
        self.log("test_loss", loss, rank_zero_only=True)
        self.log("test_mse", mse, sync_dist=True)
        return loss
    
    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
        """Prediction step."""
        _, _, predicted_labels = self._shared_step(batch)
        return predicted_labels

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler.

        Returns:
            tuple: List of optimizers and list of learning rate schedulers.
        """
        opt = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=1e-4)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.1, patience=10, verbose=False)
        return {"optimizer": opt, "lr_scheduler": sch, "monitor": "val_mse"}

In [6]:
class PredictionPlottingCallback(L.Callback):
    def __init__(self, num_samples_to_plot=10, rolling_window_size=120):
        """
        Args:
            testDataSet (Dataset): DataLoader for the prediction dataset.
            num_samples_to_plot (int): Number of samples to plot.
            dimensions_to_plot (tuple): Dimensions to plot in the time series.
        """
        super().__init__()
        #self.predictionDataLoader = DataLoader(testDataSet, batch_size=32, shuffle=False, drop_last=True)
        self.num_samples_to_plot = num_samples_to_plot
        self.rolling_window_size = rolling_window_size

    def on_train_epoch_end(self, trainer, pl_module):
        test_dataloader = DataLoader(trainer.datamodule.test, batch_size=32, shuffle=False, drop_last=True)
        try:
            pl_module.eval()
            self.all_predictions = []
            self.all_labels = []

            with torch.no_grad():
                for batch in test_dataloader:
                    features, true_labels = batch
                    # Move data to the same device as the model
                    features, true_labels = features.to(pl_module.device), true_labels.to(pl_module.device)
                    predictions = pl_module(features)

                    # Move predictions and labels to CPU for concatenation
                    self.all_predictions.append(predictions.cpu())
                    self.all_labels.append(true_labels.cpu())

                # Ensure all tensors are on CPU before concatenation
                self.all_predictions = torch.cat(self.all_predictions, dim=0)
                self.all_labels = torch.cat(self.all_labels, dim=0)

                self.plot_time_series(self.all_labels, self.all_predictions, trainer.current_epoch, trainer.datamodule.runoffDataStats)
            pl_module.train()
        except Exception as e:
            print(f"Exception in callback: {e}")

    def rolling_mean(self, data, window_size):
            """Apply a rolling mean to a 2D tensor along the time dimension."""
            cumsum_vec = np.cumsum(np.insert(data, 0, 0, axis=0), axis=0)
            return (cumsum_vec[window_size:] - cumsum_vec[:-window_size]) / window_size

    def plot_time_series(self, true_labels, predictions, epoch, runoffStats):
        num_plots = min(len(predictions), self.num_samples_to_plot)
        # Calculate the number of rows and columns for the subplot
        num_cols = int(math.ceil(math.sqrt(num_plots)))
        num_rows = int(math.ceil(num_plots / num_cols))

        true_labels = true_labels*runoffStats[1].data + runoffStats[0].data
        predictions = predictions*runoffStats[1].data + runoffStats[0].data

        # Apply rolling mean
        smoothed_true_labels = self.rolling_mean(true_labels.numpy(), self.rolling_window_size)
        smoothed_predictions = self.rolling_mean(predictions.numpy(), self.rolling_window_size)

        avg_labels = true_labels.mean(dim=0)
        _, top_river_indices = torch.topk(avg_labels, self.num_samples_to_plot, largest=True)    

        plt.figure(figsize=(num_cols * 4, num_rows * 4))  # Adjust the size dynamically based on the number of subplots
        for i, river_index in enumerate(top_river_indices):
            plt.subplot(num_rows, num_cols, i + 1)
            plt.plot(smoothed_predictions[:, river_index], label=f"Smoothed Pred for river {i}")
            plt.plot(smoothed_true_labels[:, river_index], label=f"Smoothed True for river {i}", alpha=0.5)
            plt.legend()

        plt.suptitle(f"Epoch {epoch} Predictions {self.num_samples_to_plot} largest rivers")  
        plt.xlabel("Time")  
        plt.ylabel("Value")  
        plt.savefig(f'figures/time_series_predictions_epoch_{epoch}.png')
        plt.close() 

In [7]:
# Set seed for reproducible
L.seed_everything(123)

# Use available tensor cores
torch.set_float32_matmul_precision("medium")

datapath="/silor/boergel/paper/runoff_prediction/data"
datapathPP="/silod9/boergel/runoff_prediction_ERA5_downscaled_coupled_model/resampled"

runoff = xr.open_dataset(f"{datapath}/runoff.nc").load()
runoff = runoff.sel(time=slice("1979", "2005"))
runoff = runoff.roflux

DataRain = xr.open_dataset(f"{datapathPP}/rain.nc")
DataRain = DataRain.sel(time=slice("1979", "2005"))
DataRain = DataRain.rain.squeeze()
DataRain = DataRain.drop(["lon","lat"])
DataRain = DataRain.rename({"rlon":"x","rlat":"y"})

DataShumi = xr.open_dataset(f"{datapathPP}/QV.nc")
DataShumi = DataShumi.sel(time=slice("1979", "2005"))
DataShumi = DataShumi.QV.squeeze()
DataShumi = DataShumi.drop(["lon","lat"])
DataShumi = DataShumi.rename({"rlon":"x","rlat":"y"})

DataWindSpeed = xr.open_dataset(f"{datapathPP}/speed.nc")
DataWindSpeed = DataWindSpeed.sel(time=slice("1979", "2005"))
DataWindSpeed = DataWindSpeed.speed.squeeze()
DataWindSpeed = DataWindSpeed.drop(["lon","lat"])
DataWindSpeed = DataWindSpeed.rename({"rlon":"x","rlat":"y"})

DataTemp = xr.open_dataset(f"{datapathPP}/T.nc")
DataTemp = DataTemp.sel(time=slice("1979", "2005"))
DataTemp = DataTemp.T.squeeze()
DataTemp = DataTemp.drop(["lon","lat"])
DataTemp = DataTemp.rename({"rlon":"x","rlat":"y"})

assert DataShumi.time[0] == DataRain.time[0] == DataWindSpeed.time[0]
assert len(DataShumi.time) == len(DataRain.time) == len(DataWindSpeed.time)

data = xr.merge([DataRain, DataShumi, DataWindSpeed, DataTemp])
assert len(runoff.time) == len(data.time)

Global seed set to 123


In [8]:
modelParameters = {
    "input_dim": 4,
    "hidden_dim": 8,
    "kernel_size": (7,7),
    "num_layers": 1,
    "batch_first": True,
    "bias": True,
    "return_all_layers": False,
    "dimensions": (222,244),
    "input_size": 90
}

In [9]:
# Loads the atmospheric data in batches
dataLoader = AtmosphereDataModule(
    atmosphericData=data,
    runoff=runoff,
    batch_size=16,
    input_size=modelParameters["input_size"],
    num_workers=8
)

num_epochs = 80

pyTorchBaltNet = BaltNet(modelPar=modelParameters)

LightningBaltNet = LightningModel(
    model=pyTorchBaltNet,
    learning_rate=1e-3,
    cosine_t_max=num_epochs
    )

In [10]:
callbacks = [
    ModelCheckpoint(
        dirpath="/silor/boergel/paper/runoff_prediction/data/modelWeights/",
        filename=f"TestPredictionCallBackTopOne",
        save_top_k=1,
        mode="min",
        monitor="val_mse",
        save_last=True,
        ),
    PredictionPlottingCallback()
    ]

logger = TensorBoardLogger(
    save_dir="/silor/boergel/paper/runoff_prediction/logs",
    name=f"TestPredictionCallBack"
    )   

trainer = L.Trainer(
    precision="bf16-mixed",
    callbacks=callbacks,
    max_epochs=num_epochs,
    accelerator="cuda",
    devices=2,
    logger=logger
    )

trainer.fit(model=LightningBaltNet, datamodule=dataLoader)


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 123
[rank: 1] Global seed set to 123
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type              | Params
----------------------------------------------------
0 | model         | BaltNet           | 222 M 
1 | loss_function | En

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=80` reached.
