In [1]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.utils.data import DataLoader

from etudelib.data.synthetic.synthetic import SyntheticDataset
from etudelib.models.lightsans.lightning_model import LightSANs
from etudelib.models.lightsans.torch_model import LightSANsModel

In [2]:
qty_interactions = 10000
n_items = 5000
max_seq_length = 43
qty_sessions = qty_interactions
batch_size = 32

In [3]:
train_ds = SyntheticDataset(qty_interactions=qty_interactions,
                                qty_sessions=qty_sessions,
                                n_items=n_items,
                                max_seq_length=max_seq_length)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)


In [4]:
backbone = LightSANsModel(n_layers=2,
                              n_heads=2,
                              k_interests=15,
                              hidden_size=64,
                              inner_size=256,
                              hidden_dropout_prob=0.5,
                              attn_dropout_prob=0.5,
                              hidden_act="gelu",
                              layer_norm_eps=1e-12,
                              initializer_range=0.02,
                              max_seq_length=max_seq_length,
                              n_items=n_items,
                              topk=21,
                              )

In [5]:
model = LightSANs(backbone)
print(model)

LightSANs(
  (backbone): LightSANsModel(
    (item_embedding): Embedding(5000, 64, padding_idx=0)
    (position_embedding): Embedding(43, 64)
    (trm_encoder): LightTransformerEncoder(
      (layer): ModuleList(
        (0): LightTransformerLayer(
          (multi_head_attention): LightMultiHeadAttention(
            (query): Linear(in_features=64, out_features=64, bias=True)
            (key): Linear(in_features=64, out_features=64, bias=True)
            (value): Linear(in_features=64, out_features=64, bias=True)
            (attpooling_key): ItemToInterestAggregation()
            (attpooling_value): ItemToInterestAggregation()
            (pos_q_linear): Linear(in_features=64, out_features=64, bias=True)
            (pos_k_linear): Linear(in_features=64, out_features=64, bias=True)
            (pos_ln): LayerNorm((64,), eps=1e-12, elementwise_affine=True)
            (attn_dropout): Dropout(p=0.5, inplace=False)
            (dense): Linear(in_features=64, out_features=64, bias=Tru

In [6]:
trainer = Trainer(
        accelerator="auto",
        devices=None,
        max_epochs=3,
        callbacks=[TQDMProgressBar()],
    )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, train_loader)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")

  | Name     | Type           | Params
--------------------------------------------
0 | backbone | LightSANsModel | 443 K 
--------------------------------------------
443 K     Trainable params
0         Non-trainable params
443 K     Total params
1.774     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]