# Train

In [1]:
# %load_ext autoreload
# %autoreload 2

In [3]:
import xarray as xr
import torch

from hython.models.convLSTM import ConvLSTM
from hython.datasets.datasets import get_dataset
from hython.sampler import SamplerBuilder, CubeletsDownsampler
from hython.trainer import HythonTrainer, RNNTrainParams, train_val
from hython.metrics import MSEMetric
from hython.losses import RMSELoss
from hython.normalizer import Normalizer
from hython.utils import write_to_zarr, read_from_zarr, set_seed


import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

# Settings

In [4]:

EXPERIMENT  = "exp1"

SURROGATE_INPUT = "https://eurac-eo.s3.amazonaws.com/INTERTWIN/SURROGATE_INPUT/adg1km_eobs_original.zarr/"

SURROGATE_MODEL_OUTPUT = f"path/to/model/output/directory/{EXPERIMENT}.pt"

TMP_STATS = "path/to/temporary/stats/directory" 

# === FILTER ==============================================================

# train/test temporal range
train_temporal_range = slice("2012-01-01","2018-12-31")
test_temporal_range = slice("2019-01-01", "2020-12-31")

# variables
dynamic_names = ["precip", "pet", "temp"] 
static_names = [ "thetaS", "thetaR", "KsatVer", "SoilThickness", "RootingDepth", "f", "Swood", "Sl", "Kext"]
target_names = ["vwc", "actevap"]# ["vwc", "actevap", "snow", "snowwater"] 

# === MASK ========================================================================================

mask_names = ["mask_missing", "mask_lake"] # names depends on preprocessing application

# === DATASET ========================================================================================

DATASET = "CubeletsDataset" 

XSIZE,YSIZE, TSIZE = 10, 10, 360
XOVER,YOVER,TOVER = 5, 5, 220

MISSING_POLICY = 0.05 # "any", "all"

# == MODEL  ========================================================================================

HIDDEN_SIZE = 36 # 
DYNAMIC_INPUT_SIZE = len(dynamic_names)
STATIC_INPUT_SIZE = len(static_names)
KERNEL_SIZE = (3, 3) # height, width
NUM_LSTM_LAYER = 1
OUTPUT_SIZE = len(target_names)

TARGET_WEIGHTS = {t:1/len(target_names) for t in target_names}


# === SAMPLER/TRAINER ===================================================================================

# downsampling
DONWSAMPLING = False

TEMPORAL_FRAC = [0.8, 0.8] # train, test
SPATIAL_FRAC = [1, 1]  # train, test

# gradient clipping
gradient_clip = {"max_norm":1} # None

SEED = 42
EPOCHS = 20
BATCH = 32


assert (sum(v for v in TARGET_WEIGHTS.values()) - 1) < 0.01, "check target weights"

In [5]:
set_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# === READ TRAIN ===================================================================
Xd = (
    read_from_zarr(url=SURROGATE_INPUT , group="xd")
    .sel(time=train_temporal_range)[dynamic_names]
)
Xs = read_from_zarr(url=SURROGATE_INPUT , group="xs")[static_names]

Y = (
    read_from_zarr(url=SURROGATE_INPUT , group="y")
    .sel(time=train_temporal_range)[target_names]
)

SHAPE = Xd.attrs["shape"]

# === READ TEST ===================================================================

Y_test = (
    read_from_zarr(url=SURROGATE_INPUT , group="y")
    .sel(time=test_temporal_range)[target_names]
)
Xd_test = (
    read_from_zarr(url=SURROGATE_INPUT , group="xd")
    .sel(time=test_temporal_range)[dynamic_names]
)

In [7]:
masks = (
    read_from_zarr(url=SURROGATE_INPUT, group="mask")
    .mask.sel(mask_layer=mask_names)
    .any(dim="mask_layer")
)

In [8]:
if DONWSAMPLING:
    train_downsampler = CubeletsDownsampler(
        temporal_downsample_fraction= TEMPORAL_FRAC[0], 
        spatial_downsample_fraction= SPATIAL_FRAC[0]
    )       
    test_downsampler = CubeletsDownsampler(
        temporal_downsample_fraction= TEMPORAL_FRAC[-1], 
        spatial_downsample_fraction= SPATIAL_FRAC[-1]
    )
else:
    train_downsampler,test_downsampler = None, None

In [9]:
normalizer_dynamic = Normalizer(method = "standardize", type="spacetime", axis_order = "xarray_dataset")
                                #save_stats=  f"{TMP_STATS}/{EXPERIMENT}_xd.nc")
normalizer_static = Normalizer(method = "standardize", type="space", axis_order = "xarray_dataset")
                               #save_stats=  f"{TMP_STATS}/{EXPERIMENT}_xs.nc")
normalizer_target = Normalizer(method = "standardize", type="spacetime", axis_order = "xarray_dataset")
                               #save_stats=  f"{TMP_STATS}/{EXPERIMENT}_y.nc")

In [10]:
train_dataset = get_dataset(DATASET)(Xd.chunk("auto"), 
                          Y.chunk("auto"),
                          Xs.chunk("auto"),
                          mask = masks,
                          downsampler=train_downsampler,
                          # normalizer_dynamic = normalizer_dynamic, 
                          # normalizer_static = normalizer_static,
                          # normalizer_target = normalizer_target,
                          shape=Xd.precip.shape, # time, lat, lon
                          batch_size={"xsize":XSIZE,"ysize":YSIZE,"tsize":TSIZE}, 
                          overlap={"xover":XOVER, "yover":YOVER, "tover":TOVER},
                          missing_policy=MISSING_POLICY,
                          fill_missing=0,
                          persist=True, 
                          lstm_1d=False, 
                          static_to_dynamic=True
                         )
test_dataset = get_dataset(DATASET)(Xd_test.chunk("auto"), 
                          Y_test.chunk("auto"),
                          Xs.chunk("auto"),
                          mask = masks,
                          downsampler=test_downsampler,
                          # normalizer_dynamic = normalizer_dynamic, 
                          # normalizer_static = normalizer_static,
                          # normalizer_target = normalizer_target,
                          shape=Xd_test.precip.shape, # time, lat, lon
                          batch_size={"xsize":XSIZE,"ysize":YSIZE,"tsize":TSIZE}, 
                          overlap={"xover":XOVER, "yover":YOVER, "tover":TOVER},
                          missing_policy=MISSING_POLICY,
                          fill_missing=0,
                          persist=True, 
                          lstm_1d=False, 
                          static_to_dynamic=True
                         )

In [11]:
len(train_dataset), len(test_dataset)

(7440, 1395)

In [12]:
# === SAMPLER ===================================================================


train_sampler_builder = SamplerBuilder(
    train_dataset,
    sampling="random", 
    processing="single-gpu")

test_sampler_builder = SamplerBuilder(
    test_dataset,
    sampling="sequential", 
    processing="single-gpu")


train_sampler = train_sampler_builder.get_sampler()
test_sampler = test_sampler_builder.get_sampler()

In [13]:
# === DATA LOADER ================================================================

train_loader = DataLoader(train_dataset, batch_size=BATCH , sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=BATCH , sampler=test_sampler)

In [14]:
# === MODEL ===================================================================

model = ConvLSTM(
    input_dim =  DYNAMIC_INPUT_SIZE + STATIC_INPUT_SIZE,
    output_dim= OUTPUT_SIZE,
    hidden_dim = HIDDEN_SIZE,
    kernel_size = KERNEL_SIZE,
    num_layers = NUM_LSTM_LAYER,
    batch_first = True,
    bias = False,
    return_all_layers = False
).to(device)

In [15]:
opt = optim.Adam(model.parameters(), lr=1e-3)
lr_scheduler = ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=10)

loss_fn = RMSELoss(target_weight=TARGET_WEIGHTS)
metric_fn = MSEMetric(target_names=target_names)

trainer = HythonTrainer(
    RNNTrainParams(
               experiment=EXPERIMENT, 
               target_names=target_names,
               metric_func=metric_fn,
               loss_func=loss_fn,
               #gradient_clip= gradient_clip
    )
)

<hython.trainer.RNNTrainParams object at 0x7571a49a3af0>


In [None]:
model, loss_history, metric_history = train_val(
    trainer,
    model,
    train_loader,
    test_loader,
    EPOCHS,
    opt,
    lr_scheduler,
    SURROGATE_MODEL_OUTPUT,
    device
)

In [None]:
lepochs = list(range(1, EPOCHS + 1))

fig, axs = plt.subplots(len(target_names) +1, 1, figsize= (12,10), sharex=True)

axs[0].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['train']], marker='.', linestyle='-', color='b', label='Training')
axs[0].plot(lepochs, [i.detach().cpu().numpy() for i in loss_history['val']], marker='.', linestyle='-', color='r', label='Validation')
axs[0].set_title('Loss')
axs[0].set_ylabel(loss_fn.__name__)
axs[0].grid(True)
axs[0].legend(bbox_to_anchor=(1,1))

for i, variable in enumerate(target_names):
    axs[i+1].plot(lepochs, metric_history[f'train_{variable}'], marker='.', linestyle='-', color='b', label='Training')
    axs[i+1].plot(lepochs, metric_history[f'val_{variable}'], marker='.', linestyle='-', color='r', label='Validation')
    axs[i+1].set_title(variable)
    axs[i+1].set_ylabel(metric_fn.__class__.__name__)
    axs[i+1].grid(True)
    axs[i+1].legend(bbox_to_anchor=(1,1))