In [3]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from typing import Union, List, Optional
import os

from llmfoundry.data import StreamingTextDataset
from llmfoundry.tokenizers import ChronosTokenizerWrapper
from chronos.chronos import ChronosConfig, ChronosPipeline

from streaming import Stream
from streaming.base import MDSWriter, StreamingDataset
from transformers import PreTrainedTokenizerBase, PreTrainedTokenizer, AutoConfig

from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union, Tuple, cast

%cd /mnt/workdisk/kushal/llm-foundry/

/mnt/workdisk/kushal/llm-foundry


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

In [3]:
train = pd.read_csv('./kushal-testing/hospital_train.csv')
test = pd.read_csv('./kushal-testing/hospital_test.csv')
context_len, horizon_len = train.shape[1], test.shape[1]

In [4]:
config = AutoConfig.from_pretrained('amazon/chronos-t5-small')
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
chronos_config = ChronosConfig(**config.chronos_config)  # dictionary of content in "chronos_config" of amazon/chronos-t5-{} config.json
chronos_tokenizer = chronos_config.create_tokenizer()
chronos_tokenizer.boundaries, type(chronos_tokenizer.boundaries), chronos_tokenizer.boundaries.shape

(tensor([-1.0000e+20, -1.4996e+01, -1.4989e+01,  ...,  1.4989e+01,
          1.4996e+01,  1.0000e+20]),
 torch.Tensor,
 torch.Size([4094]))

In [22]:
{i.item(): i for i in chronos_tokenizer.boundaries}

{tensor(-1.0000e+20): tensor(-1.0000e+20),
 tensor(-14.9963): tensor(-14.9963),
 tensor(-14.9890): tensor(-14.9890),
 tensor(-14.9817): tensor(-14.9817),
 tensor(-14.9743): tensor(-14.9743),
 tensor(-14.9670): tensor(-14.9670),
 tensor(-14.9597): tensor(-14.9597),
 tensor(-14.9523): tensor(-14.9523),
 tensor(-14.9450): tensor(-14.9450),
 tensor(-14.9377): tensor(-14.9377),
 tensor(-14.9304): tensor(-14.9304),
 tensor(-14.9230): tensor(-14.9230),
 tensor(-14.9157): tensor(-14.9157),
 tensor(-14.9084): tensor(-14.9084),
 tensor(-14.9010): tensor(-14.9010),
 tensor(-14.8937): tensor(-14.8937),
 tensor(-14.8864): tensor(-14.8864),
 tensor(-14.8790): tensor(-14.8790),
 tensor(-14.8717): tensor(-14.8717),
 tensor(-14.8644): tensor(-14.8644),
 tensor(-14.8570): tensor(-14.8570),
 tensor(-14.8497): tensor(-14.8497),
 tensor(-14.8424): tensor(-14.8424),
 tensor(-14.8350): tensor(-14.8350),
 tensor(-14.8277): tensor(-14.8277),
 tensor(-14.8204): tensor(-14.8204),
 tensor(-14.8130): tensor(-14.81

In [2]:
pipeline = ChronosPipeline.from_pretrained('amazon/chronos-t5-small')
print(pipeline.tokenizer.centers)
print(len(pipeline.tokenizer.centers))
print(pipeline.tokenizer.boundaries)
print(len(pipeline.tokenizer.boundaries))

tensor([-15.0000, -14.9927, -14.9853,  ...,  14.9853,  14.9927,  15.0000])
4093
tensor([-1.0000e+20, -1.4996e+01, -1.4989e+01,  ...,  1.4989e+01,
         1.4996e+01,  1.0000e+20])
4094


In [19]:
entry = {
    'past_target': np.array([23., 11., 16., 20., 12., 14., 22., 20., 17., 13.,  9., 12., 12., 11., 11., 20.,  6., 11.], dtype=np.float32), 
    'future_target': np.array([10.,  5., 15., 13., 14., 11., 11., 12., 15., 17., 14., 15.], dtype=np.float32)
}
past_target = entry['past_target']
future_target = entry['future_target']
num_special_toks = 2
n_tokens = 4096

scale = np.mean(past_target)
scaled_past_target = past_target / scale
scale, scaled_past_target

(14.444445,
 array([1.5923077 , 0.76153845, 1.1076922 , 1.3846154 , 0.83076924,
        0.9692308 , 1.5230769 , 1.3846154 , 1.176923  , 0.9       ,
        0.6230769 , 0.83076924, 0.83076924, 0.76153845, 0.76153845,
        1.3846154 , 0.41538462, 0.76153845], dtype=float32))

In [13]:
for i in [2262, 2263, 2264, 2265, 2266, 2267, 2268]:
    val = chronos_tokenizer.chronos_tokenizer.boundaries[i].item()
    print(i, chronos_tokenizer.chronos_tokenizer.boundaries[i], scaled_past_target[0] - val)

2262 tensor(1.5799) 0.012395620346069336
2263 tensor(1.5872) 0.005064249038696289
2264 tensor(1.5946) -0.002267122268676758
2265 tensor(1.6019) -0.009598493576049805
2266 tensor(1.6092) -0.01692986488342285
2267 tensor(1.6166) -0.0242612361907959
2268 tensor(1.6239) -0.031592607498168945


In [24]:
{i: j for i, j in zip(chronos_tokenizer.chronos_tokenizer.centers.tolist(), list(range(num_special_toks, n_tokens)))}

{-15.0: 2,
 -14.992668151855469: 3,
 -14.985337257385254: 4,
 -14.978005409240723: 5,
 -14.970674514770508: 6,
 -14.963342666625977: 7,
 -14.956011772155762: 8,
 -14.94867992401123: 9,
 -14.941349029541016: 10,
 -14.934017181396484: 11,
 -14.92668628692627: 12,
 -14.919354438781738: 13,
 -14.912023544311523: 14,
 -14.904691696166992: 15,
 -14.897360801696777: 16,
 -14.890028953552246: 17,
 -14.882698059082031: 18,
 -14.8753662109375: 19,
 -14.868035316467285: 20,
 -14.860703468322754: 21,
 -14.853372573852539: 22,
 -14.846040725708008: 23,
 -14.838709831237793: 24,
 -14.831377983093262: 25,
 -14.824047088623047: 26,
 -14.816715240478516: 27,
 -14.8093843460083: 28,
 -14.80205249786377: 29,
 -14.794721603393555: 30,
 -14.787389755249023: 31,
 -14.780058860778809: 32,
 -14.772727012634277: 33,
 -14.765396118164062: 34,
 -14.758064270019531: 35,
 -14.750733375549316: 36,
 -14.743401527404785: 37,
 -14.73607063293457: 38,
 -14.728738784790039: 39,
 -14.721407890319824: 40,
 -14.71407604217

In [14]:
for i in [2262, 2263, 2264, 2265, 2266, 2267, 2268]:
    val = chronos_tokenizer.chronos_tokenizer.centers[i].item()
    print(i, chronos_tokenizer.chronos_tokenizer.centers[i])

2262 tensor(1.5836)
2263 tensor(1.5909)
2264 tensor(1.5982)
2265 tensor(1.6056)
2266 tensor(1.6129)
2267 tensor(1.6202)
2268 tensor(1.6276)


In [8]:
chronos_tokenizer = ChronosTokenizerWrapper('amazon/chronos-t5-small')
type(chronos_tokenizer)
chronos_tokenizer._to_hf_format(entry)

{'input_ids': tensor([2266, 2153, 2200, 2238, 2162, 2181, 2257, 2238, 2210, 2172, 2134, 2162,
         2162, 2153, 2153, 2238, 2106, 2153,    1]),
 'attention_mask': tensor([True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True]),
 'labels': tensor([2143, 2096, 2191, 2172, 2181, 2153, 2153, 2162, 2191, 2210, 2181, 2191,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100,    1])}

Should output this:

```
{'input_ids': tensor([2266, 2153, 2200, 2238, 2162, 2181, 2257, 2238, 2210, 2172, 2134, 2162, 2162, 2153, 2153, 2238, 2106, 2153,    1]),
 'attention_mask': tensor([True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]),
 'labels': tensor([2143, 2096, 2191, 2172, 2181, 2153, 2153, 2162, 2191, 2210, 2181, 2191, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    1])}
```

In [18]:
label = torch.Tensor([[10.,  5., 15., 13., 14., 11., 11., 12., 15., 17., 14., 15.]])
scale = torch.Tensor([14.4444])

centers = torch.linspace(
    -15.0,
    15.0,
    4093,
)
boundaries = torch.concat(
    (
        torch.tensor([-1e20], device=centers.device),
        (centers[1:] + centers[:-1]) / 2,
        torch.tensor([1e20], device=centers.device),
    )
)

In [19]:
context = label
attention_mask = ~torch.isnan(context)

if scale is None:
    scale = torch.nansum(
        torch.abs(context) * attention_mask, dim=-1
    ) / torch.nansum(attention_mask, dim=-1)
    scale[~(scale > 0)] = 1.0

scaled_context = context / scale.unsqueeze(dim=-1)
token_ids = (
    torch.bucketize(
        input=scaled_context,
        boundaries=boundaries,
        right=True,
    )
    + 2
)
token_ids[~attention_mask] = self.config.pad_token_id

# token_ids, attention_mask, scale

In [14]:
scaled_context

tensor([[0.6923, 0.3462, 1.0385, 0.9000, 0.9692, 0.7615, 0.7615, 0.8308, 1.0385,
         1.1769, 0.9692, 1.0385]])

In [4]:
from chronos.scripts import ChronosDataset

ModuleNotFoundError: No module named 'chronos.scripts'

# Writing to MDS

In [4]:
df = pd.read_csv('/mnt/workdisk/kushal/llm-foundry/kushal-testing/hospital.csv')
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,16.0,13.0,15.0,16.0,21.0,12.0,14.0,15.0,21.0,20.0,...,18.0,14.0,6.0,15.0,21.0,17.0,14.0,12.0,8.0,17.0
1,14.0,7.0,8.0,12.0,18.0,19.0,16.0,15.0,16.0,15.0,...,13.0,9.0,17.0,22.0,11.0,19.0,11.0,12.0,20.0,10.0
2,204.0,202.0,206.0,207.0,210.0,228.0,194.0,166.0,198.0,225.0,...,196.0,192.0,210.0,198.0,193.0,190.0,186.0,181.0,198.0,169.0
3,124.0,128.0,131.0,127.0,115.0,131.0,116.0,124.0,116.0,97.0,...,105.0,67.0,102.0,90.0,93.0,103.0,103.0,82.0,76.0,81.0
4,17.0,14.0,10.0,19.0,14.0,10.0,12.0,19.0,11.0,18.0,...,10.0,17.0,18.0,13.0,23.0,14.0,22.0,26.0,24.0,19.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
762,30.0,39.0,33.0,22.0,27.0,33.0,31.0,26.0,47.0,42.0,...,35.0,31.0,41.0,34.0,34.0,30.0,21.0,33.0,37.0,32.0
763,42.0,42.0,51.0,54.0,57.0,62.0,51.0,59.0,61.0,42.0,...,42.0,54.0,39.0,38.0,45.0,39.0,45.0,55.0,46.0,31.0
764,54.0,53.0,60.0,51.0,55.0,54.0,57.0,61.0,60.0,51.0,...,39.0,42.0,37.0,36.0,40.0,34.0,39.0,42.0,44.0,30.0
765,230.0,289.0,231.0,288.0,286.0,291.0,231.0,252.0,286.0,368.0,...,410.0,361.0,374.0,388.0,401.0,434.0,415.0,546.0,453.0,456.0


In [14]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data: Union[torch.Tensor, List[torch.Tensor], pd.DataFrame, np.ndarray], 
                 context_len: int, horizon_len: int, transform: Optional[bool]=None):
        """
        Args:
            data (pd.DataFrame, np.ndarray, torch.Tensor, or List[torch.Tensor]): The input data containing features.
            targets (pd.DataFrame, np.ndarray, torch.Tensor, or List[torch.Tensor]): The target values.
            seq_length (int): The length of each sequence.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = data
        self.context_len = context_len
        self.horizon_len = horizon_len
        self.transform = transform

    def __len__(self):
        # Return the number of sequences in the dataset
        return len(self.data)

    def __getitem__(self, idx: int):
        x = self.data.iloc[idx, :self.context_len]
        y = self.data.iloc[idx, self.context_len:]

        if self.transform:
            x = self.transform(x)
        
        return x.to_numpy().astype(np.float16), y.to_numpy().astype(np.float16)

## Split in `context` and `horizon` MDS

In [98]:
def split_row(row, context_len=18):
    c = row[:context_len].to_numpy().astype(np.float64)
    h = row[context_len:].to_numpy().astype(np.float64)
    return c, h

new_data = df.apply(split_row, axis=1)
new_df = pd.DataFrame(new_data.tolist(), columns=['context', 'horizon'])
new_df

Unnamed: 0,context,horizon
0,"[16.0, 13.0, 15.0, 16.0, 21.0, 12.0, 14.0, 15....","[13.0, 19.0, 18.0, 14.0, 6.0, 15.0, 21.0, 17.0..."
1,"[14.0, 7.0, 8.0, 12.0, 18.0, 19.0, 16.0, 15.0,...","[13.0, 15.0, 13.0, 9.0, 17.0, 22.0, 11.0, 19.0..."
2,"[204.0, 202.0, 206.0, 207.0, 210.0, 228.0, 194...","[205.0, 180.0, 196.0, 192.0, 210.0, 198.0, 193..."
3,"[124.0, 128.0, 131.0, 127.0, 115.0, 131.0, 116...","[142.0, 115.0, 105.0, 67.0, 102.0, 90.0, 93.0,..."
4,"[17.0, 14.0, 10.0, 19.0, 14.0, 10.0, 12.0, 19....","[15.0, 6.0, 10.0, 17.0, 18.0, 13.0, 23.0, 14.0..."
...,...,...
762,"[30.0, 39.0, 33.0, 22.0, 27.0, 33.0, 31.0, 26....","[34.0, 33.0, 35.0, 31.0, 41.0, 34.0, 34.0, 30...."
763,"[42.0, 42.0, 51.0, 54.0, 57.0, 62.0, 51.0, 59....","[47.0, 40.0, 42.0, 54.0, 39.0, 38.0, 45.0, 39...."
764,"[54.0, 53.0, 60.0, 51.0, 55.0, 54.0, 57.0, 61....","[41.0, 41.0, 39.0, 42.0, 37.0, 36.0, 40.0, 34...."
765,"[230.0, 289.0, 231.0, 288.0, 286.0, 291.0, 231...","[347.0, 373.0, 410.0, 361.0, 374.0, 388.0, 401..."


In [16]:
# Write to MDS
output_dir = '/mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_context_horizon_split/'  # shards written to a local directory
columns = {
    'context': 'ndarray:float16:18',
    'horizon': 'ndarray:float16:12',
}

%rm -rf /mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_context_horizon_split/
from streaming.base import MDSWriter, StreamingDataset
dataset = TimeSeriesDataset(data=df, context_len=18, horizon_len=12)
with MDSWriter(out=output_dir, columns=columns) as out:
    for x, y in dataset:
        out.write({'context': x.astype(np.float16), 'horizon': y.astype(np.float16)})

In [17]:
# Inspect dataset
dataset = StreamingDataset(local=output_dir, batch_size=16)

Because `predownload` was not specified, it will default to 8*batch_size if batch_size is not None, otherwise 64. Prior to Streaming v0.7.0, `predownload` defaulted to max(batch_size, 256 * batch_size // num_canonical_nodes).


DistStoreError: Timed out after 601 seconds waiting for clients. 1/2 clients joined.

In [None]:
for i in range(dataset.num_samples):
    print(i)
    if i % 50 == 0:
        print(dataset[i])

## Keep a single column in MDS format

In [6]:
class TestStreamingTimeSeriesDataset(StreamingDataset):
    def __init__(
        self, tokenizer: PreTrainedTokenizerBase, seq_len: int,
        streams: Optional[Sequence[Stream]] = None,remote: Optional[str] = None,local: Optional[str] = None,split: Optional[str] = None,download_retry: int = 2,download_timeout: float = 60,validate_hash: Optional[str] = None,
        keep_zip: bool = False,epoch_size: Optional[Union[int, str]] = None,predownload: Optional[int] = None,cache_limit: Optional[Union[int, str]] = None,partition_algo: str = 'relaxed',num_canonical_nodes: Optional[int] = None,
        batch_size: Optional[int] = None,shuffle: bool = False,shuffle_algo: str = 'py1e',shuffle_seed: int = 9176,shuffle_block_size: Optional[int] = None,sampling_method: str = 'balanced',sampling_granularity: int = 1,
        batching_method: str = 'random',allow_unsafe_types: bool = False,replication: Optional[int] = None,transform: Optional[Callable] = None,**kwargs: Any,
    ):
        
        # Build Dataset
        super().__init__(
            streams=streams,remote=remote,local=local,split=split,download_retry=download_retry,download_timeout=download_timeout,validate_hash=validate_hash,
            keep_zip=keep_zip,epoch_size=epoch_size,predownload=predownload,cache_limit=cache_limit,partition_algo=partition_algo,num_canonical_nodes=num_canonical_nodes,
            batch_size=batch_size,shuffle=shuffle,shuffle_algo=shuffle_algo,shuffle_seed=shuffle_seed,shuffle_block_size=shuffle_block_size,sampling_method=sampling_method,
            sampling_granularity=sampling_granularity,batching_method=batching_method,allow_unsafe_types=allow_unsafe_types,replication=replication,
        )
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.transform = transform        
        
    def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
        sample = super().__getitem__(idx)
        sample_data = sample['data']
        return sample_data

In [11]:
def row_to_numpy(row: pd.Series):
    return row.to_numpy().astype(np.float64)

one_col_data = df.apply(row_to_numpy, axis=1)
one_col_df = pd.DataFrame(one_col_data, columns=['text'])
one_col_df

Unnamed: 0,text
0,"[16.0, 13.0, 15.0, 16.0, 21.0, 12.0, 14.0, 15...."
1,"[14.0, 7.0, 8.0, 12.0, 18.0, 19.0, 16.0, 15.0,..."
2,"[204.0, 202.0, 206.0, 207.0, 210.0, 228.0, 194..."
3,"[124.0, 128.0, 131.0, 127.0, 115.0, 131.0, 116..."
4,"[17.0, 14.0, 10.0, 19.0, 14.0, 10.0, 12.0, 19...."
...,...
762,"[30.0, 39.0, 33.0, 22.0, 27.0, 33.0, 31.0, 26...."
763,"[42.0, 42.0, 51.0, 54.0, 57.0, 62.0, 51.0, 59...."
764,"[54.0, 53.0, 60.0, 51.0, 55.0, 54.0, 57.0, 61...."
765,"[230.0, 289.0, 231.0, 288.0, 286.0, 291.0, 231..."


In [12]:
one_col_small_df = one_col_df.iloc[:20]
one_col_small_df

Unnamed: 0,text
0,"[16.0, 13.0, 15.0, 16.0, 21.0, 12.0, 14.0, 15...."
1,"[14.0, 7.0, 8.0, 12.0, 18.0, 19.0, 16.0, 15.0,..."
2,"[204.0, 202.0, 206.0, 207.0, 210.0, 228.0, 194..."
3,"[124.0, 128.0, 131.0, 127.0, 115.0, 131.0, 116..."
4,"[17.0, 14.0, 10.0, 19.0, 14.0, 10.0, 12.0, 19...."
5,"[12.0, 14.0, 12.0, 12.0, 8.0, 17.0, 10.0, 11.0..."
6,"[6.0, 9.0, 15.0, 15.0, 13.0, 11.0, 9.0, 14.0, ..."
7,"[26.0, 40.0, 30.0, 25.0, 30.0, 20.0, 22.0, 26...."
8,"[23.0, 11.0, 16.0, 20.0, 12.0, 14.0, 22.0, 20...."
9,"[31.0, 32.0, 23.0, 35.0, 29.0, 24.0, 21.0, 21...."


In [13]:
# Write to MDS
%rm -rf /mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col/

output_dir = '/mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col/'  # shards written to a local directory
columns = {'text': 'ndarray:float64:30'}

with MDSWriter(out=output_dir, columns=columns) as out:
    for index, row in one_col_df.iterrows():
        out.write({'text': row[0].astype(np.float64)})

  out.write({'text': row[0].astype(np.float64)})


In [14]:
# Write small to MDS
%rm -rf /mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col_small/

output_dir = '/mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col_small/'  # shards written to a local directory
columns = {'text': 'ndarray:float64:30'}

with MDSWriter(out=output_dir, columns=columns) as out:
    for index, row in one_col_small_df.iterrows():
        out.write({'text': row[0].astype(np.float64)})

  out.write({'text': row[0].astype(np.float64)})


In [47]:
# Read in entire large dataset - CANNOT DO THIS IN JUPYTER NOTEBOOK
ds = StreamingTimeSeriesDataset(tokenizer=ChronosTokenizerWrapper('amazon/chronos-t5-small'), seq_len=30, local='/mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col/')

Because `predownload` was not specified, it will default to 8*batch_size if batch_size is not None, otherwise 64. Prior to Streaming v0.7.0, `predownload` defaulted to max(batch_size, 256 * batch_size // num_canonical_nodes).


In [20]:
# Read in small dataset - CANNOT DO THIS IN JUPYTER NOTEBOOK
ds_small = StreamingTimeSeriesDataset(tokenizer=ChronosTokenizerWrapper('amazon/chronos-t5-small'), seq_len=30, local='/mnt/workdisk/kushal/llm-foundry/kushal-testing/mds_single_col_small/')

Because `predownload` was not specified, it will default to 8*batch_size if batch_size is not None, otherwise 64. Prior to Streaming v0.7.0, `predownload` defaulted to max(batch_size, 256 * batch_size // num_canonical_nodes).


# Testing Pipeline

In [19]:
streaming_ts = StreamingTimeSeriesDataset(tokenizer=ChronosTokenizerWrapper('amazon/chronos-t5-small'), seq_len=30, local='/mnt/workdisk/kushal/llm-foundry/kushal-testing/hospital_train.csv')

TypeError: StreamingTimeSeriesDataset.__init__() missing 1 required positional argument: 'seq_len'

# Creating torch.Dataset or StreamingDataset

In [21]:
# Template for creating your own torch.Dataset
class CustomDataset(Dataset):
    def __init__(self, data: Union[torch.Tensor, List[torch.Tensor], pd.DataFrame], labels: Union[torch.Tensor, List[torch.Tensor], pd.DataFrame], transform: Optional[bool]=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        sample = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample, label


In [104]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data: Union[torch.Tensor, List[torch.Tensor], pd.DataFrame, np.ndarray], 
                 context_len: int, horizon_len: int, transform: Optional[bool]=None):
        """
        Args:
            data (pd.DataFrame, np.ndarray, torch.Tensor, or List[torch.Tensor]): The input data containing features.
            targets (pd.DataFrame, np.ndarray, torch.Tensor, or List[torch.Tensor]): The target values.
            seq_length (int): The length of each sequence.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = data
        self.context_len = context_len
        self.horizon_len = horizon_len
        self.transform = transform

    def __len__(self):
        # Return the number of sequences in the dataset
        return len(self.data)

    def __getitem__(self, idx: int):
        # Get the sequence of data and the corresponding target
        # x = self.data.iloc[[idx]][:self.context_len]  # for pd.dataframes
        # y = self.data.iloc[[idx]][self.context_len:]  # for pd.dataframes
        # x = self.data[idx]  # for torch.tensors
        # y = self.targets[idx]  # for torch.tensors
        
        if idx >= self.__len__():
            return
        
        x = self.data['context'][idx]
        y = self.data['horizon'][idx]

        if self.transform:
            x = self.transform(x)

        # return torch.tensor(x.values, dtype=torch.float32), torch.tensor(y.values, dtype=torch.float32)  # for pd.dataframes
        # return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)  # for torch.tensors
        return x.astype(np.float64), y.astype(np.float64)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,16.0,13.0,15.0,16.0,21.0,12.0,14.0,15.0,21.0,20.0,...,18.0,14.0,6.0,15.0,21.0,17.0,14.0,12.0,8.0,17.0


In [45]:
# Read in pd.DataFrame
dataset = TimeSeriesDataset(data=train, targets=test, seq_length=20, transform=None)
dataset

<__main__.TimeSeriesDataset at 0x7f429ad3f550>

In [41]:
# Read in torch.Tensor
dataset = TimeSeriesDataset(data=torch.tensor(train.values), targets=torch.tensor(test.values), seq_length=20, transform=None)
dataset

<__main__.TimeSeriesDataset at 0x7f429adf51d0>

In [46]:
dataloader = DataLoader(dataset=dataset, batch_size=8, shuffle=True)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f429d2ee750>

In [47]:
for batch_idx, (inputs, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}")
    print("Inputs:", inputs)
    print("Labels:", labels)

Batch 1
Inputs: tensor([[[5.6500e+02, 5.4400e+02, 6.0400e+02, 5.0900e+02, 5.2300e+02,
          5.3800e+02, 4.0200e+02, 3.7000e+02, 3.6400e+02, 3.1100e+02,
          3.0400e+02, 3.1900e+02, 4.3900e+02, 4.0600e+02, 3.4300e+02,
          3.5500e+02, 3.5200e+02, 2.8600e+02]],

        [[5.4003e+02, 4.7680e+02, 5.9097e+02, 4.4837e+02, 5.6257e+02,
          5.0162e+02, 5.5028e+02, 4.9375e+02, 4.8995e+02, 5.9162e+02,
          6.3917e+02, 4.9929e+02, 5.5674e+02, 4.8648e+02, 5.8371e+02,
          4.7555e+02, 5.7836e+02, 5.2526e+02]],

        [[5.3340e+02, 5.2860e+02, 5.2812e+02, 5.9384e+02, 5.2046e+02,
          4.9946e+02, 5.6732e+02, 5.8370e+02, 5.7809e+02, 5.5311e+02,
          5.1725e+02, 5.3925e+02, 5.4666e+02, 5.8754e+02, 4.6565e+02,
          5.3589e+02, 4.7663e+02, 5.2292e+02]],

        [[5.4259e+02, 4.6580e+02, 4.1824e+02, 4.1735e+02, 4.8898e+02,
          4.4101e+02, 4.6335e+02, 4.6479e+02, 4.1452e+02, 4.3501e+02,
          4.2747e+02, 5.1433e+02, 5.0989e+02, 4.2930e+02, 4.1680e+0