In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import pickle
import wandb

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary

from layer.kan_layer import KANLinear, NewGELU
from sklearn.preprocessing import StandardScaler
import time

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

In [2]:
def get_raw_data(raw_data_dict):
  raw_data = {}
  for dict_ in raw_data_dict.values():
    raw_data.update(dict_)
  return raw_data
def load_and_process_data(data_path, train_years, val_years, test_years):
  raw_train_data_dict = {}
  raw_val_data_dict = {}
  raw_test_data_dict = {}
  for filename in os.listdir(data_path):
    if filename.endswith('.pkl'):
      year = int(filename[11:15])
      if year not in train_years and year not in val_years and year not in test_years:
        continue
      with open(os.path.join(data_path, filename), 'rb') as f:
        data = pickle.load(f)
        if year in train_years:
          raw_train_data_dict[filename] = data
        elif year in val_years:
          raw_val_data_dict[filename] = data
        elif year in test_years:
          raw_test_data_dict[filename] = data
        else:
          raise ValueError(f"Invalid year: {year}")

  raw_train_data_dict = dict(sorted(raw_train_data_dict.items()))
  raw_val_data_dict = dict(sorted(raw_val_data_dict.items()))
  raw_test_data_dict = dict(sorted(raw_test_data_dict.items()))

  raw_train_data = get_raw_data(raw_train_data_dict)
  raw_val_data = get_raw_data(raw_val_data_dict)
  raw_test_data = get_raw_data(raw_test_data_dict)
  return raw_train_data, raw_val_data, raw_test_data

def prepare_data(storm_data, sequence_length, n_ahead, dtype=np.float32):
  total_sequence = 0
  center_grid = 15
  for sid, storm_records in storm_data.items():
    if len(storm_records) < sequence_length + n_ahead:
      continue
    total_sequence += len(storm_records) - sequence_length - n_ahead + 1

  first_key = next(iter(storm_data.keys()))

  cma_len = len(storm_data[first_key][0]['targets'])
  era5_single_len = storm_data[first_key][0]['features']['single'].shape[0]
  era5_multi_len = storm_data[first_key][0]['features']['multi'][1:4].shape[0] * storm_data[first_key][0]['features']['multi'].shape[1]
  features_len = cma_len + era5_single_len + era5_multi_len
  input_shape = (total_sequence, sequence_length, features_len)
  output_shape = (total_sequence, n_ahead)

  X_sequences = np.empty(input_shape, dtype=dtype)
  y_sequences = np.empty(output_shape, dtype=dtype)
  sequence_metadata = [None] * total_sequence

  valid_storms = 0
  idx = 0

  for sid, storm_records in storm_data.items():
    if len(storm_records) < sequence_length + n_ahead:
      continue

    valid_storms += 1
    L = len(storm_records) - sequence_length - n_ahead + 1

    for i in range(L):
      for j in range(sequence_length):
        target = storm_records[i + j]['targets']
        cma_features = dtype([target['center_lat'],target['center_lon'],target['vmax'],target['pmin']])

        era5_features = []
        single_era5_features = storm_records[i + j]['features']['single']
        multi_era5_features = storm_records[i + j]['features']['multi'][1:4, :, :, :]

        for m in range(single_era5_features.shape[0]):
          era5_features.append(single_era5_features[m, center_grid, center_grid])
        for m in range(multi_era5_features.shape[0]):
          for n in range(multi_era5_features.shape[1]):
            era5_features.append(multi_era5_features[m, n, center_grid, center_grid])

        era5_features = dtype(era5_features)

        X_sequences[idx, j, :4] = cma_features
        X_sequences[idx, j, 4:] = era5_features

      for j in range(n_ahead):
        target = storm_records[i + sequence_length + j]['targets']
        y_sequences[idx, j] = dtype(target['vmax'])

      sequence_metadata[idx] = {
        'storm_id': sid,
        'input_times': [storm_records[i + j]['time'] for j in range(sequence_length)],
        'target_time': [storm_records[i + sequence_length + j]['time'] for j in range(n_ahead)]
      }

      idx += 1
  if idx < total_sequence:
    X_sequences = X_sequences[:idx]
    y_sequences = y_sequences[:idx]
    sequence_metadata = sequence_metadata[:idx]

  metadata = {
    'n_sequences': idx,
    'sequence_length': sequence_length,
    'n_storms': valid_storms,
    'sequence_metadata': sequence_metadata,
  }

  gc.collect()
  return (X_sequences, y_sequences, metadata)

class StormDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.as_tensor(X, dtype=torch.float32)
        self.y = torch.as_tensor(y, dtype=torch.float32)
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
def prepare_data_6d(storm_data, sequence_length, n_ahead, dtype=np.float32):
    # Tạo sliding windows
    X_sequences_multi = []
    X_sequences_single = []
    sequence_metadata = []

    # in_multi_size = ()

    valid_storms = 0
    total_sequence = 0
    for sid, storm_records in storm_data.items():
        if len(storm_records) < sequence_length + n_ahead:
            continue
        total_sequence += len(storm_records) - sequence_length - n_ahead + 1

    y_sequences = np.empty((total_sequence, n_ahead))
    
    idx = 0

    for sid, storm_records in storm_data.items():
        if len(storm_records) < sequence_length + 1:  # Cần ít nhất 5 time steps (4 input + 1 target)
            print(f"Bỏ qua storm {sid}: chỉ có {len(storm_records)} time steps (cần ít nhất {sequence_length + 1})")
            continue

        valid_storms += 1
        storm_sequences = 0
        L = len(storm_records) - sequence_length - n_ahead + 1

        # Tạo sliding window cho mỗi cơn bão
        for i in range(L):
            # Input: 4 time steps liên tiếp
            input_sequence = {
                'multi': [],
                'single': []
            }
            # first_storm_time = storm_records[i]['time']
            # storm_record_targets = []
            for j in range(sequence_length):
                target = storm_records[i + j]['targets']
                cma_features = dtype([target['center_lat'],target['center_lon'],target['vmax'],target['pmin']])
                cma_features = np.array(cma_features)


                single_era5_features = storm_records[i + j]['features']['single']
                multi_era5_features = storm_records[i + j]['features']['multi']
                single_era5_features = single_era5_features[None, :, :, :]
                single_era5_features = np.array(single_era5_features)
                multi_era5_features = np.array(multi_era5_features)
                era5_features = np.concatenate([single_era5_features, multi_era5_features], axis=0)

                input_sequence['single'].append(cma_features)
                input_sequence['multi'].append(era5_features)
            

            # input_sequence['single'].append(generate_single_data(storm_record_targets, first_storm_time))
            for j in range(n_ahead):
                target = storm_records[i + sequence_length + j]['targets']
                y_sequences[idx, j] = dtype(target['vmax'])

            X_sequences_multi.append(input_sequence['multi'])
            X_sequences_single.append(input_sequence['single'])

            # Tao metadata
            sequence_metadata.append({
                'storm_id': sid,
                'input_times': [storm_records[i + j]['time'] for j in range(sequence_length)],
                'target_time': storm_records[i + sequence_length]['time']
            })

            idx += 1

    X_multi = np.array(X_sequences_multi)
    X_single = np.array(X_sequences_single)
    # X_single = X_single.reshape(X_single.shape[0], 76)
    y = np.array(y_sequences)


    # Tạo metadata
    metadata = {
        'n_sequences': idx,
        'sequence_length': sequence_length,
        'n_storms': valid_storms,
        'sequence_metadata': sequence_metadata,
    }

    # print(f"\nData preparation completed:")
    # print(f"  Số cơn bão: {valid_storms}")
    # print(f"  Số sequences: {len(X_sequences_multi)}")
    # print(f"  Shape X_multi: {X_multi.shape}")
    # print(f"  Shape X_single: {X_single.shape}")
    # print(f"  Shape y: {y.shape}")

    return X_multi, X_single, y, metadata

In [14]:
data_path = 'data/cma-era5-data'
raw_train_data, raw_val_data, raw_test_data = load_and_process_data(data_path=data_path, train_years=list(range(1980, 2017)), val_years=[2017, 2018, 2019], test_years=[2020, 2021, 2022])

In [42]:
Xm, Xs, y, _ = prepare_data_6d(raw_test_data, 5, 4)

Bỏ qua storm 7-2022: chỉ có 4 time steps (cần ít nhất 6)

Data preparation completed:
  Số cơn bão: 80
  Số sequences: 1588
  Shape X_multi: (1588, 5, 6, 4, 31, 31)
  Shape X_single: (1588, 5, 4)
  Shape y: (1588, 4)


## Register Storm dataset

### 2D Storm Dataset 
(seq_len, n_features)

In [None]:
data_path = 'data/cma-era5-data'
raw_train_data, raw_val_data, raw_test_data = load_and_process_data(data_path=data_path, train_years=list(range(1980, 2017)), val_years=[2017, 2018, 2019], test_years=[2020, 2021, 2022])

In [14]:
run = wandb.init(project="Dataset Artifact", job_type="data-loading")

raw_dataset_artifact = wandb.Artifact(
    name='raw_cma_era5',
    type='dataset',
    description='Raw CMA ERA5 train/val/test data'
)

# add file
raw_dataset_artifact.add_file('dataset/cma-era5-data/raw_cma_era5_train.pkl')
raw_dataset_artifact.add_file('dataset/cma-era5-data/raw_cma_era5_val.pkl')
raw_dataset_artifact.add_file('dataset/cma-era5-data/raw_cma_era5_test.pkl')

# log artifact
wandb.log_artifact(raw_dataset_artifact)

run.finish()

In [None]:
def regist_dataset(file_path):
    run = wandb.init(project="Dataset Artifact", job_type="data-loading")
    name = file_path.split('/')[-1][:-4]
    raw_dataset_artifact = wandb.Artifact(
        name=name,
        type='dataset',
        description=name
    )

    raw_dataset_artifact.add_file(file_path)

    wandb.log_artifact(raw_dataset_artifact)
    run.finish()

In [10]:
file_paths = ['dataset/raw_ETTh1.csv', 'dataset/raw_ETTh2.csv', 'dataset/raw_ETTm1.csv', 'dataset/raw_electricity.csv', 'dataset/raw_exchange_rate.csv', 'dataset/raw_PEMS03.pkl', 'dataset/raw_PEMS04.pkl', 'dataset/raw_PEMS07.pkl', 'dataset/raw_PEMS08.pkl']
for file_path in file_paths:
    regist_dataset(file_path)