# Conversion between TCTracks data and torch data loader

Here we implement the conversion between TCTracks and the pytorch dataloader module.

In [1]:
import torch
import numpy as np
from torch.utils import data
from climada.hazard import TCTracks
# from matplotlib import pyplot as plt

from klearn_tcyclone.data_utils import data_array_list_from_TCTracks
from sklearn.model_selection import train_test_split

from klearn_tcyclone.knf_data_utils import TCTrackDataset

Dask dataframe query planning is disabled because dask-expr is not installed.

You can install it with `pip install dask[dataframe]` or `conda install dask`.
This will raise in a future version.



Import example TCTracks data.

In [2]:
tc_tracks = TCTracks.from_ibtracs_netcdf(provider='usa', year_range=(2000, 2021), basin='NA', correct_pres=False)
print('Number of tracks:', tc_tracks.size)

Cannot find the ecCodes library




  if ibtracs_ds.dims['storm'] == 0:


Number of tracks: 393


Split TCTracks into train and test datasets.

In [3]:
tc_tracks_train, tc_tracks_test = train_test_split(tc_tracks.data, test_size=0.1)

We get the torch TCTrackDataset by specifying input_length, output_length, jumps between the sliding windows and the list of relevant features from the TCTracks data.

In [4]:
input_length = 45
output_length = 10
jumps = 3
feature_list = ["lat", "lon", "central_pressure"]

In [5]:
train_set = TCTrackDataset(
    input_length=input_length,
    output_length=output_length,
    tc_tracks=tc_tracks_train,
    feature_list=feature_list,
    mode="train",
    jumps=jumps
)

Items can be queried by index.

In [6]:
item = train_set.__getitem__(index=1)
item[0].shape, item[1].shape

(torch.Size([45, 3]), torch.Size([10, 3]))

Finally we obtain the torch DataLoader by specifying the batch_size.

In [7]:
batch_size = 32
train_loader = data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=1
)


Length of dataset.

In [8]:
len(train_loader)

55

Check types and shapes of data entries.

In [9]:
for idx, data_point in enumerate(train_loader):
    if idx < 6:
        print(type(data_point), len(data_point))
        print(type(data_point[0]), type(data_point[1]))
        print(data_point[0].shape, data_point[1].shape)

<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
<class 'list'> 2
<class 'torch.Tensor'> <class 'torch.Tensor'>
torch.Size([32, 45, 3]) torch.Size([32, 10, 3])
