In [1]:
import os
import random
import time

import numpy as np
import torch
from absl import app
# from klearn_tcyclone.training_utils.args import FLAGS, ALL_FLAGS
from klearn_tcyclone.training_utils.training_utils import get_default_flag_values
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils import data

from klearn_tcyclone.climada.tc_tracks import TCTracks
from klearn_tcyclone.data_utils import (
    LinearScaler,
)
from klearn_tcyclone.KNF.modules.eval_metrics import RMSE_TCTracks
from klearn_tcyclone.KNF.modules.models import Koopman
from klearn_tcyclone.KNF.modules.train_utils import (
    eval_epoch_koopman,
    train_epoch_koopman,
)
from klearn_tcyclone.knf_data_utils import TCTrackDataset
from klearn_tcyclone.training_utils.training_utils import set_flags
from absl import app, flags

from klearn_tcyclone.training_utils.training_utils import extend_by_default_flag_values

from klearn_tcyclone.koopkernel_seq2seq import KoopmanKernelSeq2Seq, RBFKernel
from klearn_tcyclone.koopkernel_seq2seq import KoopKernelLoss, batch_tensor_context

In [2]:
torch.cuda.is_available()

True

## Import data

Set some specific parameters and load default values for all other parameters.

In [3]:
flag_params = {
    # "seed": 42,
    "year_range": [1980, 1988],
    # "batch_size": 16,
    "num_epochs": 2,
    "train_output_length": 1,
    "context_length": 14,
}
flag_params["input_length"] = flag_params["context_length"]
flag_params = extend_by_default_flag_values(flag_params)

In [4]:
random.seed(flag_params["seed"])  # python random generator
np.random.seed(flag_params["seed"])  # numpy random generator

torch.manual_seed(flag_params["seed"])
torch.cuda.manual_seed_all(flag_params["seed"])

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

feature_list = [
    "lon",
    "lat",
    "max_sustained_wind",
    # "radius_max_wind",
    # "radius_oci",
    "central_pressure",
    "environmental_pressure",
]

# feature_list = [
#     "lon",
#     "lat",
#     "max_sustained_wind",
#     "radius_max_wind",
#     "radius_oci",
#     "central_pressure",
#     "environmental_pressure",
# ]

# these are not contained as flags
# encoder_hidden_dim = flag_params["hidden_dim"]
# decoder_hidden_dim = flag_params["hidden_dim"]
# encoder_num_layers = flag_params["num_layers"]
# decoder_num_layers = flag_params["num_layers"]

output_dim = flag_params["input_dim"]
num_feats = len(feature_list)
learning_rate = flag_params["learning_rate"]
# ---------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device", device)

scaler = LinearScaler()
eval_metric = RMSE_TCTracks

Device cuda


In [5]:
# Datasets
tc_tracks = TCTracks.from_ibtracs_netcdf(
    provider="official",
    year_range=flag_params["year_range"],
    basin="NA",
)

tc_tracks_train, tc_tracks_test = train_test_split(tc_tracks.data, test_size=0.1)



  if ibtracs_ds.dims['storm'] == 0:


In [6]:
len(tc_tracks_train), tc_tracks_train[5]

(73,
 <xarray.Dataset> Size: 8kB
 Dimensions:                 (time: 134)
 Coordinates:
   * time                    (time) datetime64[ns] 1kB 1986-08-13T12:00:00 ......
     lat                     (time) float32 536B 30.1 30.45 30.8 ... 56.2 56.2
     lon                     (time) float32 536B -84.0 -84.0 -84.0 ... 7.0 8.0
 Data variables:
     radius_max_wind         (time) float32 536B 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
     radius_oci              (time) float32 536B 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
     max_sustained_wind      (time) float32 536B 10.0 10.0 10.0 ... 15.0 15.0
     central_pressure        (time) float32 536B 1.009e+03 1.01e+03 ... 1.006e+03
     environmental_pressure  (time) float64 1kB 1.01e+03 1.01e+03 ... 1.01e+03
     time_step               (time) float64 1kB 3.0 3.0 3.0 3.0 ... 3.0 3.0 3.0
     basin                   (time) <U2 1kB 'NA' 'NA' 'NA' ... 'NA' 'NA' 'NA'
 Attributes:
     max_sustained_wind_unit:  kn
     central_pressure_unit:    mb
     orig_event

In [7]:
from klearn_tcyclone.data_utils import (
    load_model,
    standardized_context_dataset_from_TCTracks,
)

In [8]:
time_lag = 1
scaler = LinearScaler()
basin = "NA"

The idea is to feed the tensor_context_dataset into the model. Because for the kernels I need the lookback window and the shifted version of this.

In [9]:
feature_list

['lon',
 'lat',
 'max_sustained_wind',
 'central_pressure',
 'environmental_pressure']

In [10]:
flag_params["context_length"]

14

In [11]:
tensor_context_train_standardized = standardized_context_dataset_from_TCTracks(
    tc_tracks_train,
    feature_list=feature_list,
    scaler=scaler,
    context_length=flag_params["context_length"],
    time_lag=time_lag,
    fit=True,
    periodic_shift=True,
    basin=basin,
)

In [12]:

train_set = TCTrackDataset(
    input_length=flag_params["input_length"],
    output_length=flag_params["train_output_length"],
    tc_tracks=tc_tracks_train,
    feature_list=feature_list,
    mode="train",
    jumps=flag_params["jumps"],
    scaler=scaler,
    fit=True,
)
valid_set = TCTrackDataset(
    input_length=flag_params["input_length"],
    output_length=flag_params["train_output_length"],
    tc_tracks=tc_tracks_train,
    feature_list=feature_list,
    mode="valid",
    jumps=flag_params["jumps"],
    scaler=scaler,
    fit=False,
)
test_set = TCTrackDataset(
    input_length=flag_params["input_length"],
    output_length=flag_params["test_output_length"],
    tc_tracks=tc_tracks_test,
    feature_list=feature_list,
    mode="test",
    # jumps=flag_params["jumps"], # jumps not used in test mode
    scaler=scaler,
    fit=False,
)
train_loader = data.DataLoader(
    train_set, batch_size=flag_params["batch_size"], shuffle=True, num_workers=1
)
valid_loader = data.DataLoader(
    valid_set, batch_size=flag_params["batch_size"], shuffle=True, num_workers=1
)
test_loader = data.DataLoader(
    test_set, batch_size=flag_params["batch_size"], shuffle=False, num_workers=1
)

if len(train_loader) == 0:
    raise Exception(
        "There are likely too few data points in the test set. Try to increase year_range."
    )

Check why we have nan values!!!

In [13]:
train_loader.dataset[0][0].shape

torch.Size([14, 5])

In [14]:
counter = 0
for inps, tgts in train_loader:
    if counter < 5:
        print(counter)
        print(inps.shape, type(inps))
        print(tgts.shape, type(inps))
        print(inps[0,:,0])
        print(tgts[0,:,0])
        print()
    
    counter += 1


0
torch.Size([32, 14, 5]) <class 'torch.Tensor'>
torch.Size([32, 1, 5]) <class 'torch.Tensor'>
tensor([0.7815, 0.8036, 0.8226, 0.8416, 0.8575, 0.8733, 0.8892, 0.9050, 0.9184,
        0.9319, 0.9422, 0.9525, 0.9588, 0.9652])
tensor([0.9667])

1
torch.Size([32, 14, 5]) <class 'torch.Tensor'>
torch.Size([32, 1, 5]) <class 'torch.Tensor'>
tensor([0.0538, 0.0689, 0.0784, 0.0879, 0.0855, 0.0831, 0.0744, 0.0657, 0.0499,
        0.0340, 0.0293, 0.0245, 0.0238, 0.0230])
tensor([0.0245])

2
torch.Size([32, 14, 5]) <class 'torch.Tensor'>
torch.Size([32, 1, 5]) <class 'torch.Tensor'>
tensor([-0.2043, -0.2162, -0.2257, -0.2352, -0.2494, -0.2637, -0.2755, -0.2874,
        -0.3017, -0.3159, -0.3294, -0.3428, -0.3555, -0.3682])
tensor([-0.3824])

3
torch.Size([32, 14, 5]) <class 'torch.Tensor'>
torch.Size([32, 1, 5]) <class 'torch.Tensor'>
tensor([-0.0428, -0.0546, -0.0665, -0.0784, -0.0895, -0.1006, -0.1108, -0.1211,
        -0.1314, -0.1417, -0.1520, -0.1623, -0.1734, -0.1845])
tensor([-0.1940])

4


In [63]:
rbf = RBFKernel(length_scale=1.0)
flag_params["koopman_kernel_num_centers"] = 100

In [64]:
koopkernelmodel = KoopmanKernelSeq2Seq(
    kernel=rbf,
    input_dim = 1,
    input_length = 1,
    output_length = 1,
    output_dim = 1,
    num_steps = 1,
    num_nys_centers = flag_params["koopman_kernel_num_centers"],
    rng_seed = 42,
)

koopkernelmodel._initialize_nystrom_data(tensor_context_train_standardized)

torch.Size([100, 5]) torch.Size([100, 5])


In [65]:
inps = tensor_context_train_standardized[:flag_params["batch_size"]].lookback(13)
inps = torch.tensor(tensor_context_train_standardized[:flag_params["batch_size"]].lookback(tensor_context_train_standardized.context_length - 1), dtype=torch.float32).to(
    device
)
target = torch.tensor(tensor_context_train_standardized[:flag_params["batch_size"]].lookback(tensor_context_train_standardized.context_length - 1, slide_by=1), dtype=torch.float32).to(
    device
)
print(inps.shape, target.shape)

outs = koopkernelmodel.forward(inps)
outs.shape

torch.Size([32, 13, 5]) torch.Size([32, 13, 5])


torch.Size([32, 13, 5])

In [67]:
tensor_context_inps, tensor_context_tgts = batch_tensor_context(tensor_context_train_standardized, batch_size=flag_params["batch_size"], flag_params=flag_params)
assert torch.all(tensor_context_inps[:,:,1:] == tensor_context_tgts[:,:,:-1])
tensor_context_inps.shape, tensor_context_tgts.shape

(torch.Size([32, 103, 13, 5]), torch.Size([32, 103, 13, 5]))

In [68]:
optimizer = torch.optim.Adam(koopkernelmodel.parameters(), lr=learning_rate)
loss_koopkernel = KoopKernelLoss(koopkernelmodel.nystrom_data_Y, koopkernelmodel._kernel)

from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
tb_writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=1, gamma=flag_params["decay_rate"]
)  # stepwise learning rate decay

# def train_one_epoch(epoch_index, tb_writer):
def train_one_epoch(epoch_index, tb_writer, tensor_context_inps, tensor_context_tgts):
    """From https://pytorch.org/tutorials/beginner/introyt/trainingyt.html."""

    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting

    print(range(tensor_context_inps.shape[1]))
    for i in range(tensor_context_inps.shape[1]):

        # Every data instance is an input + label pair
        inputs, labels = tensor_context_inps[:,i], tensor_context_tgts[:,i]

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = koopkernelmodel(inputs)

        # Compute the loss and its gradients
        loss = loss_koopkernel(outputs, labels)
        loss.backward()

        # print(loss.item())

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * tensor_context_inps.shape[1] + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    
    print(optimizer.state_dict()["param_groups"][0])

    scheduler.step()

    return last_loss

In [69]:
epoch_index = 1
train_loss = train_one_epoch(epoch_index, tb_writer, tensor_context_inps, tensor_context_tgts)

range(0, 103)
  batch 10 loss: 35990.28666992187
  batch 20 loss: 11655.11181640625
  batch 30 loss: 3359.58955078125
  batch 40 loss: 2076.6636138916015
  batch 50 loss: 838.6673706054687
  batch 60 loss: 518.437353515625
  batch 70 loss: 449.80617065429686
  batch 80 loss: 314.98473052978517
  batch 90 loss: 261.50956573486326
  batch 100 loss: 237.34925689697266
{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.001, 'params': [0]}


In [70]:
epoch_index = 1
train_loss = train_one_epoch(epoch_index, tb_writer, tensor_context_inps, tensor_context_tgts)

range(0, 103)
  batch 10 loss: 242.28590393066406
  batch 20 loss: 257.22379913330076
  batch 30 loss: 176.7944793701172
  batch 40 loss: 172.2018928527832
  batch 50 loss: 198.46255645751953
  batch 60 loss: 138.46106719970703
  batch 70 loss: 185.59183578491212
  batch 80 loss: 136.82908630371094
  batch 90 loss: 125.52691192626953
  batch 100 loss: 135.94329910278321
{'lr': 0.0009000000000000001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'initial_lr': 0.001, 'params': [0]}
