In [1]:
from models.classifier import TransformerEncoder
from pytorch_lightning import Trainer
import torch

In [2]:
hparams = {
    "hidden_size": 48,  # size of the hidden layers and embeddings
    "hidden_ff": 96,  # size of the position-wise feed-forward layer
    "n_encoders": 4,  # number of encoder blocks
    "n_heads": 2,  # number of attention heads in the multiheadattention module
    "n_local": 2,  # number of local attention heads
    "local_window_size": 4,  # size of the window for local attention
    'batch_size': 4,
    "max_length": 30,  # maximum length of the input sequence
    "vocab_size": 100,  # size of the vocabulary
    "learning_rate": 0.001,
    "num_epochs": 30,
    "attention_type": "performer",
    "norm_type": "rezero",
    "num_random_features": 32,  # number of random features for the Attention module (Performer uses this)
    "emb_dropout": 0.1,  # dropout for the embedding block
    "fw_dropout": 0.1,  # dropout for the position-wise feed-forward layer
    "att_dropout": 0.1,  # dropout for the multiheadattention module
    "dc_dropout": 0.1,  # dropout for the decoder block
    "hidden_act": "swish",  # activation function for the hidden layers (attention layers use ReLU)
    "epsilon": 1e-8,
    "weight_decay": 0.01,
    "beta1": 0.9,
    "beta2": 0.999,
}
model = TransformerEncoder(hparams)

#### Sanity Check

In [3]:
batch = {
    'tokens': torch.randint(0, model.hparams.vocab_size, (model.hparams.batch_size, model.hparams.max_length)),
    'abspos': torch.arange(0, model.hparams.max_length).unsqueeze(0).repeat(model.hparams.batch_size, 1),
    'age': torch.randint(0, 100, (model.hparams.batch_size, 1)).repeat(1, model.hparams.max_length),
    'padding_mask': torch.zeros(model.hparams.batch_size, model.hparams.max_length),
    'targets': torch.randint(0, 2, (model.hparams.batch_size, 1)).float(),
}

In [4]:
batch

{'tokens': tensor([[ 3, 72, 16, 92, 63, 80,  1,  3, 35,  3, 18,  9, 88, 70, 21, 76, 16, 58,
          11, 57, 51, 67, 79, 91, 62, 63, 10, 65, 36, 61],
         [19, 24, 65, 39, 60,  0,  0, 14, 62, 11, 13, 46, 11, 99, 40, 86, 24, 57,
          94, 13,  7, 23, 70, 88, 82, 70, 93, 94, 15, 33],
         [91, 57, 30,  4, 48, 61, 11, 52, 55, 33, 82, 49, 44, 78, 74, 56, 30, 89,
          11, 93, 49, 85,  1, 62,  6, 52, 85, 83, 99, 39],
         [21,  2, 21, 52,  9, 63,  6, 37, 84, 21, 13, 87, 73, 42, 91, 61, 69, 58,
          99, 87, 76, 42, 41,  3, 42, 64, 92, 96, 18, 84]]),
 'abspos': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 

In [5]:
model(batch)

{'logits': tensor([[0.6862],
         [0.6543],
         [0.8233],
         [0.6555]], grad_fn=<AddmmBackward0>),
 'preds': tensor([[0.6651],
         [0.6580],
         [0.6949],
         [0.6582]], grad_fn=<SigmoidBackward0>)}

In [6]:
model.training_step(batch, 0)

tensor(0.9925, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [7]:
model.on_train_epoch_end()


Train Metrics
	Loss: 0.993
	Accuracy: 0.250
	MCC: -0.577



/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/core/module.py:447: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


### Full pipeline would be something like

In [8]:
from dataloaders.synthetic import SyntheticDataModule
dataloader = SyntheticDataModule(num_samples=1000, max_length=hparams['max_length'],
                                  batch_size=hparams['batch_size'], vocab_size=hparams['vocab_size'])

In [None]:
trainer = Trainer(max_epochs=30,
                accelerator="cpu",   ### change to "cuda" or "gpu" or 'msp'
                limit_train_batches=0.5,
                accumulate_grad_batches=4,
                num_sanity_val_steps=8,
                check_val_every_n_epoch=1)
trainer.fit(model, dataloader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

   | Name        | Type                   | Params | Mode 
--------------

Sanity Checking DataLoader 0: 100%|██████████| 8/8 [00:00<00:00, 118.21it/s]
Val Metrics
	Loss: 0.664
	Accuracy: 0.625
	MCC: 0.091

                                                                            

/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/carlomarx/.local/share/virtualenvs/life2vec-light-Ez8u7ZRp-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 125/125 [00:01<00:00, 82.87it/s, v_num=6]
Val Metrics
	Loss: 0.739
	Accuracy: 0.497
	MCC: -0.000

Epoch 0: 100%|██████████| 125/125 [00:03<00:00, 40.64it/s, v_num=6]
Train Metrics
	Loss: 0.750
	Accuracy: 0.504
	MCC: 0.001

Epoch 1: 100%|██████████| 125/125 [00:01<00:00, 88.04it/s, v_num=6]
Val Metrics
	Loss: 0.729
	Accuracy: 0.502
	MCC: 0.000

Epoch 1: 100%|██████████| 125/125 [00:02<00:00, 43.35it/s, v_num=6]
Train Metrics
	Loss: 0.744
	Accuracy: 0.492
	MCC: -0.055

Epoch 2: 100%|██████████| 125/125 [00:01<00:00, 91.00it/s, v_num=6]
Val Metrics
	Loss: 0.719
	Accuracy: 0.509
	MCC: 0.001

Epoch 2: 100%|██████████| 125/125 [00:02<00:00, 43.56it/s, v_num=6]
Train Metrics
	Loss: 0.740
	Accuracy: 0.489
	MCC: -0.014

Epoch 3: 100%|██████████| 125/125 [00:01<00:00, 87.67it/s, v_num=6]
Val Metrics
	Loss: 0.714
	Accuracy: 0.502
	MCC: 0.000

Epoch 3: 100%|██████████| 125/125 [00:02<00:00, 42.33it/s, v_num=6]
Train Metrics
	Loss: 0.732
	Accuracy: 0.490
	MCC: -0.005

Epoc