In [1]:
import makassar_ml as ml
import pathlib
import pytorch_lightning as pl
import torch
from typing import Optional

In [2]:
class BeijingPM25LightningDataModule(pl.LightningDataModule):
    def __init__(self, 
        root: str, 
        feature_cols: list[int], 
        target_cols: list[int], 
        history: int, 
        horizon: int, 
        split: float,
        batch_size: int,
        ):
        self.root = root
        self.feature_cols = feature_cols
        self.target_cols = target_cols
        self.history = history
        self.horizon = horizon
        self.split = split
        self.batch_size = batch_size

    def prepare_data(self):
        # Download the dataset.
        ml.datasets.BeijingPM25Dataset(
            root=self.root,
            download=True,
            )

    def setup(self, stage: Optional[str] = None):

        # Create train/val datasets for dataloaders.
        if stage == 'fit' or stage is None:
            dataset_train_full = ml.datasets.BeijingPM25Dataset(
                root=self.root,
                download=False,
                train=True,
                split=self.split,
                )
            train_n = len(dataset_train_full)
            train_val_cutoff = train_n - round(train_n*.25) # 75% train, 25% val

            self.dataset_train = torch.utils.data.Subset(dataset_train_full, list(range(0, train_val_cutoff)))
            self.dataset_val = torch.utils.data.Subset(dataset_train_full, list(range(train_val_cutoff, train_n)))

            self.dataset_train_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_train,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )
            self.dataset_val_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_val,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )

        # Create test dataset for dataloaders.
        if stage == 'test' or stage is None:
            self.dataset_test = ml.datasets.BeijingPM25Dataset(
                root=self.root,
                download=False,
                train=False,
                split=self.split,
                )
            self.dataset_test_wrap = ml.datasets.TimeseriesForecastDatasetWrapper(
                dataset=self.dataset_test,
                feature_cols=self.feature_cols,
                target_cols=self.target_cols,
                history=self.history,
                horizon=self.horizon,
                )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_train_wrap,
            batch_size=self.batch_size,
            )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_val_wrap,
            batch_size=self.batch_size,
            )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            dataset=self.dataset_test_wrap,
            batch_size=self.batch_size,
            )

In [3]:
def create_attn_mask(length: int, device: str = None):
    """Generate mask used for attention mechanisms.

    Masks are a lower-triangular matrix of zeros
    with the other entries taking value "-inf".

    Args:
        length (int): Length of square-matrix dimension.
        device (str, optional): PyTorch device.

    Examples:

        >>> create_mask(3)
        tensor([[0., -inf, -inf],
                [0., 0., -inf],
                [0., 0., 0.]])
    """
    # Get lower-triangular matrix of ones.
    mask = torch.tril(torch.ones(length, length, device=device))

    # Replace 0 -> "-inf" and 1 -> 0.0
    mask = (
        mask
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask

In [4]:
class TimeseriesTransformer(torch.nn.Module):

    def __init__(self,
        n_input_features: int,
        n_output_features: int,
        d_model: int = 512,
        dropout: float = 0.1,
        batch_first: bool = False,
        ):
        super().__init__()

        self.batch_first = batch_first

        # Linear transformation from input-feature space into arbitrary n-dimension space.
        # This is similar to a word embedding used in NLP tasks.
        self.encoder_projection = torch.nn.Linear(in_features=n_input_features, out_features=d_model)
        self.decoder_projection = torch.nn.Linear(in_features=n_output_features, out_features=d_model)

        # Transformer encoder/decoder layers.
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8, # Number of multihead-attention models.
            dropout=dropout,
            dim_feedforward=4*d_model,
            batch_first=batch_first,
        )
        decoder_layer = torch.nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=8, # Number of multihead-attention models.
            dropout=dropout,
            dim_feedforward=4*d_model,
            batch_first=batch_first,
        )
        self.encoder = torch.nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=8)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer=decoder_layer, num_layers=8)

        # Linear output layer.
        # We typically only predict a single data point at a time, so output features is typically 1.
        self.linear = torch.nn.Linear(in_features=d_model, out_features=n_output_features)

    def encode(self, src):
        # Transform source into arbitrary feature space.
        x = self.encoder_projection(src)

        # Create source mask.
        if self.batch_first:
            src_length, batch_size = src.size(1), src.size(0)
        else:
            src_length, batch_size = src.size(0), src.size(1)
        src_mask = create_attn_mask(length=src_length, device=src.device)

        # Pass the linear transformation through the encoder layers.
        x = self.encoder(x, mask=src_mask)
        # x = self.encoder(x)

        return x

    def decode(self, tgt, memory):
        # Transform target into arbitrary feature space.
        x = self.decoder_projection(tgt)

        # Create target attention mask.
        if self.batch_first:
            tgt_length, batch_size = tgt.size(1), tgt.size(0)
        else:
            tgt_length, batch_size = tgt.size(0), tgt.size(1)

        tgt_mask = create_attn_mask(length=tgt_length, device=tgt.device)

        # Pass the linear transformation through the decoder layers.
        x = self.decoder(tgt=x, memory=memory, tgt_mask=tgt_mask)

        # Pass the output of the decoder through the linear prediction layer.
        x = self.linear(x)

        return x

    def forward(self, x):
        src, tgt = x
        y = self.encode(src)
        y = self.decode(tgt=tgt, memory=y)
        return y

    def step(self, batch):
        src, tgt_int, tgt_out = batch

In [5]:
class BeijingPM25ForecastTransformer(pl.LightningModule):
    def __init__(self,
        n_input_features: int,
        n_output_features: int,
        d_model: int = 512,
        dropout: float = 0.1,
        ):
        super().__init__()

        self.criterion = torch.nn.MSELoss(reduction='mean')

        # Create the transformer model.
        self.model = TimeseriesTransformer(
            n_input_features=n_input_features,
            n_output_features=n_output_features,
            d_model=d_model,
            dropout=dropout,
            batch_first=True,
            )

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def compute_loss(self, y_hat, y):
        return self.criterion(y_hat, y)

    def training_step(self, batch, batch_idx):
        history_x, history_y, horizon_x, horizon_y = batch
        # y_hat = self((history_x, history_y,))
        y_hat = self((history_x, horizon_y,))
        loss = self.compute_loss(y_hat, horizon_y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

In [6]:
# Define parameters for the dataset.
root = pathlib.Path('../datasets/')
feature_cols = [0,1,2,3]
target_cols = [-3]
history = 5
horizon = 3
split = 0.15
batch_size = 32

# Create the dataset.
dm = BeijingPM25LightningDataModule(
    root=root,
    feature_cols=feature_cols,
    target_cols=target_cols,
    history=history,
    horizon=horizon,
    split=split,
    batch_size=batch_size,
)

n_input_features: int = len(feature_cols)
n_output_features: int = len(target_cols)
d_model: int = 512
dropout: float = 0.1
model = BeijingPM25ForecastTransformer(
    n_input_features=n_input_features,
    n_output_features=n_output_features,
    d_model=d_model,
    dropout=dropout,
    )

trainer = pl.Trainer()
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name      | Type                  | Params
----------------------------------------------------
0 | criterion | MSELoss               | 0     
1 | model     | TimeseriesTransformer | 58.9 M
----------------------------------------------------
58.9 M    Trainable params
0         Non-trainable params
58.9 M    Total params
235.422   Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   4%|▍         | 34/873 [00:36<15:07,  1.08s/it, loss=2.28e+03, v_num=31, train_loss_step=141.0]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# dm.prepare_data()
# dm.setup()
# train = dm.train_dataloader()
# val = dm.val_dataloader()
# test = dm.test_dataloader()

# # Print counts for each split.
# print('train',len(train))
# print('val',len(val))
# print('test',len(test))
# print('total',len(train)+len(val)+len(test))

# # Visually inspect the split boundaries to ensure that no values are missing.
# print('train[0]:',dm.dataset_train[0][0:4])
# print('train[-1]:',dm.dataset_train[-1][0:4])
# print('val[0]:',dm.dataset_val[0][0:4])
# print('val[-1]:',dm.dataset_val[-1][0:4])
# print('test[0]:',dm.dataset_test[0][0:4])
# print('test[-1]:',dm.dataset_test[-1][0:4])

train 27931
val 9305
test 6567
total 43803
train[0]: tensor([2.0100e+03, 1.0000e+00, 1.0000e+00, 0.0000e+00], dtype=torch.float64)
train[-1]: tensor([2.0130e+03, 3.0000e+00, 1.0000e+01, 1.0000e+00], dtype=torch.float64)
val[0]: tensor([2.0130e+03, 3.0000e+00, 1.0000e+01, 2.0000e+00], dtype=torch.float64)
val[-1]: tensor([2.0140e+03, 4.0000e+00, 2.0000e+00, 1.0000e+00], dtype=torch.float64)
test[0]: tensor([2.0140e+03, 4.0000e+00, 2.0000e+00, 2.0000e+00], dtype=torch.float64)
test[-1]: tensor([2014.,   12.,   31.,   23.], dtype=torch.float64)
