In [2]:
import sys
sys.path.append('../..')

In [7]:
from basicts.utils import load_pkl
from torch.utils.data import Dataset
import torch
import os

In [57]:
class ForecastingDataset(Dataset):

    def __init__(self, data_file_path: str, index_file_path: str, mode: str, seq_len:int) -> None:
        """Init the dataset in the forecasting stage.

        Args:
            data_file_path (str): data file path.
            index_file_path (str): index file path.
            mode (str): train, valid, or test.
            seq_len (int): the length of long term historical data.
        """

        super().__init__()
        assert mode in ["train", "valid", "test"], "error mode"
        self._check_if_file_exists(data_file_path, index_file_path)
        # read raw data (normalized)
        data = load_pkl(data_file_path)
        processed_data = data["processed_data"]
        self.data = torch.from_numpy(processed_data).float()
        # read index
        self.index = load_pkl(index_file_path)[mode]
        # length of long term historical data
        self.seq_len = seq_len
        # mask
        self.mask = torch.zeros(self.seq_len, self.data.shape[1], self.data.shape[2])

    def _check_if_file_exists(self, data_file_path: str, index_file_path: str):
        """Check if data file and index file exist.

        Args:
            data_file_path (str): data file path
            index_file_path (str): index file path

        Raises:
            FileNotFoundError: no data file
            FileNotFoundError: no index file
        """

        if not os.path.isfile(data_file_path):
            raise FileNotFoundError("BasicTS can not find data file {0}".format(data_file_path))
        if not os.path.isfile(index_file_path):
            raise FileNotFoundError("BasicTS can not find index file {0}".format(index_file_path))

    def __getitem__(self, index: int|list) -> tuple:
        """Get a sample.

        Args:
            index (int): the iteration index (not the self.index)

        Returns:
            tuple: (future_data, history_data), where the shape of each is L x N x C.
        """
        if isinstance(index, int):
            idx = list(self.index[index])
        else:
            idx = index
        history_data = self.data[idx[0]:idx[1]]     # 12
        future_data = self.data[idx[1]:idx[2]]      # 12
        if idx[1] - self.seq_len < 0:
            long_history_data = self.mask
        else:
            long_history_data = self.data[idx[1] - self.seq_len:idx[1]]     # 11

        return future_data, history_data, long_history_data

    def __len__(self):
        """Dataset length

        Returns:
            int: dataset length
        """

        return len(self.index)


In [58]:
dd = ForecastingDataset(
    "/home/seyed/PycharmProjects/step/STEP/datasets/METR-LA/data_in12_out12.pkl",
    "/home/seyed/PycharmProjects/step/STEP/datasets/METR-LA/index_in12_out12.pkl",
    "test",
    288 * 7
)

In [73]:
from collections import defaultdict

In [86]:
nested_dict = defaultdict(lambda: defaultdict(list))

In [88]:
nested_dict["rrr"]["10"].append(2)

In [89]:
nested_dict["rrr"]["10"]

[2]

In [63]:
import datetime
def find_indices(date: datetime.datetime):
    df = pd.read_hdf("/home/seyed/PycharmProjects/step/STEP/datasets/raw_data/METR-LA/METR-LA.h5")
    dates = df.index.to_pydatetime().tolist()
    mapping = {date: index for index, date in enumerate(dates)}
    idx = mapping[date]
    indices = [idx - 12, idx, idx + 12]
    return indices

In [64]:
data = dd[find_indices(datetime.datetime(2012, 4, 1, 0, 35))]

In [65]:
data[0], data[1].shape, data[2].shape

(torch.Size([12, 207, 3]),
 torch.Size([12, 207, 3]),
 torch.Size([2016, 207, 3]))

In [66]:
# DataLoader = torch.utils.data.DataLoader(dd, batch_size=1, shuffle=False, num_workers=0)

In [68]:
DataLoader[find_indices(datetime.datetime(2012, 4, 1, 0, 35))]

TypeError: 'DataLoader' object is not subscriptable

In [72]:
data[0]

tensor([[[0.7030, 0.0243, 6.0000],
         [0.7372, 0.0243, 6.0000],
         [0.7429, 0.0243, 6.0000],
         ...,
         [0.6346, 0.0243, 6.0000],
         [0.7429, 0.0243, 6.0000],
         [0.4352, 0.0243, 6.0000]],

        [[0.7230, 0.0278, 6.0000],
         [0.6012, 0.0278, 6.0000],
         [0.6460, 0.0278, 6.0000],
         ...,
         [0.6525, 0.0278, 6.0000],
         [0.2741, 0.0278, 6.0000],
         [0.4280, 0.0278, 6.0000]],

        [[0.7600, 0.0312, 6.0000],
         [0.6859, 0.0312, 6.0000],
         [0.7258, 0.0312, 6.0000],
         ...,
         [0.6118, 0.0312, 6.0000],
         [0.7201, 0.0312, 6.0000],
         [0.4124, 0.0312, 6.0000]],

        ...,

        [[0.6175, 0.0556, 6.0000],
         [0.6517, 0.0556, 6.0000],
         [0.6745, 0.0556, 6.0000],
         ...,
         [0.5321, 0.0556, 6.0000],
         [0.7828, 0.0556, 6.0000],
         [0.4523, 0.0556, 6.0000]],

        [[0.6460, 0.0590, 6.0000],
         [0.6589, 0.0590, 6.0000],
         [0.