In [2]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

In [1]:
!pwd

/p/projects/ncps/slaf-project/data


In [18]:
class ESTDataset(Dataset):
    """EST hourly dataset."""
    STATIONS = [
        "AEP",
        "COMED",
        "DAYTON",
        "DEOK",
        "DOM",
        "DUQ",
        "EKPC",
        "FE",
        "NI",
        "PJME",
        "PJMW",
        "PJM_Load"
    ]

    def __init__(self, dataset_file, sequence_length=10, station="AEP"):
        """
        Arguments:
            dataset_file (string): Path to the csv file with annotations.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.sequence_length = sequence_length
        if station not in self.STATIONS and \
            station.upper() not in self.STATIONS:
            raise ValueError(f"{station} is not found in the dataset.")
        data = pd.read_parquet(dataset_file)[[station]]
        data.index.names = ["datetime"]
        data.rename(columns={station: "value"}, inplace=True)
        data.dropna(inplace=True)

        self.data = data.values  # Convert pandas DataFrame to numpy array

    def __len__(self):
        return len(self.data) - self.sequence_length

    def __getitem__(self, idx):
        sequence = self.data[idx:idx + self.sequence_length]
        target = self.data[idx + self.sequence_length]
        sequence = torch.tensor(sequence, dtype=torch.float32)
        target = torch.tensor(target, dtype=torch.float32)
        return sequence, target


tensor([[[16862.],
         [13491.],
         [12746.],
         [12459.],
         [12250.],
         [12490.],
         [13139.],
         [14329.],
         [15515.],
         [16288.]],

        [[21816.],
         [19926.],
         [15425.],
         [14570.],
         [13977.],
         [13591.],
         [13349.],
         [13304.],
         [13221.],
         [13501.]],

        [[15339.],
         [15219.],
         [14620.],
         [13700.],
         [12821.],
         [13383.],
         [12795.],
         [12400.],
         [12113.],
         [12308.]],

        [[13812.],
         [13662.],
         [13517.],
         [13465.],
         [13561.],
         [13623.],
         [13806.],
         [13981.],
         [14208.],
         [14304.]],

        [[15882.],
         [15947.],
         [15825.],
         [15834.],
         [15687.],
         [15552.],
         [15458.],
         [15808.],
         [15546.],
         [14488.]],

        [[12503.],
         [12958.],
  