# Transformer encoder

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
from os import path
import random

import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.optim import AdamW

from accelerate import Accelerator

import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch import loggers as pl_loggers

# from gluonts.dataset.pandas import PandasDataset
# from gluonts.dataset.split import split, InputDataset, LabelDataset
# from gluonts.time_feature import (
#     time_features_from_frequency_str,
#     TimeFeature,
#     get_lags_for_frequency,
# )
# from gluonts.dataset.field_names import FieldName
# from gluonts.transform import (
#     AddAgeFeature,
#     AddObservedValuesIndicator,
#     AddTimeFeatures,
#     AsNumpyArray,
#     Chain,
#     ExpectedNumInstanceSampler,
#     InstanceSplitter,
#     RemoveFields,
#     SelectFields,
#     SetField,
#     TestSplitSampler,
#     Transformation,
#     ValidationSplitSampler,
#     VstackFeatures,
#     RenameFields,
# )
# from gluonts.transform.sampler import InstanceSampler
# from gluonts.itertools import Cyclic, Cached
# from gluonts.dataset.loader import as_stacked_batches

from tqdm import tqdm

from typing import Optional, Iterable, Sized, Iterator

  torch.utils._pytree._register_pytree_node(


In [3]:
from data_preprocessor.data_preprocessor import CompositeDataPreprocessor, ReduceMemUsageDataPreprocessor, FillNaPreProcessor
from data_preprocessor.feature_engineering import (
    BasicFeaturesPreprocessor,
    DupletsTripletsPreprocessor,
    MovingAvgPreProcessor,
    RemoveIrrelevantFeaturesDataPreprocessor,
    DropTargetNADataPreprocessor,
    FarNearPriceFillNaPreprocessor,
    MovingAvgFillNaPreprocessor,
    RemoveRecordsByStockDateIdPreprocessor,
)
from data_preprocessor.polynomial_features import PolynomialFeaturesPreProcessor

In [4]:
# https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
seed_everything(42, workers=True)

[rank: 0] Seed set to 42


42

In [5]:
# TODO: remove logger for better performance (string construction could be slow too)

class ModelLogger:
    def log(self, msg):
        pass

    def reset(self):
        pass

class BasicModelLogger(ModelLogger):
    def __init__(self, msg_prefix):
        self.msg_prefix = msg_prefix
        # self.log_idx = 0

    def log(self, msg):
        print(f"{self.msg_prefix} - {msg}")
        # self.log_idx += 1

    def reset(self):
        # self.log_idx = 0
        pass

class NoopModelLogger(ModelLogger):
    def log(self, msg):
        pass

    def reset(self):
        pass

# Model

TODO:
1. use input_ff_sigmoid?
2. any further model enhancement?

In [6]:
class TransformerModel(nn.Module):
    def __init__(
        self,
        num_input_features: int,
        num_classes: int,
        embedding_dim: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        dropout: float,
        logger: ModelLogger,
    ):
        super().__init__()
        self.model_type = 'Transformer'
        self.embedding = nn.Embedding(num_classes, embedding_dim)
        self.input_ff = nn.Linear(num_input_features + embedding_dim, d_model)
        self.input_ff_sigmoid = nn.Sigmoid()
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.final_linear = nn.Linear(d_model, 1)
        self.logger = logger

    def forward(
        self,
        src: Tensor,
        item_ids: Tensor,
        src_mask: Tensor = None,
    ) -> Tensor:
        # src: [batch_size b, seq_len k 55, features 27]
        # item_ids: [batch_size b]
        batch_size = src.size(dim=0)
        seq_len = src.size(dim=1)
        num_input_features = src.size(dim=2)
        output = src
        embedded = self.embedding(item_ids)
        # idea from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py#L1290
        embedded = embedded.unsqueeze(dim=1).expand(-1, seq_len, -1)

        output = torch.cat((src, embedded), dim=-1)
        # [batch_size b, seq_len k 55, features 27 + embedding_dim]
        self.logger.log(f"{output.size()}")

        output = self.input_ff(output)
        # [batch_size b, seq_len k 55, d_model]
        self.logger.log(f"input_ff - {output.size()}")

        # TODO: do we need sigmoid?
        # output = self.input_ff_sigmoid(output)
        # self.logger.log(f"input_ff_sigmoid - {output.size()}")

        if src_mask is None:
            """Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
            """
            src_mask = nn.Transformer.generate_square_subsequent_mask(src.size(dim=1))
            # no "to device" for lightning
            # https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to
            src_mask = src_mask.to(output)
            # src_mask = squared (triangle matrix) matrix [seq_len k 55, seq_len k 55]
        self.logger.log(f"src_mask - {src_mask.size()}")

        output = self.transformer_encoder(output, src_mask)
        # [batch_size b, seq_len k 55, d_model]
        self.logger.log(f"encoder - {output.size()}")

        output = self.final_linear(output)
        # [batch_size b, seq_len k 55, 1]
        self.logger.log(f"final_linear - {output.size()}")

        # take the "last" prediction, which includes all previous information
        output = output[:, -1, :]
        # [batch_size b, 1]
        output = output.squeeze(dim=1)
        # [batch_size b]
        self.logger.log(f"output - {output.size()}")

        return output

In [7]:
# https://stackoverflow.com/questions/49433936/how-do-i-initialize-weights-in-pytorch
# https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

## Load data

In [8]:
df = pd.read_csv(
    "../optiver-trading-at-the-close/train.csv",
    dtype={
        # stock_id should be int64 / long for embedding
        "date_id": np.float32,
        "seconds_in_bucket": np.float32,
        "imbalance_size": np.float32,
        "imbalance_buy_sell_flag": np.float32,
        "reference_price": np.float32,
        "matched_size": np.float32,
        "far_price": np.float32,
        "near_price": np.float32,
        "bid_price": np.float32,
        "bid_size": np.float32,
        "ask_price": np.float32,
        "ask_size": np.float32,
        "wap": np.float32,
        "target": np.float32,
        "time_id": np.float32,
    },
)
raw_df = df

In [9]:
# deep copy for easy reference to raw df without reloading from csv
df = raw_df.copy(deep=True)

In [10]:
df

Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,bid_size,ask_price,ask_size,wap,target,time_id,row_id
0,0,0.0,0.0,3.180603e+06,1.0,0.999812,13380277.00,,,0.999812,60651.500000,1.000026,8493.030273,1.000000,-3.029704,0.0,0_0_0
1,1,0.0,0.0,1.666039e+05,-1.0,0.999896,1642214.25,,,0.999896,3233.040039,1.000660,20605.089844,1.000000,-5.519986,0.0,0_0_1
2,2,0.0,0.0,3.028799e+05,-1.0,0.999561,1819368.00,,,0.999403,37956.000000,1.000298,18995.000000,1.000000,-8.389950,0.0,0_0_2
3,3,0.0,0.0,1.191768e+07,-1.0,1.000171,18389746.00,,,0.999999,2324.899902,1.000214,479032.406250,1.000000,-4.010201,0.0,0_0_3
4,4,0.0,0.0,4.475500e+05,-1.0,0.999532,17860614.00,,,0.999394,16485.539062,1.000016,434.100006,1.000000,-7.349849,0.0,0_0_4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5237975,195,480.0,540.0,2.440723e+06,-1.0,1.000317,28280362.00,0.999734,0.999734,1.000317,32257.039062,1.000434,319862.406250,1.000328,2.310276,26454.0,480_540_195
5237976,196,480.0,540.0,3.495105e+05,-1.0,1.000643,9187699.00,1.000129,1.000386,1.000643,205108.406250,1.000900,93393.070312,1.000819,-8.220077,26454.0,480_540_196
5237977,197,480.0,540.0,0.000000e+00,0.0,0.995789,12725436.00,0.995789,0.995789,0.995789,16790.660156,0.995883,180038.312500,0.995797,1.169443,26454.0,480_540_197
5237978,198,480.0,540.0,1.000899e+06,1.0,0.999210,94773272.00,0.999210,0.999210,0.998970,125631.718750,0.999210,669893.000000,0.999008,-1.540184,26454.0,480_540_198


## Data pre-processing and features

In [11]:
processors = [
    RemoveRecordsByStockDateIdPreprocessor([
        {"stock_id": 19, "date_id": 438},
        {"stock_id": 101, "date_id": 328},
        {"stock_id": 131, "date_id": 35},
        {"stock_id": 158, "date_id": 388},
    ]),
    FarNearPriceFillNaPreprocessor(),
    # ReduceMemUsageDataPreprocessor(verbose=True),
    # BasicFeaturesPreprocessor(),    
    # DupletsTripletsPreprocessor(),
    MovingAvgPreProcessor("wap"),
    MovingAvgFillNaPreprocessor("wap", 1.0),
    # StockIdFeaturesPreProcessor(),  
    # DropTargetNADataPreprocessor(),    
    # RemoveIrrelevantFeaturesDataPreprocessor(['stock_id', 'date_id','time_id', 'row_id']),
    # FillNaPreProcessor(),
    # PolynomialFeaturesPreProcessor(),
]
processor = CompositeDataPreprocessor(processors)

In [12]:
df = processor.apply(df)

CompositeDataPreprocessor - original df shape: (5237980, 17)
Processing RemoveRecordsByStockDateIdPreprocessor...
RemoveRecordsByStockDateIdPreprocessor - removing 220 records
RemoveRecordsByStockDateIdPreprocessor took 0.76s. New df shape: (5237760, 17).
Processing FarNearPriceFillNaPreprocessor...
FarNearPriceFillNaPreprocessor took 0.07s. New df shape: (5237760, 17).
Processing MovingAvgPreProcessor...
MovingAvgPreProcessor took 29.55s. New df shape: (5237760, 21).
Processing MovingAvgFillNaPreprocessor...
MovingAvgFillNaPreprocessor took 0.15s. New df shape: (5237760, 21).
CompositeDataPreprocessor - final df shape: (5237760, 21)


In [13]:
print(df.columns)
display(df)

Index(['stock_id', 'date_id', 'seconds_in_bucket', 'imbalance_size',
       'imbalance_buy_sell_flag', 'reference_price', 'matched_size',
       'far_price', 'near_price', 'bid_price', 'bid_size', 'ask_price',
       'ask_size', 'wap', 'target', 'time_id', 'row_id', 'wap_mov_avg_3_1',
       'wap_mov_avg_6_3', 'wap_mov_avg_12_6', 'wap_mov_avg_24_12'],
      dtype='object')


Unnamed: 0,stock_id,date_id,seconds_in_bucket,imbalance_size,imbalance_buy_sell_flag,reference_price,matched_size,far_price,near_price,bid_price,...,ask_price,ask_size,wap,target,time_id,row_id,wap_mov_avg_3_1,wap_mov_avg_6_3,wap_mov_avg_12_6,wap_mov_avg_24_12
0,0,0.0,0.0,3.180603e+06,1.0,0.999812,13380277.00,1.000000,1.000000,0.999812,...,1.000026,8493.030273,1.000000,-3.029704,0.0,0_0_0,1.000000,1.000000,1.000000,1.000000
1,1,0.0,0.0,1.666039e+05,-1.0,0.999896,1642214.25,1.000000,1.000000,0.999896,...,1.000660,20605.089844,1.000000,-5.519986,0.0,0_0_1,1.000000,1.000000,1.000000,1.000000
2,2,0.0,0.0,3.028799e+05,-1.0,0.999561,1819368.00,1.000000,1.000000,0.999403,...,1.000298,18995.000000,1.000000,-8.389950,0.0,0_0_2,1.000000,1.000000,1.000000,1.000000
3,3,0.0,0.0,1.191768e+07,-1.0,1.000171,18389746.00,1.000000,1.000000,0.999999,...,1.000214,479032.406250,1.000000,-4.010201,0.0,0_0_3,1.000000,1.000000,1.000000,1.000000
4,4,0.0,0.0,4.475500e+05,-1.0,0.999532,17860614.00,1.000000,1.000000,0.999394,...,1.000016,434.100006,1.000000,-7.349849,0.0,0_0_4,1.000000,1.000000,1.000000,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5237975,195,480.0,540.0,2.440723e+06,-1.0,1.000317,28280362.00,0.999734,0.999734,1.000317,...,1.000434,319862.406250,1.000328,2.310276,26454.0,480_540_195,1.000345,1.000304,1.000318,1.000202
5237976,196,480.0,540.0,3.495105e+05,-1.0,1.000643,9187699.00,1.000129,1.000386,1.000643,...,1.000900,93393.070312,1.000819,-8.220077,26454.0,480_540_196,1.000816,1.000710,1.000560,1.000506
5237977,197,480.0,540.0,0.000000e+00,0.0,0.995789,12725436.00,0.995789,0.995789,0.995789,...,0.995883,180038.312500,0.995797,1.169443,26454.0,480_540_197,0.995958,0.996070,0.996130,0.996436
5237978,198,480.0,540.0,1.000899e+06,1.0,0.999210,94773272.00,0.999210,0.999210,0.998970,...,0.999210,669893.000000,0.999008,-1.540184,26454.0,480_540_198,0.999116,0.999218,0.999305,0.999313


In [14]:
feat_dynamic_real = [
    "date_id",
    "seconds_in_bucket",
    "imbalance_size",
    "reference_price",
    "matched_size",
    "far_price",
    "near_price",
    "bid_price",
    "bid_size",
    "ask_price",
    "ask_size",
    "wap",
    "wap_mov_avg_3_1",
    "wap_mov_avg_6_3",
    "wap_mov_avg_12_6",
    "wap_mov_avg_24_12",
]
num_input_features = len(feat_dynamic_real)
print(num_input_features)

16


In [15]:
# should not have any na features
any_na_values_mask = df[feat_dynamic_real].isna().any(axis=1)
print(any_na_values_mask.shape, any_na_values_mask[any_na_values_mask].shape)

(5237760,) (0,)


## Group by stock_id

TODO: group by date_id too? or use embedding for date_id?

In [16]:
df_grouped = df.groupby("stock_id")
num_classes = len(df_grouped)
print(num_classes)
print(df_grouped.size())

200
stock_id
0      26455
1      26455
2      26455
3      26455
4      26455
       ...  
195    26455
196    26455
197    26455
198    26455
199    21615
Length: 200, dtype: int64


## Normalize features per stock

TODO: move as pre-processor? do not normalize all columns?

TODO: should not normalize date_id, seconds_in_bucket, imbalance_buy_sell_flag, far_price, near_price, bid_price, ask_price?

In [17]:
normalize_columns = list(set(feat_dynamic_real) - set(["date_id"]))
print(normalize_columns)
display(df_grouped.get_group(0)[normalize_columns].head())
df[normalize_columns] = df_grouped[normalize_columns].transform(lambda x: (x - x.mean()) / x.std())
display(df_grouped.get_group(0)[normalize_columns].head())

['near_price', 'wap', 'wap_mov_avg_6_3', 'seconds_in_bucket', 'far_price', 'ask_price', 'imbalance_size', 'wap_mov_avg_3_1', 'matched_size', 'wap_mov_avg_24_12', 'bid_price', 'reference_price', 'bid_size', 'wap_mov_avg_12_6', 'ask_size']


Unnamed: 0,near_price,wap,wap_mov_avg_6_3,seconds_in_bucket,far_price,ask_price,imbalance_size,wap_mov_avg_3_1,matched_size,wap_mov_avg_24_12,bid_price,reference_price,bid_size,wap_mov_avg_12_6,ask_size
0,1.0,1.0,1.0,0.0,1.0,1.000026,3180602.75,1.0,13380277.0,1.0,0.999812,0.999812,60651.5,1.0,8493.030273
191,1.0,0.999892,1.0,10.0,1.0,1.000026,1299772.75,0.999946,15261107.0,1.0,0.999812,1.000026,13996.5,1.0,23519.160156
382,1.0,0.999842,0.999911,20.0,1.0,0.999919,1299772.75,0.999911,15261107.0,1.0,0.999812,0.999919,4665.5,1.0,12131.599609
573,1.0,1.000085,0.999955,30.0,1.0,1.000133,1299772.75,0.99994,15261107.0,1.0,1.000026,1.000133,55998.0,1.0,46203.300781
764,1.0,1.000317,1.000027,40.0,1.0,1.000455,1218204.375,1.000081,15342675.0,1.0,1.000241,1.000455,14655.950195,1.0,26610.449219


Unnamed: 0,near_price,wap,wap_mov_avg_6_3,seconds_in_bucket,far_price,ask_price,imbalance_size,wap_mov_avg_3_1,matched_size,wap_mov_avg_24_12,bid_price,reference_price,bid_size,wap_mov_avg_12_6,ask_size
0,0.133338,0.094047,0.094393,-1.700898,0.123295,0.042608,-0.123788,0.094223,-0.587912,0.085538,0.045702,-0.013458,0.326688,0.092848,-0.517845
191,0.133338,0.029667,0.094393,-1.637902,0.123295,0.042608,-0.349815,0.061263,-0.498555,0.085538,0.045702,0.114022,-0.29275,0.092848,-0.257284
382,0.133338,-0.000142,0.038456,-1.574906,0.123295,-0.021127,-0.349815,0.04009,-0.498555,0.085538,0.045702,0.050282,-0.416638,0.092848,-0.45475
573,0.133338,0.144713,0.065861,-1.51191,0.123295,0.106379,-0.349815,0.057407,-0.498555,0.085538,0.173482,0.177797,0.264904,0.092848,0.13607
764,0.133338,0.282995,0.111535,-1.448913,0.123295,0.298187,-0.359617,0.143845,-0.49468,0.085538,0.301903,0.369621,-0.283995,0.092848,-0.20368


## Data hyperparameters

In [18]:
# TODO: change to align with lightgbm
training_set_max_date_id = 480 - 20
print(training_set_max_date_id)

460


In [19]:
prediction_length = 1
seq_len = 55

training_batch_size = 256
validation_batch_size = 256

## Prepare Pytorch datasets

In [20]:
class StockTrainingDataset(torch.utils.data.Dataset):
    def __init__(self, stock_df, feature_names, target_col, item_id, context_length):
        super().__init__()
        self.features = stock_df[feature_names]
        self.targets = stock_df[target_col]
        self.item_id = item_id
        self.context_length = context_length
        # possible idx = 0 ... 26455 - 55 + 1
        self.total_size = self.features.shape[0] - context_length + 1

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        start_idx = idx
        end_idx = idx + self.context_length
        features = self.features.iloc[start_idx:end_idx]
        targets = self.targets.iloc[start_idx:end_idx]
        return features.values, self.item_id, targets.values

In [22]:
stock_training_datasets = []
stock_validation_datasets = []
for item_id, gdf in df_grouped:
    training_mask = gdf["date_id"] < training_set_max_date_id
    training_df = gdf[training_mask]
    validation_df = gdf[~training_mask]
    assert training_df.shape[0] > 0 \
        and validation_df.shape[0] > 0 \
        and training_df.shape[0] + validation_df.shape[0] == gdf.shape[0], f"{item_id} invalid shape, training_df: {training_df.shape}, validation_df: {validation_df.shape}"
    stock_training_datasets.append(StockTrainingDataset(training_df, feat_dynamic_real, "target", item_id, seq_len))
    stock_validation_datasets.append(StockTrainingDataset(validation_df, feat_dynamic_real, "target", item_id, seq_len))
print(len(stock_training_datasets), len(stock_validation_datasets))

200 200


In [23]:
full_training_dataset = torch.utils.data.ConcatDataset(stock_training_datasets)
full_validation_dataset = torch.utils.data.ConcatDataset(stock_validation_datasets)
print(len(full_training_dataset), len(full_validation_dataset))

4995960 220200


In [24]:
training_sampler = torch.utils.data.RandomSampler(full_training_dataset)
validation_sampler = torch.utils.data.SequentialSampler(full_validation_dataset)

In [25]:
training_dataloader = torch.utils.data.DataLoader(
    full_training_dataset,
    batch_size=training_batch_size,
    sampler=training_sampler,
    # https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading
    num_workers=4,
    # https://pytorch.org/docs/stable/data.html#memory-pinning
    pin_memory=True,
)
validation_dataloader = torch.utils.data.DataLoader(
    full_validation_dataset,
    batch_size=validation_batch_size,
    sampler=validation_sampler,
    num_workers=4,
    pin_memory=True,
)

In [26]:
training_sample_batch = next(iter(training_dataloader))
print("training_sample_batch", training_sample_batch[0].size(), training_sample_batch[1].size(), training_sample_batch[2].size())
print("training_sample_batch", training_sample_batch[0].type(), training_sample_batch[1].type(), training_sample_batch[2].type())
validation_sample_batch = next(iter(validation_dataloader))
print("validation_sample_batch", validation_sample_batch[0].size(), validation_sample_batch[1].size(), validation_sample_batch[2].size())
print("validation_sample_batch", validation_sample_batch[0].type(), validation_sample_batch[1].type(), validation_sample_batch[2].type())

training_sample_batch torch.Size([256, 55, 16]) torch.Size([256]) torch.Size([256, 55])
training_sample_batch torch.FloatTensor torch.LongTensor torch.FloatTensor
validation_sample_batch torch.Size([256, 55, 16]) torch.Size([256]) torch.Size([256, 55])
validation_sample_batch torch.FloatTensor torch.LongTensor torch.FloatTensor


## Hyperparameters

In [27]:
embedding_dim = 4
d_model = 32
nhead = 4
d_hid = 32
nlayers = 2
dropout = 0.1

## Create model

In [28]:
model = TransformerModel(
    num_input_features=num_input_features,
    num_classes=num_classes,
    embedding_dim=embedding_dim,
    d_model=d_model,
    nhead=nhead,
    d_hid=d_hid,
    nlayers=nlayers,
    dropout=dropout,
    logger=NoopModelLogger(),
)
model.apply(init_weights)

TransformerModel(
  (embedding): Embedding(200, 4)
  (input_ff): Linear(in_features=20, out_features=32, bias=True)
  (input_ff_sigmoid): Sigmoid()
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_features=32, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=32, out_features=32, bias=True)
        (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (final_linear): Linear(in_features=32, out_features=1, bias=True)
)

In [29]:
criterion = nn.L1Loss()
validation_criterion = nn.L1Loss()

## Test model with 2 identical samples, test model is non-random

In [30]:
sample_model = TransformerModel(
    num_input_features=num_input_features,
    num_classes=num_classes,
    embedding_dim=4,
    d_model=4,
    nhead=2,
    d_hid=4,
    nlayers=2,
    dropout=dropout,
    logger=BasicModelLogger("test_model_with_2_identical_samples"),
)
sample_model.apply(init_weights)
sample_model.eval()
sample_input_features = training_sample_batch[0][0]
sample_input_features = sample_input_features.expand(2, -1, -1)
sample_input_item_id = training_sample_batch[1][0]
sample_input_item_id = sample_input_item_id.expand(2)
print(sample_input_features.size(), sample_input_item_id.size(), sample_input_item_id)
sample_output = sample_model(sample_input_features, sample_input_item_id)
sample_targets = training_sample_batch[2][0]
sample_targets = sample_targets.expand(2, -1)
sample_actual_targets = sample_targets[:, -1]
print(sample_output.size(), sample_output, sample_targets.size(), sample_actual_targets.size(), sample_actual_targets)
sample_loss = criterion(sample_output, sample_actual_targets)
print(sample_loss)
del sample_model, sample_input_features, sample_input_item_id, sample_output, sample_targets, sample_loss

torch.Size([2, 55, 16]) torch.Size([2]) tensor([83, 83])
test_model_with_2_identical_samples - torch.Size([2, 55, 20])
test_model_with_2_identical_samples - input_ff - torch.Size([2, 55, 4])
test_model_with_2_identical_samples - src_mask - torch.Size([55, 55])
test_model_with_2_identical_samples - encoder - torch.Size([2, 55, 4])
test_model_with_2_identical_samples - final_linear - torch.Size([2, 55, 1])
test_model_with_2_identical_samples - output - torch.Size([2])
torch.Size([2]) tensor([-0.1174, -0.1174], grad_fn=<SqueezeBackward1>) torch.Size([2, 55]) torch.Size([2]) tensor([4.9198, 4.9198])
tensor(5.0372, grad_fn=<MeanBackward0>)


## Tensorboard

In [31]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 2329232), started 0:02:57 ago. (Use '!kill 2329232' to kill it.)

## Module and trainer (lightning)

In [32]:
class TransformerModelModule(L.LightningModule):
    def __init__(self, model, criterion, validation_criterion):
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.validation_criterion = validation_criterion
        self.validation_step_outputs = []
        self.validation_step_actual_targets = []

    def training_step(self, batch, batch_idx):
        features, item_id, targets = batch[0], batch[1], batch[2]
        actual_targets = targets[:, -1]
        output = self.model(features, item_id)
        loss = self.criterion(output, actual_targets)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        features, item_id, targets = batch[0], batch[1], batch[2]
        actual_targets = targets[:, -1]
        output = self.model(features, item_id)
        # TODO: is loss being averaged based on batch size
        loss = self.validation_criterion(output, actual_targets)
        # lightning will take weighted-average on loss per step based on batch size
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.validation_step_outputs.append(output)
        self.validation_step_actual_targets.append(actual_targets)

    def on_validation_epoch_end(self):
        # TODO: remove manual calculation of validation loss if we can confirm lightning will take weighted average
        # cat is used instead of stack, last step may have different batch size
        all_preds = torch.cat(self.validation_step_outputs)
        all_actual_targets = torch.cat(self.validation_step_actual_targets)
        manual_loss = self.validation_criterion(all_preds, all_actual_targets)
        self.log("val_loss_manual", manual_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.validation_step_outputs.clear()  # free memory
        self.validation_step_actual_targets.clear()  # free memory

    def configure_optimizers(self):
        # TODO: tune learning rate
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=6e-4, betas=(0.9, 0.95), weight_decay=1e-1)
        return [optimizer]

In [33]:
model = TransformerModelModule(model, criterion, validation_criterion)

### Training hyperparameters

In [34]:
num_epochs = 10
limit_train_batches = 1.0
# limit_train_batches = 100
gradient_clip_val = 0.5

In [35]:
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html
# https://lightning.ai/docs/pytorch/stable/extensions/logging.html
# TODO: custom version name
tb_logger = pl_loggers.TensorBoardLogger(".", version=None)

In [36]:
trainer = L.Trainer(
    max_epochs=num_epochs,
    limit_train_batches=limit_train_batches,
    # https://lightning.ai/docs/pytorch/stable/advanced/training_tricks.html#gradient-clipping
    gradient_clip_val=gradient_clip_val,
    callbacks=[
        # https://lightning.ai/docs/pytorch/stable/common/progress_bar.html#richprogressbar
        RichProgressBar(leave=True)
    ],
    logger=tb_logger,
    # https://lightning.ai/docs/pytorch/stable/common/trainer.html#reproducibility
    deterministic=True,
)

/userhome/cs2/tsangsyf/anaconda3_2/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /userhome/cs2/tsangsyf/anaconda3_2/lib/python3.11/si ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..


In [37]:
trainer.validate(
    model=model,
    dataloaders=validation_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'val_loss': 6.131535530090332, 'val_loss_manual': 6.131534099578857}]

## Training

In [38]:
trainer.fit(
    model=model,
    train_dataloaders=training_dataloader,
    val_dataloaders=validation_dataloader,
)

/userhome/cs2/tsangsyf/anaconda3_2/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /userhome/cs2/tsangsyf/anaconda3_2/lib/python3.11/si ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

`Trainer.fit` stopped: `max_epochs=10` reached.


In [39]:
trainer.validate(
    model=model,
    dataloaders=validation_dataloader,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

[{'val_loss': 5.738607883453369, 'val_loss_manual': 5.738612651824951}]