In [41]:
import makassar_ml as ml
import pathlib
import pytorch_lightning as pl
import torch
from typing import Optional

In [42]:
class BeijingPM25LightningModule(pl.LightningDataModule):
    def __init__(self, 
        root: str, 
        feature_cols: list[int], 
        target_cols: list[int], 
        history: int, 
        horizon: int, 
        split: float,
        batch_size: int,
        ):
        self.root = root
        self.feature_cols = feature_cols
        self.target_cols = target_cols
        self.history = history
        self.horizon = horizon
        self.split = split
        self.batch_size = batch_size

    def prepare_data(self):
        # Download the dataset.
        ml.datasets.BeijingPM25Dataset(
            root=self.root,
            download=True,
            )

    def setup(self, stage: Optional[str] = None):

        # Create train/val datasets for dataloaders.
        if stage == 'fit' or stage is None:
            dataset_train_full = ml.datasets.BeijingPM25Dataset(
                root=self.root,
                download=False,
                train=True,
                split=self.split,
                )
            train_n = len(dataset_train_full)
            train_val_cutoff = train_n - round(train_n*.25) # 75% train, 25% val

            self.dataset_train = torch.utils.data.Subset(dataset_train_full, list(range(0, train_val_cutoff)))
            self.dataset_val = torch.utils.data.Subset(dataset_train_full, list(range(train_val_cutoff, train_n)))

            self.dataset_train_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_train,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )
            self.dataset_val_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_val,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )

        # Create test dataset for dataloaders.
        if stage == 'test' or stage is None:
            self.dataset_test = ml.datasets.BeijingPM25Dataset(
                root=self.root,
                download=False,
                train=False,
                split=self.split,
                )
            self.dataset_test_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_test,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_train_wrap,
            batch_size=self.batch_size,
            )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_val_wrap,
            batch_size=self.batch_size,
            )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_test_wrap,
            batch_size=self.batch_size,
            )

In [43]:
# Define parameters for the dataset.
root = pathlib.Path('../datasets/')
feature_cols = [0,1,2,3]
target_cols = [-3]
history = 5
horizon = 3
split = 0.15
batch_size = 1

# Create the dataset.
dm = BeijingPM25LightningModule(
    root=root,
    feature_cols=feature_cols,
    target_cols=target_cols,
    history=history,
    horizon=horizon,
    split=split,
    batch_size=batch_size,
)

In [44]:
dm.prepare_data()
dm.setup()
train = dm.train_dataloader()
val = dm.val_dataloader()
test = dm.test_dataloader()

# Print counts for each split.
print('train',len(train))
print('val',len(val))
print('test',len(test))
print('total',len(train)+len(val)+len(test))

# Visually inspect the split boundaries to ensure that no values are missing.
print('train[0]:',dm.dataset_train[0][0:4])
print('train[-1]:',dm.dataset_train[-1][0:4])
print('val[0]:',dm.dataset_val[0][0:4])
print('val[-1]:',dm.dataset_val[-1][0:4])
print('test[0]:',dm.dataset_test[0][0:4])
print('test[-1]:',dm.dataset_test[-1][0:4])

train 27931
val 9305
test 6567
total 43803
train[0]: tensor([2.0100e+03, 1.0000e+00, 1.0000e+00, 0.0000e+00], dtype=torch.float64)
train[-1]: tensor([2.0130e+03, 3.0000e+00, 1.0000e+01, 1.0000e+00], dtype=torch.float64)
val[0]: tensor([2.0130e+03, 3.0000e+00, 1.0000e+01, 2.0000e+00], dtype=torch.float64)
val[-1]: tensor([2.0140e+03, 4.0000e+00, 2.0000e+00, 1.0000e+00], dtype=torch.float64)
test[0]: tensor([2.0140e+03, 4.0000e+00, 2.0000e+00, 2.0000e+00], dtype=torch.float64)
test[-1]: tensor([2014.,   12.,   31.,   23.], dtype=torch.float64)
test[-1]: tensor([2014.0000,   12.0000,   31.0000,   23.0000,   12.0000,  -21.0000,
          -3.0000, 1034.0000,  249.8500,    0.0000,    0.0000],
       dtype=torch.float64)


In [45]:
# class BeijingPM25ForecastTransformer(pl.LightningModule):
#     def __init__(self, *args: Any, **kwargs: Any) -> None:
#         super().__init__(*args, **kwargs)