# Train

In [3]:
# %load_ext autoreload
# %autoreload 2

In [3]:
import torch
import numpy as np
import xarray as xr

from hython.datasets.datasets import get_dataset
from hython.trainer import train_val
from hython.sampler import SamplerBuilder, RegularIntervalDownsampler
from hython.metrics import MSEMetric
from hython.losses import RMSELoss, nll_loss
from hython.utils import read_from_zarr, set_seed
from hython.models.cudnnLSTM import CuDNNLSTM
from hython.trainer import RNNTrainer, RNNTrainParams
from hython.normalizer import Normalizer


import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

# viz
import matplotlib.pyplot as plt

In [4]:
from hython.metrics import Metric

In [14]:
import torch
from torch import nn


class CuDNNLSTM_UQ(nn.Module):
    def __init__(
        self,
        hidden_size: int = 24,
        dynamic_input_size: int = 3,
        static_input_size: int = 9,
        output_size: int = 2,
        static_to_dynamic: bool = True,
        num_layers:int = 1,
        dropout:float = 0.0,
    ):
        super(CuDNNLSTM_UQ, self).__init__()

        self.static_to_dynamic = static_to_dynamic

        self.fc0 = nn.Linear(dynamic_input_size + static_input_size, hidden_size)

        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers = num_layers, batch_first=True, dropout=dropout)

        self.mean_fc1 = nn.Linear(hidden_size, 12)
        self.mean_fc2 = nn.Linear(12, output_size)


        self.std_fc3 = nn.Linear(hidden_size, 12)
        self.std_fc4 = nn.Linear(12, output_size)
        

    def forward(self, x):

        l1 = self.fc0(x)

        lstm_output, (h_n, c_n) = self.lstm(l1)

        # Forward pass for the mean 
        mean = self.mean_fc1(lstm_output)
        #mean = torch.abs(self.mean_fc2(mean))
        mean = self.mean_fc2(mean)

        # Forward pass for the std
        std = self.std_fc3(lstm_output)
        std = torch.abs(self.std_fc4(std))       
        #std = torch.softplus(std) # Softplus to ensure positive std
        
        return mean, std

# Settings

In [5]:
EXPERIMENT  = "exp1"

SURROGATE_INPUT = "https://eurac-eo.s3.amazonaws.com/INTERTWIN/SURROGATE_INPUT/adg1km_eobs_preprocessed.zarr/"

SURROGATE_MODEL_OUTPUT = f"path/to/model/output/directory/{EXPERIMENT}.pt"

TMP_STATS = "path/to/temporary/stats/directory" 

# === FILTER ==============================================================

# train/test temporal range
train_temporal_range = slice("2012-01-01","2018-12-31")
test_temporal_range = slice("2019-01-01", "2019-12-31")

# variables
dynamic_names = ["precip", "pet", "temp"] 
static_names = [ "thetaS", "thetaR", "KsatVer", "SoilThickness", "RootingDepth", "f", "Swood", "Sl", "Kext"]
target_names = ["vwc", "actevap"]# ["vwc", "actevap", "snow", "snowwater"] 

# === MASK ========================================================================================

mask_names = ["mask_missing", "mask_lake"] # names depends on preprocessing application

# === DATASET ========================================================================================

DATASET = "LSTMDataset"

# == MODEL  ========================================================================================

HIDDEN_SIZE = 24
DYNAMIC_INPUT_SIZE = len(dynamic_names)
STATIC_INPUT_SIZE = len(static_names)
OUTPUT_SIZE = len(target_names)

TARGET_WEIGHTS = {t:1/len(target_names) for t in target_names}


# === SAMPLER/TRAINER ===================================================================================

EPOCHS = 20
BATCH = 256
SEED = 42

# downsampling, speeds up the training!

# - spatial

DONWSAMPLING = False
TRAIN_INTERVAL = [3,3]
TRAIN_ORIGIN = [0,0]

TEST_INTERVAL = [3,3]
TEST_ORIGIN = [2,2]

# - temporal
TEMPORAL_SUBSAMPLING = True
TEMPORAL_SUBSET = [150, 150] # n of sequences 
SEQ_LENGTH = 360


assert sum(v for v in TARGET_WEIGHTS.values()) == 1, "check target weights"
TARGET_INITIALS = "".join([i[0].capitalize() for i in target_names])

In [6]:
set_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
Xd = (
    read_from_zarr(url=SURROGATE_INPUT, group="xd", multi_index="gridcell")
    .sel(time=train_temporal_range)
    .xd.sel(feat=dynamic_names)
)
Xs = read_from_zarr(url=SURROGATE_INPUT, group="xs", multi_index="gridcell").xs.sel(
    feat=static_names
)
Y = (
    read_from_zarr(url=SURROGATE_INPUT, group="y", multi_index="gridcell")
    .sel(time=train_temporal_range)
    .y.sel(feat=target_names)
)

SHAPE = Xd.attrs["shape"]


# === READ TEST ===================================================================

Y_test = (
    read_from_zarr(url=SURROGATE_INPUT, group="y", multi_index="gridcell")
    .sel(time=test_temporal_range)
    .y.sel(feat=target_names)
)
Xd_test = (
    read_from_zarr(url=SURROGATE_INPUT, group="xd", multi_index="gridcell")
    .sel(time=test_temporal_range)
    .xd.sel(feat=dynamic_names)
)

In [8]:
masks = (
    read_from_zarr(url=SURROGATE_INPUT, group="mask")
    .mask.sel(mask_layer=mask_names)
    .any(dim="mask_layer")
)

In [9]:
if DONWSAMPLING:
    train_downsampler = RegularIntervalDownsampler(
        intervals=TRAIN_INTERVAL, origin=TRAIN_ORIGIN
    )       
    test_downsampler = RegularIntervalDownsampler(
        intervals=TEST_INTERVAL, origin=TEST_ORIGIN
    )
else:
    train_downsampler, test_downsampler = None, None

In [10]:
normalizer_dynamic = Normalizer(method="standardize",
                                type="spacetime", axis_order="NTC")
normalizer_static = Normalizer(method="standardize",
                               type="space", axis_order="NTC")
normalizer_target = Normalizer(method="standardize", type="spacetime",
                               axis_order="NTC")

In [11]:
train_dataset = get_dataset(DATASET)(
        Xd,
        Y,
        Xs,
        original_domain_shape=SHAPE,
        mask=masks,
        downsampler=train_downsampler,
        normalizer_dynamic=normalizer_dynamic,
        normalizer_static=normalizer_static,
        normalizer_target=normalizer_target
)
test_dataset = get_dataset(DATASET)(
        Xd_test,
        Y_test,
        Xs,
        original_domain_shape=SHAPE,
        mask=masks,
        downsampler=test_downsampler,
        normalizer_dynamic=normalizer_dynamic,
        normalizer_static=normalizer_static,
        normalizer_target=normalizer_target
)

In [15]:
# === SAMPLER ===================================================================


train_sampler_builder = SamplerBuilder(
    train_dataset,
    sampling="random", 
    processing="single-gpu")

test_sampler_builder = SamplerBuilder(
    test_dataset,
    sampling="sequential", 
    processing="single-gpu")


train_sampler = train_sampler_builder.get_sampler()
test_sampler = test_sampler_builder.get_sampler()

In [16]:
# === DATA LOADER ===================================================================

train_loader = DataLoader(train_dataset, batch_size=BATCH , sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=BATCH , sampler=test_sampler)

In [25]:
# === MODEL ===================================================================

model = CuDNNLSTM(
                hidden_size=HIDDEN_SIZE, 
                dynamic_input_size=DYNAMIC_INPUT_SIZE,
                static_input_size=STATIC_INPUT_SIZE, 
                output_size=OUTPUT_SIZE
)

model.to(device)

CuDNNLSTM(
  (fc0): Linear(in_features=12, out_features=24, bias=True)
  (lstm): LSTM(24, 24, batch_first=True)
  (fc1): Linear(in_features=24, out_features=2, bias=True)
)

In [17]:
model_uq = CuDNNLSTM_UQ(output_size=2)
model_uq

CuDNNLSTM_UQ(
  (fc0): Linear(in_features=12, out_features=24, bias=True)
  (lstm): LSTM(24, 24, batch_first=True)
  (mean_fc1): Linear(in_features=24, out_features=12, bias=True)
  (mean_fc2): Linear(in_features=12, out_features=2, bias=True)
  (std_fc3): Linear(in_features=24, out_features=12, bias=True)
  (std_fc4): Linear(in_features=12, out_features=2, bias=True)
)

In [26]:
opt = optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10)

loss_fn = RMSELoss(target_weight=TARGET_WEIGHTS)
loss_fn = nll_loss(target_weight=TARGET_WEIGHTS)

metric_fn = MSEMetric(target_names=target_names)

trainer = RNNTrainer(
    RNNTrainParams(
            experiment=EXPERIMENT,
            temporal_subsampling=TEMPORAL_SUBSAMPLING, 
            temporal_subset=TEMPORAL_SUBSET, 
            seq_length=SEQ_LENGTH, 
            target_names=target_names,
            metric_func=metric_fn,
            loss_func=loss_fn)
)

In [None]:
model, loss_history, metric_history = train_val(
    trainer,
    model,
    train_loader,
    test_loader,
    EPOCHS,
    opt,
    lr_scheduler,
    SURROGATE_MODEL_OUTPUT,
    device
)

In [None]:
lepochs = list(range(1, EPOCHS + 1))

fig, axs = plt.subplots(len(target_names) +1, 1, figsize= (12,10), sharex=True)

axs[0].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['train']], marker='.', linestyle='-', color='b', label='Training')
axs[0].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['val']], marker='.', linestyle='-', color='r', label='Validation')
axs[0].set_title('Loss')
axs[0].set_ylabel(loss_fn.__name__)
axs[0].grid(True)
axs[0].legend(bbox_to_anchor=(1,1))

for i, variable in enumerate(target_names):
    axs[i+1].plot(lepochs, metric_history[f'train_{variable}'], marker='.', linestyle='-', color='b', label='Training')
    axs[i+1].plot(lepochs, metric_history[f'val_{variable}'], marker='.', linestyle='-', color='r', label='Validation')
    axs[i+1].set_title(variable)
    axs[i+1].set_ylabel(metric_fn.__class__.__name__)
    axs[i+1].grid(True)
    axs[i+1].legend(bbox_to_anchor=(1,1))

In [258]:
model_uq = CuDNNLSTM_UQ(output_size=1)
model_uq

CuDNNLSTM_UQ(
  (fc0): Linear(in_features=12, out_features=24, bias=True)
  (lstm): LSTM(24, 24, batch_first=True)
  (mean_fc1): Linear(in_features=24, out_features=12, bias=True)
  (mean_fc2): Linear(in_features=12, out_features=1, bias=True)
  (std_fc3): Linear(in_features=24, out_features=12, bias=True)
  (std_fc4): Linear(in_features=12, out_features=1, bias=True)
)

In [259]:
for dynamic_b, static_b, targets_b in train_loader:
    print(dynamic_b.shape)
    print(static_b.shape)
    print(targets_b.shape)
    break

torch.Size([256, 2557, 3])
torch.Size([256, 9])
torch.Size([256, 2557, 2])


In [275]:
time_range =  next(iter(train_loader))[0].shape[1]
time_range

2557

In [276]:
time_index = np.arange(0, time_range)
time_index

array([   0,    1,    2, ..., 2554, 2555, 2556])

In [277]:
for dynamic_b, static_b, targets_b in train_loader:
    batch_temporal_loss = 0            

    # every batch
    #self.temporal_index( dynamic_b.shape[1])

    for t in time_index:  # time_index could be a subset of time indices
        # filter sequence
        dynamic_bt = dynamic_b[:, t : (t + 360)].to(device)
        targets_bt = targets_b[:, t : (t + 360)].to(device)
        
        # static --> dynamic size (repeat time dim)
        static_bt = static_b.unsqueeze(1).repeat(1, dynamic_bt.size(1), 1).to(device)
        
        x_concat = torch.cat(
            (dynamic_bt, static_bt),
            dim=-1,
        )
        break
        output = model(x_concat)

print(dynamic_bt.shape)
print(targets_bt.shape)
print(static_bt.shape)
print(x_concat.shape)

torch.Size([72, 360, 3])
torch.Size([72, 360, 2])
torch.Size([72, 360, 9])
torch.Size([72, 360, 12])


In [278]:
model_uq = CuDNNLSTM_UQ(output_size=1)
model_uq

CuDNNLSTM_UQ(
  (fc0): Linear(in_features=12, out_features=24, bias=True)
  (lstm): LSTM(24, 24, batch_first=True)
  (mean_fc1): Linear(in_features=24, out_features=12, bias=True)
  (mean_fc2): Linear(in_features=12, out_features=1, bias=True)
  (std_fc3): Linear(in_features=24, out_features=12, bias=True)
  (std_fc4): Linear(in_features=12, out_features=1, bias=True)
)

In [279]:
output1 = model_uq(x_concat)
mean_outputs = output1[0]#.shape
std_outputs = output1[1] #.shape

print(mean_outputs.shape)
print(std_outputs.shape)

torch.Size([72, 360, 1])
torch.Size([72, 360, 1])


In [280]:
def predict_step(arr, steps=-1):
    """Return the n steps that should be predicted"""
    return arr[:, steps]

In [281]:
# mean & std for the last step
mean_outputs_l = predict_step(mean_outputs, steps=-1)
std_outputs_l = predict_step(std_outputs, steps=-1)
target = predict_step(targets_bt, steps=-1)

In [282]:
std_outputs_l[:,0]

tensor([0.1925, 0.1388, 0.1482, 0.1898, 0.1870, 0.1928, 0.1722, 0.1650, 0.2206,
        0.1707, 0.1607, 0.1193, 0.1531, 0.1312, 0.1467, 0.1281, 0.1221, 0.1448,
        0.1932, 0.2019, 0.1272, 0.1145, 0.2065, 0.1896, 0.1488, 0.2050, 0.1964,
        0.1662, 0.2078, 0.1626, 0.2075, 0.1897, 0.1616, 0.2069, 0.1559, 0.1847,
        0.2019, 0.1305, 0.2109, 0.1898, 0.1579, 0.1635, 0.1553, 0.1904, 0.1980,
        0.1293, 0.2036, 0.1855, 0.1947, 0.2072, 0.1480, 0.1729, 0.1885, 0.1702,
        0.1487, 0.2124, 0.1338, 0.1615, 0.1175, 0.1717, 0.1327, 0.1640, 0.1819,
        0.2238, 0.1655, 0.2044, 0.1747, 0.1923, 0.1338, 0.1237, 0.1978, 0.1632],
       grad_fn=<SelectBackward0>)

In [283]:
mean_outputs_l[:,0]

tensor([-0.1149, -0.1172, -0.0622, -0.1274, -0.0982, -0.1038, -0.1549, -0.1286,
        -0.0864, -0.1077, -0.0671, -0.1199, -0.1016, -0.1021, -0.1969, -0.0946,
        -0.1836, -0.1057, -0.1627, -0.1808, -0.1001, -0.1105, -0.0941, -0.1365,
        -0.1213, -0.1500, -0.1589, -0.0924, -0.0294,  0.0159, -0.0645, -0.0498,
        -0.1546, -0.1217, -0.1098, -0.1379, -0.0771, -0.1428, -0.1187, -0.1304,
        -0.1053, -0.1140, -0.1037, -0.1473, -0.1499, -0.1009, -0.0682, -0.0718,
        -0.1540, -0.1487, -0.1160, -0.1366, -0.1136, -0.1266, -0.0650, -0.0944,
        -0.1107, -0.1085, -0.1086, -0.0467, -0.1203, -0.1054, -0.1646, -0.1188,
        -0.0933, -0.1201, -0.1538, -0.1326, -0.1190, -0.1129, -0.1234, -0.1703],
       grad_fn=<SelectBackward0>)

In [284]:
dist = Normal(mean_outputs_l[:,0], std_outputs_l[:,0])
#nll = -dist.log_prob(target[:,0]).mean()

In [285]:
dist.entropy()#.mean()

tensor([-0.2285, -0.5555, -0.4903, -0.2427, -0.2579, -0.2270, -0.3399, -0.3831,
        -0.0924, -0.3492, -0.4093, -0.7076, -0.4578, -0.6120, -0.5005, -0.6363,
        -0.6842, -0.5136, -0.2251, -0.1812, -0.6428, -0.7480, -0.1587, -0.2437,
        -0.4859, -0.1660, -0.2086, -0.3755, -0.1525, -0.3973, -0.1539, -0.2434,
        -0.4036, -0.1567, -0.4398, -0.2700, -0.1811, -0.6173, -0.1373, -0.2431,
        -0.4266, -0.3921, -0.4433, -0.2398, -0.2008, -0.6265, -0.1728, -0.2660,
        -0.2172, -0.1552, -0.4919, -0.3361, -0.2495, -0.3520, -0.4868, -0.1306,
        -0.5925, -0.4040, -0.7223, -0.3431, -0.6010, -0.3891, -0.2853, -0.0780,
        -0.3799, -0.1687, -0.3258, -0.2297, -0.5922, -0.6707, -0.2017, -0.3936],
       grad_fn=<AddBackward0>)

In [286]:
dist = Normal(-0.1149, 0.1925)
dist.entropy()

tensor(-0.2287)

In [288]:
??dist.entropy

[0;31mSignature:[0m [0mdist[0m[0;34m.[0m[0mentropy[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Method to compute the entropy using Bregman divergence of the log normalizer.
[0;31mSource:[0m   
    [0;32mdef[0m [0mentropy[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0;36m0.5[0m [0;34m+[0m [0;36m0.5[0m [0;34m*[0m [0mmath[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0;36m2[0m [0;34m*[0m [0mmath[0m[0;34m.[0m[0mpi[0m[0;34m)[0m [0;34m+[0m [0mtorch[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mscale[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mFile:[0m      ~/miniconda3/envs/basic/lib/python3.11/site-packages/torch/distributions/normal.py
[0;31mType:[0m      method

In [93]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import torch.nn.functional as F


## Negative log-likelihood (NLL) loss
def nll_loss(y, distr_mean, distr_std):
    # Normal distribution
    dist = Normal(distr_mean, distr_std)
    return -dist.log(y).mean()



In [148]:
from torch.distributions import Normal

class nll_loss(_Loss):
    __name__ = "NLL"

    def __init__(
        self,
        target_weight: dict = None,
    ):
        """
        Negative log-likelihood (NLL) loss for normal distribution.

         Parameters:
         target_weight: List of targets that contribute in the loss computation, with their associated weights.
                        In the form {target: weight}
        """

        super(nll_loss, self).__init__()
        self.target_weight = target_weight

    def forward(self, y_true, distr_mean, distr_std):
        """
        Calculate the negative log-likelihood of the underlying normal distribution.

        Parameters:
        y_true (torch.Tensor): The true values.
        distr_mean (torch.Tensor): The predicted mean values. 
        distr_std (torch.Tensor): The predicted std values.

        Shape
        y_true: torch.Tensor of shape (N, T).
        distr_mean: torch.Tensor of shape (N, T).
        distr_std: torch.Tensor of shape (N, T).
        (256,3) means 256 samples with 3 targets.

        Returns:
        torch.Tensor: The NLL loss.
        """
        if self.target_weight is None:
            dist = Normal(distr_mean, distr_std)
            total_nll_loss = -dist.log_prob(y_true).mean()

        else:
            total_nll_loss = 0
            for idx, k in enumerate(self.target_weight):
                w = self.target_weight[k]
                dist = Normal(distr_mean[:, idx], distr_std[:, idx])
                nll_loss = -dist.log_prob(y_true[:, idx]).mean()
                loss = nll_loss * w
                total_nll_loss += loss

        return total_nll_loss

In [289]:
t = ["ah", "by"]
if isinstance(t, list):
    print("ok")

ok


In [None]:
def loss_batch(loss_func, output, target, opt=None, gradient_clip = None, model=None, add_losses: dict = {}):

    if isinstance(output, list):
        loss = loss_func(target, output[0], output[1])
        
    elif target.shape[-1] == 1:
        target = torch.squeeze(target)
        output = torch.squeeze(output)
        loss = loss_func(target, output)
    else:
        loss = loss_func(target, output)

    # compound more losses, in case dict is not empty
    # TODO: add user-defined weights
    for k in add_losses:
        loss += add_losses[k]

    if opt is not None: 
        opt.zero_grad()
        loss.backward()

        if gradient_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), **gradient_clip)

        opt.step()

    return loss

In [None]:
class EntropyMetric(Metric):
    """
    Entropy

    Parameters
    ----------
    y_pred (numpy.array): The true values.
    y_true (numpy.array): The predicted values.
    target_names: List of targets that contribute in the loss computation.

    Returns
    -------
    Dictionary of Entropy metric for each target. {'target': entropy_metric}
    
    """ 
    metrics = {}
    for idx, target in enumerate(target_names)
    

In [None]:
    metrics = {}

    for idx, target in enumerate(target_names):
        observed = y_true[:, idx]
        simulated = y_pred[:, idx]
        r = np.corrcoef(observed, simulated)[1, 0]
        alpha = np.std(simulated, ddof=1) / np.std(observed, ddof=1)
        beta = np.mean(simulated) / np.mean(observed)
        kge = 1 - np.sqrt(
            np.power(r - 1, 2) + np.power(alpha - 1, 2) + np.power(beta - 1, 2)
        )
        metrics[target] = kge

    return metrics


In [None]:
class MSEMetric(Metric):
    """
    Mean Squared Error (MSE)

    Parameters
    ----------
    y_pred (numpy.array): The true values.
    y_true (numpy.array): The predicted values.
    target_names: List of targets that contribute in the loss computation.

    Returns
    -------
    Dictionary of Entropy metric for each target. {'target': entropy_metric}
    
    """
    def __call__(self, y_pred, y_true, target_names: list[str]):
        return metric_decorator(y_pred, y_true, target_names)(compute_mse)()

In [154]:
nll_loss1 = nll_loss(target_weight=TARGET_WEIGHTS)

In [155]:
nll_loss1(target, mean_outputs_l, std_outputs_l)

tensor(126.3684, grad_fn=<AddBackward0>)

In [None]:
# Neural network model for the mean and std
class ProbabilisticModel(nn.Module):
    def __init__(self, input_size):
        super(ProbabilisticModel, self).__init__()
        
        # Mean network
        self.mean_fc1 = nn.Linear(input_size, 100)
        self.mean_fc2 = nn.Linear(100, 50)
        self.mean_out = nn.Linear(50, 1)

        # Standard deviation network
        self.std_fc1 = nn.Linear(input_size, 100)
        self.std_fc2 = nn.Linear(100, 50)
        self.std_fc3 = nn.Linear(50, 20)
        self.std_out = nn.Linear(20, 1)
        
    def forward(self, x):
        # Forward pass for the mean
        mean = F.relu(self.mean_fc1(x))
        mean = F.relu(self.mean_fc2(mean))
        mean = self.mean_out(mean)
        
        # Forward pass for the std deviation
        std = F.relu(self.std_fc1(x))
        std = F.dropout(std, p=0.1, training=self.training) # Dropout
        std = F.relu(self.std_fc2(std))
        std = F.dropout(std, p=0.1, training=self.training) # Dropout
        std = F.relu(self.std_fc3(std))
        std = torch.softplus(self.std_out(std)) # Softplus to ensure positive std
        
        return mean, std