## 1. Package

In [None]:
! pip install xarray zarr gcsfs fsspec
! pip install --quiet climetlab
%config InteractiveShell.ast_node_interactivity = "all"

Collecting zarr
  Downloading zarr-3.1.1-py3-none-any.whl.metadata (10 kB)
Collecting donfig>=0.8 (from zarr)
  Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB)
Collecting numcodecs>=0.14 (from numcodecs[crc32c]>=0.14->zarr)
  Downloading numcodecs-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting crc32c>=2.7 (from numcodecs[crc32c]>=0.14->zarr)
  Downloading crc32c-2.7.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Downloading zarr-3.1.1-py3-none-any.whl (255 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m255.4/255.4 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading donfig-0.8.1.post1-py3-none-any.whl (21 kB)
Downloading numcodecs-0.16.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m62.4 MB/s[0m eta [36m0:00:00[0m
[

In [None]:
from torch.utils.data import Dataset
from typing import Tuple, List
import xarray as xr
import gcsfs
import os
import pandas as pd
import numpy as np
import climetlab as cml
import kagglehub
import pickle
from google.colab import auth
import time

## 2. Load and preprocessing dataset

### Load CMA Dataset

In [None]:
path = kagglehub.dataset_download("chriszhengao/cma-best-track-data")
print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/cma-best-track-data


In [None]:
cma_raw_data = pd.read_csv(os.path.join(path, "CMA_Best_Track_Data.csv"))

  cma_raw_data = pd.read_csv(os.path.join(path, "CMA_Best_Track_Data.csv"))


### Load ERA5 Dataset

In [None]:
fs = gcsfs.GCSFileSystem(token='anon')

In [None]:
era5_path = 'gs://weatherbench2/datasets/era5/'
files = fs.ls(era5_path)
print("Available files:", files)

Available files: ['weatherbench2/datasets/era5/', 'weatherbench2/datasets/era5/1959-2022-1h-240x121_equiangular_with_poles_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-1h-360x181_equiangular_with_poles_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-128x64_equiangular_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-128x64_equiangular_with_poles_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-240x121_equiangular_with_poles_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_with_poles_conservative.zarr', 'weatherbench2/datasets/era5/1959-2022-6h-64x33.zarr', 'weatherbench2/datasets/era5/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2', 'weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25de

In [None]:
era5_file_path = 'gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg-chunk-1.zarr-v2'
data_file = '/content/cleaned_cma_best_track_data.csv'

In [None]:
era5_raw_data = xr.open_zarr(era5_file_path)

### Preprocessing CMA

In [None]:
df = pd.DataFrame(cma_raw_data)
df.Time = pd.to_datetime(df.Time)

In [None]:
columns = ["Tropical Cyclone Number", "Time", "Latitude", "Longitude", "Minimum Central Pressure", "Maximum Wind Speed"]
df = df[df.Time.dt.year.isin(list(range(1980, 2024)))]
df = df[df["Time"].dt.hour % 6 == 0]
df = df[columns]
df['Year'] = df['Time'].dt.year
df['SID'] = df['Tropical Cyclone Number'].astype(str) + '-' + df['Year'].astype(str)
df = df.drop(columns=['Tropical Cyclone Number'])

In [None]:
def drop_miss_id(df):
  count = 0
  df = df.copy()
  df['Time'] = pd.to_datetime(df['Time'])
  sids = df["SID"].unique()
  drop_sids = []
  for sid in sids:
      sub_df = df[df["SID"] == sid].sort_values('Time')
      if len(sub_df) < 2:
          continue
      time_deltas = sub_df['Time'].diff().dropna()
      if not all(time_deltas == pd.Timedelta(hours=6)):
          drop_sids.append(sid)
  return drop_sids


drop_sids = drop_miss_id(df)
cma = df[~df.SID.isin(drop_sids)]

In [None]:
df

### Preprocessing ERA5

In [None]:
# CẤU HÌNH
LONG2SHORT_DICT = {
    "geopotential": "z", "temperature": "t", "specific_humidity": "q", "relative_humidity": "r",
    "u_component_of_wind": "u", "v_component_of_wind": "v", "vorticity": "vo", "potential_vorticity": "pv",
    "2m_temperature": "t2m", "10m_u_component_of_wind": "u10", "10m_v_component_of_wind": "v10",
    "total_cloud_cover": "tcc", "total_precipitation": "tp", "toa_incident_solar_radiation": "tisr",
    "mean_sea_level_pressure": "msl"
}

SINGLE_LEVEL_VARS = [
    "10m_u_component_of_wind", "10m_v_component_of_wind", "2m_temperature", "mean_sea_level_pressure"
] # u10, v10, t2m, tsm

MULTI_LEVEL_VARS = [
    "geopotential", "specific_humidity", "u_component_of_wind", "v_component_of_wind", "temperature"
] # z, q, u, v, t

# THAM SỐ
# option khác: height level lấy hết 37 level (trích patch 2 hoặc 3 thay vì 31 hoặc 40)
HEIGHT_LEVELS = [250, 500, 750, 1000] # các mức áp suất
TRAIN_BEGIN_YEAR = 1980
TIME_FREQUENCY = [0, 6, 12, 18] # 6h /1 track
def filter_era5_data(ds):
    """
    lọc lấy data cần thiết
    """

    datasets_to_merge = []

    # lọc single-var
    ds_single = ds[SINGLE_LEVEL_VARS]
    ds_single = ds_single.sel(time=slice(f'{TRAIN_BEGIN_YEAR}-01-01', None))
    ds_single = ds_single.sel(time=ds_single.time.dt.hour.isin(TIME_FREQUENCY))
    datasets_to_merge.append(ds_single)

    # lọc multi-var
    ds_multi = ds[MULTI_LEVEL_VARS]
    ds_multi = ds_multi.sel(time=slice(f'{TRAIN_BEGIN_YEAR}-01-01', None))
    ds_multi = ds_multi.sel(time=ds_multi.time.dt.hour.isin(TIME_FREQUENCY))
    ds_multi = ds_multi.sel(level=HEIGHT_LEVELS)
    datasets_to_merge.append(ds_multi)

    # merge
    ds_filtered = xr.merge(datasets_to_merge)

    # rename
    rename_dict = {k: v for k, v in LONG2SHORT_DICT.items() if k in ds_filtered.data_vars}
    ds_filtered = ds_filtered.rename(rename_dict)

    print(f"  Final dataset shape: {dict(ds_filtered.dims)}")
    print(f"  Final variables: {list(ds_filtered.data_vars)}")

    return ds_filtered

In [None]:
era5 = filter_era5_data(era5_raw_data)
era5

  Final dataset shape: {'time': 61364, 'latitude': 721, 'longitude': 1440, 'level': 4}
  Final variables: ['u10', 'v10', 't2m', 'msl', 'z', 'q', 'u', 'v', 't']


  print(f"  Final dataset shape: {dict(ds_filtered.dims)}")


Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 237.34 GiB 3.96 MiB Shape (61364, 721, 1440) (1, 721, 1440) Dask graph 61364 chunks in 3 graph layers Data type float32 numpy.ndarray",1440  721  61364,

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 237.34 GiB 3.96 MiB Shape (61364, 721, 1440) (1, 721, 1440) Dask graph 61364 chunks in 3 graph layers Data type float32 numpy.ndarray",1440  721  61364,

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 237.34 GiB 3.96 MiB Shape (61364, 721, 1440) (1, 721, 1440) Dask graph 61364 chunks in 3 graph layers Data type float32 numpy.ndarray",1440  721  61364,

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 237.34 GiB 3.96 MiB Shape (61364, 721, 1440) (1, 721, 1440) Dask graph 61364 chunks in 3 graph layers Data type float32 numpy.ndarray",1440  721  61364,

Unnamed: 0,Array,Chunk
Bytes,237.34 GiB,3.96 MiB
Shape,"(61364, 721, 1440)","(1, 721, 1440)"
Dask graph,61364 chunks in 3 graph layers,61364 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 0.93 TiB 15.84 MiB Shape (61364, 4, 721, 1440) (1, 4, 721, 1440) Dask graph 61364 chunks in 4 graph layers Data type float32 numpy.ndarray",61364  1  1440  721  4,

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 0.93 TiB 15.84 MiB Shape (61364, 4, 721, 1440) (1, 4, 721, 1440) Dask graph 61364 chunks in 4 graph layers Data type float32 numpy.ndarray",61364  1  1440  721  4,

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 0.93 TiB 15.84 MiB Shape (61364, 4, 721, 1440) (1, 4, 721, 1440) Dask graph 61364 chunks in 4 graph layers Data type float32 numpy.ndarray",61364  1  1440  721  4,

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 0.93 TiB 15.84 MiB Shape (61364, 4, 721, 1440) (1, 4, 721, 1440) Dask graph 61364 chunks in 4 graph layers Data type float32 numpy.ndarray",61364  1  1440  721  4,

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 0.93 TiB 15.84 MiB Shape (61364, 4, 721, 1440) (1, 4, 721, 1440) Dask graph 61364 chunks in 4 graph layers Data type float32 numpy.ndarray",61364  1  1440  721  4,

Unnamed: 0,Array,Chunk
Bytes,0.93 TiB,15.84 MiB
Shape,"(61364, 4, 721, 1440)","(1, 4, 721, 1440)"
Dask graph,61364 chunks in 4 graph layers,61364 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


### Extract Storm Data từ CMA và ERA5

In [None]:
import numpy as np
import pandas as pd

def extract_storm_data(era5_data, cma_data, years, PATCH=31):
    """
    Trích xuất dữ liệu từng cơn bão theo năm chỉ định

    Args:
        era5_data: fitered era5 dataset
        cma_data: pandas DataFrame với cột 'Time','SID','Latitude','Longitude','Maximum Wind Speed','Minimum Central Pressure'
        years: int hoặc list of int, các năm cần lấy (ví dụ 2005 hoặc [2005,2010])
        PATCH: kích thước patch (mặc định 31)

    Returns:
        storm_data: dict mapping SID -> list of storm_sequence dicts
    """
    # đảm bảo years là list
    if isinstance(years, int):
        years = [years]

    save_dir = 'storm_by_years'
    os.makedirs(save_dir, exist_ok=True)

    failed_extractions = 0
    HALF_PATCH = PATCH // 2

    # rút tên biến ngắn
    single_short_names = [LONG2SHORT_DICT[var]
                          for var in SINGLE_LEVEL_VARS if var in LONG2SHORT_DICT]
    multi_short_names = [LONG2SHORT_DICT[var]
                         for var in MULTI_LEVEL_VARS if var in LONG2SHORT_DICT]

    # grid ERA5
    lat_coords = era5_data.latitude.values
    lon_coords = era5_data.longitude.values
    time_coords = era5_data.time.values


    for year in years:
        storm_data = {}
        cma_per_year = cma_data[cma_data.Year == year].copy()
        total_records = len(cma_per_year)
        print(f"===================== Bắt đầu trích xuất bão năm {year} =====================")
        print(f"Tổng records của năm {year}: {total_records}")

        lat_indices = []
        lon_indices = []
        time_indices = []
        for _, row in cma_per_year.iterrows():
            lat_indices.append(np.argmin(np.abs(lat_coords - row['Latitude'])))
            lon_indices.append(np.argmin(np.abs(lon_coords - row['Longitude'])))
            time_indices.append(np.argmin(np.abs(time_coords - np.datetime64(row['Time']))))



        for i, (idx, row) in enumerate(cma_per_year.iterrows()):
            orig_time = row['Time']
            sid = row['SID']
            if sid not in storm_data:
                storm_data[sid] = []

            try:
                vmax = row['Maximum Wind Speed']
                pmin = row['Minimum Central Pressure']

                lat_idx = lat_indices[i]
                lon_idx = lon_indices[i]
                time_idx = time_indices[i]

                era5_lat = float(lat_coords[lat_idx])
                era5_lon = float(lon_coords[lon_idx])
                era5_time = time_coords[time_idx]


                # xác định slice
                lat_slice = slice(max(lat_idx - HALF_PATCH, 0),
                                min(lat_idx + HALF_PATCH + 1, len(lat_coords)))
                lon_slice = slice(max(lon_idx - HALF_PATCH, 0),
                                min(lon_idx + HALF_PATCH + 1, len(lon_coords)))

                subset = era5_data.isel(time=time_idx, latitude=lat_slice, longitude=lon_slice)

                # trích xuất multi-level vars
                da_multi = subset[multi_short_names].to_array(dim="feature")
                patch_multi = da_multi.values

                da_single = subset[single_short_names].to_array(dim="feature")
                patch_single = da_single.values

                storm_sequence = {
                    'time': era5_time,
                    'features': {
                        'multi': patch_multi,
                        'single': patch_single
                    },
                    'targets': {
                        'center_lat': era5_lat,
                        'center_lon': era5_lon,
                        'vmax': vmax,
                        'pmin': pmin
                    },
                    'metadata': {
                        'sid': sid,
                        'orig_time': orig_time,
                        'center_lat': era5_lat,
                        'center_lon': era5_lon,
                        'patch_bounds': {
                            'lat_slice': (lat_slice.start, lat_slice.stop),
                            'lon_slice': (lon_slice.start, lon_slice.stop)
                        }
                    }
                }

                storm_data[sid].append(storm_sequence)

            except Exception as e:
                failed_extractions += 1
                if failed_extractions % 10 == 0:
                    print(f"Warning: failed at record {idx}: {e}")
                continue
        total_storms = len(storm_data)
        total_records = sum(len(v) for v in storm_data.values())
        print(f"\nTrích xuất thành công cho năm {year}:")
        print(f"  Số cơn bão: {total_storms}")
        print(f"  Tổng record: {total_records}")
        print(f"  Patch thất bại: {failed_extractions}")
        # if total_storms:
        #     first_sid = next(iter(storm_data))
        #     print(f"  Multi-level patch shape: {storm_data[first_sid][0]['features'].shape}")

        storm_path = os.path.join(save_dir, f"storm_data_{year}.pkl")
        with open(storm_path, 'wb') as f:
            pickle.dump(storm_data, f)
        print(f"Đã lưu dữ liệu năm {year} vào {storm_path}")

    return storm_data


In [None]:
storm_data = extract_storm_data(era5, cma, list(range(1980, 2022)))

Tổng records của năm 2016: 725

Trích xuất thành công cho năm 2016:
  Số cơn bão: 29
  Tổng record: 725
  Patch thất bại: 0
Đã lưu dữ liệu năm 2016 vào storm_by_years/storm_data_2016.pkl
Tổng records của năm 2017: 762

Trích xuất thành công cho năm 2017:
  Số cơn bão: 29
  Tổng record: 762
  Patch thất bại: 0
Đã lưu dữ liệu năm 2017 vào storm_by_years/storm_data_2017.pkl
