# datasets

> PyTorch datasets for interacting with timeseries data

In [None]:
# | default_exp datasets

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# | hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
# |export

from datetime import datetime
from functools import partial
import os
from typing import *

import dask.dataframe as dd
from fastcore.basics import patch
import numpy as np
from omegaconf import MISSING
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype
from torch.utils.data import Dataset

from rlmm.core import *
from rlmm.utils import *

Datasets are used to decouple to data loading logic from the training environment.

Input data is expected to be either csv or parquet for now, and the data can either be loaded fully in memory with pandas or lazily with dask. The output of `__getitem__` should be a pandas dataframe for training. This dataframe will be of the specified time range and sample frequency in format "%y_%m_%d-%H_%M_%S" (if in str format).

In [None]:
path_book = "../data/test/parquet/book_snapshot_25/ETHUSDT"
path_trades = "../data/test/parquet/trades/ETHUSDT"
time_start = "21_03_05-08_30_00"
time_end = "21_03_05-10_30_00"
time_format = "%y_%m_%d-%H_%M_%S"
resample_interval = "minute"
resample_frequency = 1

In [None]:
# | export


class DatasetSequential(Dataset):
    def __init__(
        self,
        path: str,
        time_start: Union[str, datetime, None] = None,
        time_end: Union[str, datetime, None] = None,
        time_format: str = "%y_%m_%d-%H_%M_%S",
        resample_interval: Optional[str] = None,
        resample_frequency: int = 1,
    ) -> None:
        self.path = path
        self.time_start = time_start
        self.time_end = time_end
        self.time_format = time_format
        self.resample_interval = resample_interval
        self.resample_frequency = resample_frequency

        if not os.path.exists(os.path.abspath(self.path)):
            raise FileNotFoundError(f"File path {self.path} does not exist")

        self.df = dd.read_parquet(self.path, calculate_divisions=True)

        # if not isinstance(self.df.index, dd.DatetimeIndex):
        #     raise ValueError("Index must be a datetime index")
        if not isinstance(self.df.index, dd.Index) and is_datetime64_any_dtype(
            self.df.index.dtype
        ):
            raise ValueError("Index must be a datetime index")

        self.idx_start = self.df.head(1).index[0]
        self.idx_end = self.df.tail(1).index[0]

        if self.time_start is not None:
            if isinstance(self.time_start, str):
                self.time_start = datetime.strptime(self.time_start, self.time_format)
            if self.time_start < self.idx_start:
                raise ValueError("Start time is before starting index")

        if self.time_end is not None:
            if isinstance(self.time_end, str):
                self.time_end = datetime.strptime(self.time_end, self.time_format)
            if self.time_end > self.idx_end:
                raise ValueError("End time is after ending index")

        #         ADD SORTING HERE

        if self.resample_interval is not None:
            # self.df = self.df.set_index(self.df.index, npartitions='auto')
            # self.df = self.df.set_index(self.df.index, sort=True, divisions=self.df.index.compute())
            resample_code = get_resample_code(
                self.resample_interval, self.resample_frequency
            )
            self.df = self.df.resample(resample_code).first()

        self.index = None
        self.length = None

In [None]:
params = {
    "path": MISSING,
    "time_start": time_start,
    "time_end": time_end,
    "time_format": time_format,
    "resample_interval": resample_interval,
    "resample_frequency": resample_frequency,
}

hydra_nb(obj=DatasetSequential, path="../conf/datasets/base.yaml", params=params)

time_start: 21_03_05-08_30_00
time_end: 21_03_05-10_30_00
time_format: '%y_%m_%d-%H_%M_%S'
resample_interval: minute
resample_frequency: 1
path: ???



In [None]:
# | export


@patch
def _get_index(self: DatasetSequential) -> pd.Index:
    if self.index is None:
        self.index = self.df.index.compute()

    return self.index

In [None]:
# | export


@patch
def __len__(self: DatasetSequential) -> int:
    return len(self._get_index())

In [None]:
# | export


@patch
def __getitem__(
    self: DatasetSequential, idx: Union[int, datetime, slice]
) -> pd.DataFrame:
    if isinstance(idx, int):
        index = self._get_index()
        return self.df.loc[index[idx]].compute()
    elif isinstance(idx, datetime):
        return self.df.loc[idx].compute()
    elif isinstance(idx, slice):
        if idx.step is not None:
            raise ValueError("slice step is not supported")
        start, stop = idx.start, idx.stop
        if start is None and stop is None:
            return self.df.compute()
        if isinstance(start, (int, np.int64)) and isinstance(stop, (int, np.int64)):
            if not (
                (start < 0 or start >= self.__len__())
                or (stop < 0 or stop >= self.__len__())
            ):
                index = self._get_index()
                return self.df.loc[index[start:stop]].compute()
            else:
                raise ValueError("Start and stop out of range")
        elif isinstance(start, (datetime, np.datetime64)) and isinstance(
            stop, (datetime, np.datetime64)
        ):
            return self.df.loc[start:stop].compute()
        else:
            raise TypeError("start and stop of slice must be int or datetime")
    else:
        raise TypeError("idx must be int, datetime, slice")

In [None]:
# | export


class DatasetBook(DatasetSequential):
    def __init__(
        self,
        path: str,
        time_start: Union[str, datetime, None] = None,
        time_end: Union[str, datetime, None] = None,
        time_format: str = "%y_%m_%d-%H_%M_%S",
        resample_interval: Optional[str] = "minute",
        resample_frequency: int = 1,
        price_levels: int = 10,
        col_prices_ask_re: str = "(?=.*price)(?=.*ask)",
        col_prices_bid_re: str = "(?=.*price)(?=.*bid)",
        col_volumes_ask_re: str = "(?=.*amount)(?=.*ask)",
        col_volumes_bid_re: str = "(?=.*amount)(?=.*bid)",
        col_sort: str = "\d+",
    ):
        super().__init__(
            path=path,
            time_start=time_start,
            time_end=time_end,
            resample_interval=resample_interval,
            resample_frequency=resample_frequency,
            time_format=time_format,
        )

        self.price_levels = price_levels
        self.col_prices_ask_re = col_prices_ask_re
        self.col_prices_bid_re = col_prices_bid_re
        self.col_volumes_ask_re = col_volumes_ask_re
        self.col_volumes_bid_re = col_volumes_bid_re
        self.col_sort = col_sort

        cols_partial = partial(
            list_regex, lst=self.df.columns.to_list(), regex_sort=self.col_sort
        )

        self.cols_prices_ask = cols_partial(regex_match=self.col_prices_ask_re)[
            0 : self.price_levels
        ]
        self.cols_prices_bid = cols_partial(regex_match=self.col_prices_bid_re)[
            0 : self.price_levels
        ]
        self.cols_volumes_ask = cols_partial(regex_match=self.col_volumes_ask_re)[
            0 : self.price_levels
        ]
        self.cols_volumes_bid = cols_partial(regex_match=self.col_volumes_bid_re)[
            0 : self.price_levels
        ]
        self.df = self.df[
            (
                self.cols_prices_ask
                + self.cols_volumes_ask
                + self.cols_prices_bid
                + self.cols_volumes_bid
            )
        ]

In [None]:
defaults = ["base"]
params = {"_target_": "rlmm.datasets.DatasetBook", "path": path_book}
new_only = True

hydra_nb(
    obj=DatasetBook,
    path="../conf/datasets/book.yaml",
    defaults=defaults,
    params=params,
    new_only=new_only,
)

defaults:
- base
price_levels: 10
col_prices_ask_re: (?=.*price)(?=.*ask)
col_prices_bid_re: (?=.*price)(?=.*bid)
col_volumes_ask_re: (?=.*amount)(?=.*ask)
col_volumes_bid_re: (?=.*amount)(?=.*bid)
col_sort: \d+
path: ../data/test/parquet/book_snapshot_25/ETHUSDT
_target_: rlmm.datasets.DatasetBook



In [None]:
dsb = DatasetBook(
    path=path_book,
    time_start=time_start,
    time_end=time_end,
    resample_interval=resample_interval,
    resample_frequency=resample_frequency,
    time_format=time_format,
)

In [None]:
dsb_int = dsb[20:25]
test_eq(len(dsb_int), 5)

dsb_int

Unnamed: 0_level_0,asks[0].price,asks[1].price,asks[2].price,asks[3].price,asks[4].price,asks[5].price,asks[6].price,asks[7].price,asks[8].price,asks[9].price,...,bids[0].amount,bids[1].amount,bids[2].amount,bids[3].amount,bids[4].amount,bids[5].amount,bids[6].amount,bids[7].amount,bids[8].amount,bids[9].amount
ts,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021-03-05 00:20:00,1496.5,1496.7,1496.85,1496.87,1496.89,1496.9,1496.99,1497.01,1497.03,1497.19,...,1.6871,1.33584,0.7644,0.4,0.06661,1.3359,0.5628,1.35641,0.375,1.5495
2021-03-05 00:21:00,1498.36,1498.41,1498.61,1498.75,1498.82,1498.92,1499.0,1499.01,1499.02,1499.03,...,11.22826,15.0,8.05177,0.66743,9.60995,11.74502,1.33452,3.15002,15.0,2.03475
2021-03-05 00:22:00,1490.01,1490.04,1490.16,1490.29,1490.3,1490.33,1490.34,1490.39,1490.4,1490.42,...,248.37412,0.45226,1.1546,0.01008,0.00854,0.3355,0.02733,0.06713,0.60066,0.46729
2021-03-05 00:23:00,1491.64,1491.65,1491.71,1491.72,1491.92,1492.03,1492.05,1492.13,1492.16,1492.35,...,3.80541,15.0,0.00741,2.01227,3.15001,15.0,1.0,1.92038,1.0,0.37875
2021-03-05 00:24:00,1489.64,1489.67,1489.68,1489.71,1489.73,1489.74,1489.8,1489.95,1489.96,1490.06,...,6.0,0.46661,0.25,2.36004,0.23313,11.00518,3.05,1.68001,8.0,0.20823


In [None]:
dt_start = datetime.strptime(time_start, time_format)
dt_stop = datetime.strptime(time_end, time_format)

dsb_dt = dsb[dt_start:dt_stop]
dsb_dt.head()

Unnamed: 0_level_0,asks[0].price,asks[1].price,asks[2].price,asks[3].price,asks[4].price,asks[5].price,asks[6].price,asks[7].price,asks[8].price,asks[9].price,...,bids[0].amount,bids[1].amount,bids[2].amount,bids[3].amount,bids[4].amount,bids[5].amount,bids[6].amount,bids[7].amount,bids[8].amount,bids[9].amount
ts,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2021-03-05 08:30:00,1448.56,1448.61,1448.68,1448.69,1448.7,1448.86,1448.87,1448.88,1448.93,1448.94,...,0.02,0.375,0.27707,1.55673,7.41301,0.03361,1.0,0.02876,0.0095,1.41361
2021-03-05 08:31:00,1454.66,1454.73,1454.74,1454.75,1454.8,1454.82,1454.84,1454.93,1455.05,1455.07,...,1.49972,1.76857,0.10235,0.17158,0.6,1.2,0.00699,1.4175,6.75,0.3675
2021-03-05 08:32:00,1452.88,1453.0,1453.02,1453.05,1453.18,1453.19,1453.22,1453.33,1453.34,1453.35,...,1.65698,0.8,0.0437,1.56558,0.0757,0.43346,2.064,2.0,1.9,2.0
2021-03-05 08:33:00,1453.66,1453.73,1453.79,1453.9,1453.94,1453.95,1453.97,1453.98,1454.0,1454.11,...,0.70976,0.07566,0.41792,1.99,2.0265,9.65,1.2,1.0,6.87883,2.06338
2021-03-05 08:34:00,1453.27,1453.28,1453.31,1453.33,1453.34,1453.38,1453.39,1453.4,1453.5,1453.51,...,1.0,6.44403,15.0,0.01394,0.49096,4.67878,1.455,5.28184,4.601,2.75283


In [None]:
# | export


class DatasetTrades(DatasetSequential):
    def __init__(
        self,
        path: str,
        time_start: Union[str, datetime, None] = None,
        time_end: Union[str, datetime, None] = None,
        resample_interval: Optional[str] = None,
        resample_frequency: int = None,
        time_format: str = "%y_%m_%d-%H_%M_%S",
        col_side: str = "side",
        col_price: str = "price",
        col_amount: str = "amount",
    ):
        super().__init__(
            path=path,
            time_start=time_start,
            time_end=time_end,
            resample_interval=resample_interval,
            resample_frequency=resample_frequency,
            time_format=time_format,
        )

        self.col_side = col_side
        self.col_price = col_price
        self.col_amount = col_amount

        self.df = self.df[[self.col_side, self.col_price, self.col_amount]]

In [None]:
defaults = ["base"]
params = {"_target_": "rlmm.datasets.DatasetTrades", "path": path_trades}
new_only = True

hydra_nb(
    DatasetTrades,
    path="../conf/datasets/trades.yaml",
    defaults=defaults,
    params=params,
    new_only=new_only,
)

defaults:
- base
col_side: side
col_price: price
col_amount: amount
path: ../data/test/parquet/trades/ETHUSDT
_target_: rlmm.datasets.DatasetTrades



In [None]:
dst = DatasetTrades(
    path=path_trades,
    time_start=time_start,
    time_end=time_end,
    time_format=time_format,
    resample_interval=resample_interval,
    resample_frequency=resample_frequency,
)

In [None]:
dst_int = dst[20:25]
test_eq(len(dst_int), 5)

dst_int

Unnamed: 0_level_0,side,price,amount
ts,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2021-03-05 00:20:00,buy,1496.5,0.66096
2021-03-05 00:21:00,sell,1498.35,0.00874
2021-03-05 00:22:00,sell,1490.0,3.0
2021-03-05 00:23:00,buy,1491.65,3.54992
2021-03-05 00:24:00,buy,1489.61,0.06931


In [None]:
dst_dt = dst[dt_start:dt_stop]
dst_dt.head()

Unnamed: 0_level_0,side,price,amount
ts,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
2021-03-05 08:30:00,sell,1448.55,0.02
2021-03-05 08:31:00,buy,1454.66,15.0
2021-03-05 08:32:00,sell,1452.87,0.42348
2021-03-05 08:33:00,sell,1453.73,0.19419
2021-03-05 08:34:00,sell,1452.99,0.04816


In [None]:
# | hide
import nbdev

nbdev.nbdev_export()