In [1]:
from models.classifier import TransformerEncoder
from pytorch_lightning import Trainer
import torch

In [2]:
hparams = {
    "hidden_size": 96,  # size of the hidden layers and embeddings
    "hidden_ff": 128,  # size of the position-wise feed-forward layer
    "n_encoders": 4,  # number of encoder blocks
    "n_heads": 8,  # number of attention heads in the multiheadattention module
    "n_local": 2,  # number of local attention heads
    "local_window_size": 4,  # size of the window for local attention
    'batch_size': 12,
    "max_length": 500,  # maximum length of the input sequence
    "vocab_size": 1000,  # size of the vocabulary
    "learning_rate": 0.001,
    "num_epochs": 30,
    "attention_type": "performer",
    "norm_type": "rezero",
    "num_random_features": 32,  # number of random features for the Attention module (Performer uses this)
    "emb_dropout": 0.1,  # dropout for the embedding block
    "fw_dropout": 0.1,  # dropout for the position-wise feed-forward layer
    "att_dropout": 0.1,  # dropout for the multiheadattention module
    "dc_dropout": 0.1,  # dropout for the decoder block
    "hidden_act": "swish",  # activation function for the hidden layers (attention layers use ReLU)
    "epsilon": 1e-8,
    "weight_decay": 0.01,
    "beta1": 0.9,
    "beta2": 0.999,
}
model = TransformerEncoder(hparams)

In [3]:
batch = {
    'tokens': torch.randint(0, model.hparams.vocab_size, (model.hparams.batch_size, model.hparams.max_length)),
    'abspos': torch.arange(0, model.hparams.max_length).unsqueeze(0).repeat(model.hparams.batch_size, 1),
    'age': torch.randint(0, 100, (model.hparams.batch_size, 1)),
    'padding_mask': torch.zeros(model.hparams.batch_size, model.hparams.max_length),
    'targets': torch.randint(0, 2, (model.hparams.batch_size, 1)).float(),
}

#### Sanity Check

In [4]:
model(batch)

{'logits': tensor([[ 0.0265],
         [ 0.0760],
         [-0.1372],
         [-0.1907],
         [-0.0547],
         [-0.0885],
         [-0.0600],
         [ 0.1507],
         [-0.1601],
         [-0.1576],
         [-0.1622],
         [-0.1141]], grad_fn=<AddmmBackward0>),
 'preds': tensor([[0.5066],
         [0.5190],
         [0.4657],
         [0.4525],
         [0.4863],
         [0.4779],
         [0.4850],
         [0.5376],
         [0.4601],
         [0.4607],
         [0.4595],
         [0.4715]], grad_fn=<SigmoidBackward0>)}

In [5]:
model.training_step(batch, 0)

tensor(0.7142, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [6]:
model.on_train_epoch_end()

Train Metrics
Train Metrics
Loss: tensor(0.7142)
Accuracy: tensor(0.3333)
MCC: tensor(-0.2500)


/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/core/module.py:447: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


### Full pipeline would be something like

In [7]:
from dataloaders.synthetic import SyntheticDataModule
dataloader = SyntheticDataModule(num_samples=100, max_length=hparams['max_length'],
                                  batch_size=hparams['batch_size'], vocab_size=hparams['vocab_size'])

In [None]:
trainer = Trainer(max_epochs=30,
                accelerator="cpu", 
                val_check_interval=3) ### change to "cuda" or "gpu"
trainer.fit(model, dataloader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

   | Name        | Type                   | Params | Mode 
--------------

Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.90it/s]

/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (9) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 9/9 [00:04<00:00,  2.12it/s, v_num=3]Train Metrics
Train Metrics
Loss: tensor(0.7025)
Accuracy: tensor(0.4107)
MCC: tensor(-0.1706)
Epoch 1: 100%|██████████| 9/9 [00:03<00:00,  2.34it/s, v_num=3]Train Metrics
Train Metrics
Loss: tensor(0.6964)
Accuracy: tensor(0.4528)
MCC: tensor(-0.0774)
Epoch 2:  67%|██████▋   | 6/9 [00:02<00:01,  2.88it/s, v_num=3]

/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
