In [1]:

import torch
import lightning as L
import numpy as np

from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split

from BalticRiverPrediction.BaltNet import BaltNet
from BalticRiverPrediction.BaltNet import LightningModel
from BalticRiverPrediction.BaltNet import AtmosphereDataModule
from BalticRiverPrediction.sharedUtilities import read_netcdfs, preprocess



In [2]:
#| export

import torch
import torch.nn as nn
import lightning as L
import pytorch_lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch.loggers import CSVLogger

import xarray as xr
from glob import glob
from tqdm import tqdm

from BalticRiverPrediction.convLSTM import ConvLSTM
from BalticRiverPrediction.sharedUtilities import read_netcdfs, preprocess, plot_loss_and_acc

In [3]:

# Set seed for reproducible
L.seed_everything(123)

# Use available tensor cores
torch.set_float32_matmul_precision("high")

# X 
datapath="/silor/boergel/paper/runoff_prediction/data"
datapathPP="/fast/boergel/paper/runoff_prediction"

runoff = xr.open_dataset(f"{datapathPP}/runoff.nc").load()
runoff = runoff.roflux

DataRain = xr.open_dataset(f"{datapathPP}/rain2.nc")
DataRain = DataRain.sel(time=slice("1979", "2011"))
DataRain = DataRain.RAIN.squeeze()

DataShumi = xr.open_dataset(f"{datapathPP}/shumi.nc")
DataShumi = DataShumi.sel(time=slice("1979", "2011"))
DataShumi = DataShumi.shumi.squeeze()

DataWindMagnitude = xr.open_dataset(f"{datapathPP}/windxy.nc")
DataWindMagnitude['wind_magnitude'] = (DataWindMagnitude['windx']**2 + DataWindMagnitude['windy']**2)**0.5
DataWindMagnitude = DataWindMagnitude.sel(time=slice("1979", "2011"))
DataWindMagnitude = DataWindMagnitude.wind_magnitude.squeeze(dim="height", drop=True)




Global seed set to 123


In [4]:
DataWindMagnitude

In [None]:
assert DataShumi.time[0] == DataRain.time[0] == DataWindMagnitude.time[0]
assert len(DataShumi.time) == len(DataRain.time) == len(DataWindMagnitude.time)

data = xr.merge([DataRain, DataShumi, DataWindMagnitude])
assert len(runoff.time) == len(data.time)

In [3]:

# set seed for reproducibility   
L.seed_everything(123)

# Our GPU has tensor cores, hence mixed precision training is enabled
# see https://sebastianraschka.com/blog/2023/llm-mixed-precision-copy.html

torch.set_float32_matmul_precision("medium")

datapath="/silor/boergel/paper/runoff_prediction/data"
datapathPP="/fast/boergel/paper/runoff_prediction"



data = read_netcdfs(
    files=f"{datapath}/atmosphericForcing/????/shumi.mom.dta.nc",
    dim="time",
    transform_func=lambda ds:preprocess(ds)
    )             
data.to_netcdf(f"{datapathPP}/shumi.nc")

data = read_netcdfs(
    files=f"{datapath}/atmosphericForcing/????/pair.mom.dta.nc",
    dim="time",
    transform_func=lambda ds:preprocess(ds)
    )             
data.to_netcdf(f"{datapathPP}/pair.nc")

data = read_netcdfs(
    files=f"{datapath}/atmosphericForcing/????/windxy.mom.dta.nc",
    dim="time",
    transform_func=lambda ds:preprocess(ds)
    )             
data.to_netcdf(f"{datapathPP}/windxy.nc")

data = read_netcdfs(
    files=f"{datapath}/atmosphericForcing/????/snow.mom.dta.nc",
    dim="time",
    transform_func=lambda ds:preprocess(ds)
    )             
data.to_netcdf(f"{datapathPP}/snow.nc")

Global seed set to 123
100%|██████████| 54/54 [23:30<00:00, 26.12s/it]
100%|██████████| 54/54 [21:54<00:00, 24.34s/it]
100%|██████████| 54/54 [1:30:15<00:00, 100.29s/it]
100%|██████████| 54/54 [03:25<00:00,  3.80s/it]


In [4]:
data = data.drop(["lon_bnds", "lat_bnds"])

In [5]:
data = data.rename(
    {
        "x":"lon",
        "y":"lat"
    }
)

In [38]:
#| export
class AtmosphericDataset(Dataset):
    def __init__(self, input_size, atmosphericData, runoff, transform=None):

        # Update 6.10.2023
        # The function is not handling the preprocessing anymore
        # which makes the function more flexible 
        # Following the technical description of the river data (Germo et al.)
        # the original river data is limited to 1979 to 2011
        # start_year, end_year = 1979, 2011
        # self.timeRange = slice(str(start_year), str(end_year))
        
        # Length of the sequence
        self.input_size = input_size

        # input data (x) 
        atmosphericDataStd = atmosphericData.std("time") # dimension will be channel, lat, lon
        atmosphericDataMean = atmosphericData.mean("time")
        self.atmosphericStats = (atmosphericDataMean, atmosphericDataStd)

        # output data - label (y)
        runoffData = runoff.transpose("time", "river")
        runoffDataMean = runoffData.mean("time")
        runoffDataSTD = runoffData.std("time")
        self.runoffDataStats = (runoffDataMean, runoffDataSTD)

        # save data
        np.savetxt(
            "/silor/boergel/paper/runoff_prediction/data/modelStats.txt",
            [runoffDataMean, runoffDataSTD]
        )
        
        # normalize data
        X = ((atmosphericData - atmosphericDataMean)/atmosphericDataStd).compute()
        y = ((runoffData - runoffDataMean)/runoffDataSTD).compute()
        
        # If only 3 dimension are available (time, lat, lon) 
        # an additional dimension for the channel is added
        # to end up with (time, channel, lat, lon)

        xStacked = X.to_array(dim='variable')
        xStacked = xStacked.transpose("time", "variable", "lat", "lon")

        # if len(xStacked.data.ndim) == 3:
        #     self.x = torch.tensor(xStacked.data, dtype=torch.float32).unsqueeze(dim=1)
        # else:
        assert xStacked.data.ndim == 4
        self.x = torch.tensor(xStacked.data, dtype=torch.float16)
        self.y = torch.tensor(y.data, dtype=torch.float16)

    def __getitem__(self, index):
        return self.x[index:index+(self.input_size)], self.y[index+int(self.input_size)]

    def __len__(self):
        return self.y.shape[0]-(self.input_size)

In [39]:
#| export
class AtmosphereDataModule(L.LightningDataModule):
    
    def __init__(self, atmosphericData, runoff, batch_size=64, num_workers=8, add_first_dim=True, input_size=30):
        super().__init__()

        self.data = atmosphericData
        self.runoff = runoff
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.add_first_dim = add_first_dim
        self.input_size = input_size
    
    def setup(self, stage:str):
        UserWarning("Loading atmospheric data ...")
        dataset = AtmosphericDataset(
            atmosphericData=self.data,
            runoff=self.runoff,
            input_size=self.input_size
            )
        n_samples = len(dataset)
        train_size = int(0.9 * n_samples)
        val_size = n_samples - train_size
        self.train, self.val, = random_split(dataset, [train_size, val_size])
        # val_size = int(0.1 * n_samples)
        # test_size = n_samples - train_size  - val_size
        # self.train, self.val, self.test = random_split(dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train,
            batch_size=self.batch_size,
            shuffle=True, 
            drop_last=True, 
            num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(
            dataset=self.val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True)

    # def test_dataloader(self):
    #     return DataLoader(
    #         self.test,
    #         batch_size=self.batch_size,
    #         shuffle=False,
    #         num_workers=self.num_workers, 
    #         drop_last=True)

In [40]:
#| export
class BaltNet(nn.Module):
    def __init__(self, modelPar):
        super(BaltNet, self).__init__()

        # initialize all attributes
        for k, v in modelPar.items():
            setattr(self, k, v)

        self.linear_dim = self.dimensions[0]*self.dimensions[1]*self.hidden_dim

        self.convLSTM = ConvLSTM(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                kernel_size=self.kernel_size,
                num_layers=self.num_layers,
                batch_first=self.batch_first,
                bias=self.bias,
                return_all_layers=self.return_all_layers
        )

        self.convLSTM2 = ConvLSTM(
                input_dim=self.input_dim,
                hidden_dim=self.hidden_dim,
                kernel_size=self.kernel_size,
                num_layers=1,
                batch_first=self.batch_first,
                bias=self.bias,
                return_all_layers=self.return_all_layers
        )

        # CNN layers to map the output of convLSTM2 to 97 rivers
        # self.cnn_layers = nn.Sequential(
        #     nn.Conv2d(self.hidden_dim, 32, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU(),
        #     nn.AdaptiveAvgPool2d((1, 1)),  # Global Average Pooling
        #     nn.Flatten(),
        #     nn.Linear(32, 97)
        # )

        # CNN layers to map the output of convLSTM2 to 97 rivers
        # self.cnn_layers = nn.Sequential(
        #     nn.Conv2d(self.hidden_dim, 128, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        #     nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        #     nn.Flatten(),
        #     nn.Linear(256 * (self.dimensions[0] // 4) * (self.dimensions[1] // 4), 97)
        # )
        
        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(self.linear_dim, 256),
            torch.nn.ReLU(),
            # torch.nn.Linear(512, 256),
            # torch.nn.ReLU(),
            torch.nn.Linear(256, 97)
            )

    def forward(self, x):
        _, encode_state = self.convLSTM(x)
        decoder_out, _ = self.convLSTM2(x[:,-1,:,:,:].unsqueeze(dim=1), encode_state)
        x = decoder_out[0].squeeze(1)
        # x = self.cnn_layers(x).squeeze()
        x = torch.flatten(x, start_dim=1)
        x = self.fc_layers(x).squeeze()
        return x


In [35]:
torch.cuda.empty_cache()


In [36]:
modelParameters = {
"input_dim": 1, # Number of channel, right now only precipitation
"hidden_dim": 8, # hidden states
"kernel_size":(5,5), # applied for spatial convolutions
"num_layers": 3, # number of convLSTM layers
"batch_first":True, # first index is batch
"bias":True, 
"return_all_layers": False, 
"dimensions": (191, 206) # dimensions of atmospheric forcing
}

In [37]:
### Setup model

# Loads the atmospheric data in batches
dataLoader = AtmosphereDataModule(
atmosphericData=data,
runoff=runoff,
batch_size=64,
input_size=30
)

num_epochs = 5

# initalize model
pyTorchBaltNet = BaltNet(modelPar=modelParameters)

# Lightning model wrapper
LighningBaltNet = LightningModel(
    pyTorchBaltNet,
    learning_rate=1e-3,
    cosine_t_max=40
)

# save best model 
callbacks = [
    ModelCheckpoint(
        dirpath="/silor/boergel/paper/runoff_prediction/data/modelWeights/",
        filename="BaltNetTopOne",
        save_top_k=1,
        mode="min",
        monitor="val_mse",
        save_last=True
    )
]

trainer = L.Trainer(
    callbacks=callbacks,
    max_epochs=num_epochs,
    accelerator="cuda",
    devices=2,
    logger=CSVLogger(
        save_dir="/silor/boergel/paper/runoff_prediction/logs",
        name="BaltNet1"
    ),
    deterministic=True,
)

trainer.fit(model=LighningBaltNet, datamodule=dataLoader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 123
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Global seed set to 123
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | BaltNet          | 40.3 M
1 | train_mse | MeanSquaredError | 0     
2 | val_mse   | MeanSquaredError | 0     


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.
