In [1]:
%config InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary

from layer.kan_layer import KANLinear, NewGELU
from sklearn.preprocessing import StandardScaler
import time

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

<torch._C.Generator at 0x264d116f9b0>

In [3]:
def get_raw_data(raw_data_dict):
  raw_data = {}
  for dict_ in raw_data_dict.values():
    raw_data.update(dict_)
  return raw_data
def load_and_process_data(data_path, train_years, val_years, test_years):
  raw_train_data_dict = {}
  raw_val_data_dict = {}
  raw_test_data_dict = {}
  for filename in os.listdir(data_path):
    if filename.endswith('.pkl'):
      year = int(filename[11:15])
      if year not in train_years and year not in val_years and year not in test_years:
        continue
      with open(os.path.join(data_path, filename), 'rb') as f:
        data = pickle.load(f)
        if year in train_years:
          raw_train_data_dict[filename] = data
        elif year in val_years:
          raw_val_data_dict[filename] = data
        elif year in test_years:
          raw_test_data_dict[filename] = data
        else:
          raise ValueError(f"Invalid year: {year}")

  raw_train_data_dict = dict(sorted(raw_train_data_dict.items()))
  raw_val_data_dict = dict(sorted(raw_val_data_dict.items()))
  raw_test_data_dict = dict(sorted(raw_test_data_dict.items()))

  raw_train_data = get_raw_data(raw_train_data_dict)
  raw_val_data = get_raw_data(raw_val_data_dict)
  raw_test_data = get_raw_data(raw_test_data_dict)
  return raw_train_data, raw_val_data, raw_test_data

def prepare_data(storm_data, sequence_length, n_ahead, dtype=np.float32):
  total_sequence = 0
  center_grid = 15
  for sid, storm_records in storm_data.items():
    if len(storm_records) < sequence_length + n_ahead:
      continue
    total_sequence += len(storm_records) - sequence_length - n_ahead + 1

  first_key = next(iter(storm_data.keys()))

  cma_len = len(storm_data[first_key][0]['targets'])
  era5_single_len = storm_data[first_key][0]['features']['single'].shape[0]
  era5_multi_len = storm_data[first_key][0]['features']['multi'][1:4].shape[0] * storm_data[first_key][0]['features']['multi'].shape[1]
  features_len = cma_len + era5_single_len + era5_multi_len
  input_shape = (total_sequence, sequence_length, features_len)
  output_shape = (total_sequence, n_ahead)

  X_sequences = np.empty(input_shape, dtype=dtype)
  y_sequences = np.empty(output_shape, dtype=dtype)
  sequence_metadata = [None] * total_sequence

  valid_storms = 0
  idx = 0

  for sid, storm_records in storm_data.items():
    if len(storm_records) < sequence_length + n_ahead:
      continue

    valid_storms += 1
    L = len(storm_records) - sequence_length - n_ahead + 1

    for i in range(L):
      for j in range(sequence_length):
        target = storm_records[i + j]['targets']
        cma_features = dtype([target['center_lat'],target['center_lon'],target['vmax'],target['pmin']])

        era5_features = []
        single_era5_features = storm_records[i + j]['features']['single']
        multi_era5_features = storm_records[i + j]['features']['multi'][1:4, :, :, :]

        for m in range(single_era5_features.shape[0]):
          era5_features.append(single_era5_features[m, center_grid, center_grid])
        for m in range(multi_era5_features.shape[0]):
          for n in range(multi_era5_features.shape[1]):
            era5_features.append(multi_era5_features[m, n, center_grid, center_grid])

        era5_features = dtype(era5_features)

        X_sequences[idx, j, :4] = cma_features
        X_sequences[idx, j, 4:] = era5_features

      for j in range(n_ahead):
        target = storm_records[i + sequence_length + j]['targets']
        y_sequences[idx, j] = dtype(target['vmax'])

      sequence_metadata[idx] = {
        'storm_id': sid,
        'input_times': [storm_records[i + j]['time'] for j in range(sequence_length)],
        'target_time': [storm_records[i + sequence_length + j]['time'] for j in range(n_ahead)]
      }

      idx += 1
  if idx < total_sequence:
    X_sequences = X_sequences[:idx]
    y_sequences = y_sequences[:idx]
    sequence_metadata = sequence_metadata[:idx]

  metadata = {
    'n_sequences': idx,
    'sequence_length': sequence_length,
    'n_storms': valid_storms,
    'sequence_metadata': sequence_metadata,
  }

  gc.collect()
  return (X_sequences, y_sequences, metadata)

class StormDataset(Dataset):
    def __init__(self, X, y):
        if not torch.is_tensor(X):
            self.X = torch.from_numpy(X).float()
        else:
            self.X = X.float()

        if not torch.is_tensor(y):
            self.y = torch.from_numpy(y).float()
        else:
            self.y = y.float()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [4]:
class RNN_KAN_Cell(nn.Module):
    """
        x: (batch, in_features)
        h0: (batch, hidden_features)
    """
    def __init__(self, in_features, hidden_features, activation=nn.Tanh):
        super(RNN_KAN_Cell, self).__init__()
        assert in_features[-1] == hidden_features, f"in_features[-1]={in_features[-1]} phải bằng hidden_features={hidden_features}"
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.activation = activation()
        self.i2h = KANLinear(in_features[0], in_features[1])
        self.h2h = nn.Linear(hidden_features, hidden_features) # W_hh @ h + b_h

    def forward(self, x, h):
        return self.activation(self.i2h(x) + self.h2h(h))

class RNN_KAN(nn.Module):
    """
        x: (batch, seq_len, in_features)
        h0: (batch, hidden_features)
        in_features: list - last idx must equal to hidden_features
    """
    def __init__(self, in_features, hidden_features, output_features, n_ahead, activation=nn.Tanh):
        super(RNN_KAN, self).__init__()
        self.hidden_features = hidden_features
        self.in_features = in_features
        self.RNN_KAN_Cell = RNN_KAN_Cell(in_features, hidden_features, activation)
        self.fc_out = nn.Linear(hidden_features, output_features)

        self.n_ahead = n_ahead
        self.output_features = output_features
    def forward(self, x, h0=None):
        batch, Tx, _ = x.size()
        Ty = self.n_ahead
        h = torch.zeros(batch, Tx + Ty, self.hidden_features, device=x.device)
        y_pred = torch.zeros(batch, Ty, self.output_features, device=x.device)
        if h0 is None:
            h0 = torch.zeros(batch, self.hidden_features, device=x.device)

        h_t = h0
        for t in range(Tx):
            h_t = self.RNN_KAN_Cell(x[:, t, :], h_t)
            h[:, t, :] = h_t
        for t in range(Ty):
            h_t = self.RNN_KAN_Cell(torch.zeros(batch, self.in_features[0], device=x.device), h_t)
            h[:, Tx + t, :] = h_t
            y_t = self.fc_out(h_t)
            y_pred[:, t, :] = y_t
        return y_pred.squeeze(dim=2), h



In [12]:
def test(model, test_loader, scaler_y, device):
  mae_loss = nn.L1Loss()
  running_loss = 0.0
  total_samples = 0

  model.eval()
  for batch_x, batch_y in test_loader:
    batch_x = batch_x.float().to(device)
    batch_y = batch_y.float().to(device)
    b = batch_x.size(0)
    total_samples += b

    y_pred, _ = model(batch_x)

    y_true_pred = scaler_y.inverse_transform(y_pred.cpu().detach().numpy())
    y_true_true = scaler_y.inverse_transform(batch_y.cpu().detach().numpy())

    test_mse_loss = mae_loss(torch.from_numpy(y_true_pred), torch.from_numpy(y_true_true))
    running_loss += test_mse_loss.item() * b

  test_mae_loss = running_loss / len(test_loader.dataset)

  return test_mae_loss

In [6]:
def train_val(model, criterion, optimizer, train_loader, val_loader, device, batch_size=128, epochs=10):
  train_loss_per_epoch = []
  val_loss_per_epoch = []

  total_time = time.time()
  for epoch in range(epochs):
      model.train()
      running_loss = 0.0
      total_samples = 0
      for i, (batch_x, batch_y) in enumerate(train_loader):
          batch_x = batch_x.float().to(device) # (128, 5, 20)
          batch_y = batch_y.float().to(device) # (128, 4)

          if epoch == 0 and i == 0:
              print(f"batch_x shape: {batch_x.size()}")
              print(f"batch_y shape: {batch_y.size()}")

          b = batch_x.size(0)
          total_samples += b
          h0 = None

          # pred
          y_pred, _ = model(batch_x, h0)

          # backprop
          loss = criterion(y_pred, batch_y)
          running_loss += loss.item() * b
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()

      train_epoch_loss = running_loss / total_samples
      train_loss_per_epoch.append(train_epoch_loss)


      model.eval()
      running_loss = 0.0
      total_samples = 0
      for batch_x, batch_y in val_loader:
          batch_x = batch_x.float().to(device) # (128, 5, 20)
          batch_y = batch_y.float().to(device) # (128, 4, 1)
          b = batch_x.size(0)
          total_samples += b
          h0 = None

          # pred
          y_pred, _ = model(batch_x)
          loss = criterion(y_pred, batch_y)
          running_loss += loss.item() * b

      val_epoch_loss = running_loss / total_samples
      val_loss_per_epoch.append(val_epoch_loss)

      print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}")

  total_time = time.time() - total_time

  # plt.plot(train_loss_per_epoch)
  # plt.plot(val_loss_per_epoch)
  # plt.legend(['train', 'val'])
  # plt.show()

  return model, total_time, train_loss_per_epoch, val_loss_per_epoch

In [7]:
in_features = (20, 100)
hidden_features = 100
output_features = 1
seq_len = 5
n_ahead = 4
batch_size = 128

model = RNN_KAN(in_features, hidden_features, output_features, n_ahead)

summary(model, (batch_size, seq_len, in_features[0]))

Layer (type:depth-idx)                   Output Shape              Param #
RNN_KAN                                  [128, 4]                  --
├─RNN_KAN_Cell: 1-1                      [128, 100]                --
│    └─KANLinear: 2-1                    [128, 100]                20,000
│    │    └─SiLU: 3-1                    [128, 20]                 --
│    │    └─SiLU: 3-2                    [128, 20]                 --
│    └─Linear: 2-2                       [128, 100]                10,100
│    └─Tanh: 2-3                         [128, 100]                --
├─RNN_KAN_Cell: 1-2                      [128, 100]                (recursive)
│    └─KANLinear: 2-4                    [128, 100]                (recursive)
│    │    └─SiLU: 3-3                    [128, 20]                 --
│    │    └─SiLU: 3-4                    [128, 20]                 --
│    └─Linear: 2-5                       [128, 100]                (recursive)
│    └─Tanh: 2-6                         [128, 100

In [8]:
data_path = 'data/cma_era5'
raw_train_data, raw_val_data, raw_test_data = load_and_process_data(data_path=data_path, train_years=list(range(1980, 2017)), val_years=[2017, 2018, 2019], test_years=[2020, 2021, 2022])

In [None]:
#====================== Train model ======================#
# config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {torch.cuda.get_device_name(0)} for device")

# hyperparams
in_features = (20, 100)
hidden_features = 100
output_features = 1
seq_len = 5
n_aheads = list(range(1, 13))
batch_size = 128
epochs = 10
lr = 1e-3

results = {}

for n_ahead in n_aheads:
  print(f"================== TRAINING n_ahead = {n_ahead} ==================")
  X_train, y_train, metadata_train = prepare_data(raw_train_data, sequence_length = 5, n_ahead = n_ahead)
  X_val, y_val, metadata_val = prepare_data(raw_val_data, sequence_length = 5, n_ahead = n_ahead)
  X_test, y_test, metadata_test = prepare_data(raw_test_data, sequence_length = 5, n_ahead = n_ahead)

  print(f"X_shape: {X_train.shape}")
  print(f"y_shape: {y_train.shape}")

  raw_test_data_2020 = {k: v for k, v in raw_test_data.items() if int(k[-4:]) == 2020}
  raw_test_data_2021 = {k: v for k, v in raw_test_data.items() if int(k[-4:]) == 2021}
  raw_test_data_2022 = {k: v for k, v in raw_test_data.items() if int(k[-4:]) == 2022}

  X_test_2020, y_test_2020, metadata_test_2020 = prepare_data(raw_test_data_2020, sequence_length = 5, n_ahead = n_ahead)
  X_test_2021, y_test_2021, metadata_test_2021 = prepare_data(raw_test_data_2021, sequence_length = 5, n_ahead = n_ahead)
  X_test_2022, y_test_2022, metadata_test_2022 = prepare_data(raw_test_data_2022, sequence_length = 5, n_ahead = n_ahead)

  scaler_X = StandardScaler()
  X_train_scaled = scaler_X.fit_transform(X_train.reshape(-1, X_train.shape[-1])).reshape(X_train.shape)
  X_val_scaled   = scaler_X.transform(X_val.reshape(-1, X_val.shape[-1])).reshape(X_val.shape)
  X_test_scaled  = scaler_X.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(X_test.shape)
  X_test_2020_scaled = scaler_X.transform(X_test_2020.reshape(-1, X_test_2020.shape[-1])).reshape(X_test_2020.shape)
  X_test_2021_scaled = scaler_X.transform(X_test_2021.reshape(-1, X_test_2021.shape[-1])).reshape(X_test_2021.shape)
  X_test_2022_scaled = scaler_X.transform(X_test_2022.reshape(-1, X_test_2022.shape[-1])).reshape(X_test_2022.shape)

  scaler_y = StandardScaler()
  y_train_scaled = scaler_y.fit_transform(y_train)
  y_val_scaled   = scaler_y.transform(y_val)
  y_test_scaled  = scaler_y.transform(y_test)
  y_test_2020_scaled = scaler_y.transform(y_test_2020)
  y_test_2021_scaled = scaler_y.transform(y_test_2021)
  y_test_2022_scaled = scaler_y.transform(y_test_2022)

  train_dataset = StormDataset(X_train_scaled, y_train_scaled)
  val_dataset = StormDataset(X_val_scaled, y_val_scaled)
  test_dataset = StormDataset(X_test_scaled, y_test_scaled)
  test_dataset_2020 = StormDataset(X_test_2020_scaled, y_test_2020_scaled)
  test_dataset_2021 = StormDataset(X_test_2021_scaled, y_test_2021_scaled)
  test_dataset_2022 = StormDataset(X_test_2022_scaled, y_test_2022_scaled)

  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
  test_2020_loader = DataLoader(test_dataset_2020, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
  test_2021_loader = DataLoader(test_dataset_2021, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
  test_2022_loader = DataLoader(test_dataset_2022, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

  model = RNN_KAN(in_features, hidden_features, output_features, n_ahead)
  model = model.to(device)
  criterion = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
  model, total_time, train_loss_per_epoch, val_loss_per_epoch = train_val(model, criterion, optimizer, train_loader, val_loader, device, batch_size, epochs)

  test_mae = test(model, test_loader, scaler_y, device)
  test_2020_mae = test(model, test_2020_loader, scaler_y, device)
  test_2021_mae = test(model, test_2021_loader, scaler_y, device)
  test_2022_mae = test(model, test_2022_loader, scaler_y, device)

  results[n_ahead] = {
      "mae": [test_mae, test_2020_mae, test_2021_mae, test_2022_mae],
      "time": [total_time]
  }

Using NVIDIA GeForce RTX 4060 Ti for device
X_shape: (25864, 5, 20)
y_shape: (25864, 1)
batch_x shape: torch.Size([128, 5, 20])
batch_y shape: torch.Size([128, 1])
Epoch 1/10, Train Loss: 0.0935, Val Loss: 0.0399
Epoch 2/10, Train Loss: 0.0412, Val Loss: 0.0343
Epoch 3/10, Train Loss: 0.0394, Val Loss: 0.0373
Epoch 4/10, Train Loss: 0.0384, Val Loss: 0.0359
Epoch 5/10, Train Loss: 0.0384, Val Loss: 0.0334
Epoch 6/10, Train Loss: 0.0372, Val Loss: 0.0313
Epoch 7/10, Train Loss: 0.0370, Val Loss: 0.0344
Epoch 8/10, Train Loss: 0.0362, Val Loss: 0.0329
Epoch 9/10, Train Loss: 0.0364, Val Loss: 0.0318
Epoch 10/10, Train Loss: 0.0364, Val Loss: 0.0302
X_shape: (24788, 5, 20)
y_shape: (24788, 2)
batch_x shape: torch.Size([128, 5, 20])
batch_y shape: torch.Size([128, 2])
Epoch 1/10, Train Loss: 0.1500, Val Loss: 0.0837
Epoch 2/10, Train Loss: 0.0838, Val Loss: 0.0754
Epoch 3/10, Train Loss: 0.0783, Val Loss: 0.0714
Epoch 4/10, Train Loss: 0.0748, Val Loss: 0.0694
Epoch 5/10, Train Loss: 0.072

In [14]:
for k,v in results.items():
  print(f"Prediction {k*6}h: ")
  print(f"\t Total time: {v['time'][0]:.2f}s")
  print(f"\t Test MAE: {v['mae'][0]:.4f}")
  print(f"\t Test 2020 MAE: {v['mae'][1]:.4f}")
  print(f"\t Test 2021 MAE: {v['mae'][2]:.4f}")
  print(f"\t Test 2022 MAE: {v['mae'][3]:.4f}")

Prediction 6h: 
	 Total time: 13.64s
	 Test MAE: 2.1584
	 Test 2020 MAE: 1.6396
	 Test 2021 MAE: 1.5002
	 Test 2022 MAE: 3.5745
Prediction 12h: 
	 Total time: 14.77s
	 Test MAE: 3.2734
	 Test 2020 MAE: 2.3647
	 Test 2021 MAE: 2.2472
	 Test 2022 MAE: 5.6135
Prediction 18h: 
	 Total time: 16.61s
	 Test MAE: 3.8540
	 Test 2020 MAE: 2.7879
	 Test 2021 MAE: 2.7134
	 Test 2022 MAE: 6.5492
Prediction 24h: 
	 Total time: 17.02s
	 Test MAE: 4.1209
	 Test 2020 MAE: 3.2074
	 Test 2021 MAE: 2.9418
	 Test 2022 MAE: 6.7463
Prediction 30h: 
	 Total time: 26.89s
	 Test MAE: 4.5103
	 Test 2020 MAE: 3.4612
	 Test 2021 MAE: 3.3348
	 Test 2022 MAE: 7.3145
Prediction 36h: 
	 Total time: 22.94s
	 Test MAE: 5.2042
	 Test 2020 MAE: 4.2179
	 Test 2021 MAE: 3.9945
	 Test 2022 MAE: 8.0309
Prediction 42h: 
	 Total time: 29.66s
	 Test MAE: 5.6628
	 Test 2020 MAE: 4.2999
	 Test 2021 MAE: 4.5120
	 Test 2022 MAE: 8.8635
Prediction 48h: 
	 Total time: 30.28s
	 Test MAE: 5.8051
	 Test 2020 MAE: 4.5698
	 Test 2021 MAE: 