In [11]:
from models.classifier import TransformerEncoder
import torch

In [12]:
hparams = {
    "hidden_size": 48,  # size of the hidden layers and embeddings
    "hidden_ff": 96,  # size of the position-wise feed-forward layer
    "n_encoders": 4,  # number of encoder blocks
    "n_heads": 2,  # number of attention heads in the multiheadattention module
    "n_local": 2,  # number of local attention heads
    "local_window_size": 4,  # size of the window for local attention
    'batch_size': 4,
    "max_length": 30,  # maximum length of the input sequence
    "vocab_size": 100,  # size of the vocabulary
    "learning_rate": 0.001,
    "num_epochs": 30,
    "attention_type": "performer",
    "norm_type": "rezero",
    "num_random_features": 32,  # number of random features for the Attention module (Performer uses this)
    "emb_dropout": 0.1,  # dropout for the embedding block
    "fw_dropout": 0.1,  # dropout for the position-wise feed-forward layer
    "att_dropout": 0.1,  # dropout for the multiheadattention module
    "dc_dropout": 0.1,  # dropout for the decoder block
    "hidden_act": "swish",  # activation function for the hidden layers (attention layers use ReLU)
    "epsilon": 1e-8,
    "weight_decay": 0.01,
    "beta1": 0.9,
    "beta2": 0.999,
}
model = TransformerEncoder(hparams)

# Sanity Check
Check if returns values and if the output looks OK.

In [13]:
batch = {
    'tokens': torch.randint(0, model.hparams.vocab_size, (model.hparams.batch_size, model.hparams.max_length)),
    'abspos': torch.arange(0, model.hparams.max_length).unsqueeze(0).repeat(model.hparams.batch_size, 1),
    'age': torch.randint(0, 100, (model.hparams.batch_size, 1)).repeat(1, model.hparams.max_length),
    'padding_mask': torch.zeros(model.hparams.batch_size, model.hparams.max_length),
    'targets': torch.randint(0, 2, (model.hparams.batch_size, 1)).float(),
}

In [14]:
batch

{'tokens': tensor([[70, 33, 44, 72,  9, 40, 55, 65,  8, 67, 70, 92, 48, 22, 12, 22, 59, 52,
          28, 31, 74, 71, 51, 92,  3, 98, 18, 76, 26, 92],
         [28, 24, 22, 22, 81, 28, 71, 65,  2, 46, 28,  4, 62, 71, 12, 36, 36, 93,
          50, 74, 47,  3, 73, 89, 56, 79,  4, 87, 37, 89],
         [72, 69, 64, 11, 85, 24,  6, 40, 87,  6, 69, 62,  3, 31, 82,  3, 87, 58,
          39, 38, 76, 39, 12, 54,  1, 48, 95,  7, 51, 68],
         [88, 45, 71, 68, 46, 65, 93, 45, 54, 92, 98, 50, 17, 21, 37, 36, 90, 27,
          88, 48, 35, 21, 80, 25, 98, 99, 58, 48, 67, 28]]),
 'abspos': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 

In [15]:
model(batch)

{'logits': tensor([[-1.0735],
         [-0.7359],
         [-0.9112],
         [-0.5476]], grad_fn=<AddmmBackward0>),
 'preds': tensor([[0.2547],
         [0.3239],
         [0.2868],
         [0.3664]], grad_fn=<SigmoidBackward0>)}

In [16]:
model.training_step(batch, 0)

tensor(0.7365, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [17]:
model.on_train_epoch_end()


Train Metrics
	Loss: 0.737
	Accuracy: 0.500
	MCC: 0.000



### Full pipeline would be something like

In [18]:
from dataloaders.synthetic import SyntheticDataModule

dataloader = SyntheticDataModule(num_samples=1000, max_length=hparams['max_length'],
                                  batch_size=hparams['batch_size'], vocab_size=hparams['vocab_size'])

In [21]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

model_checkpoint = ModelCheckpoint(monitor='val_mcc', save_top_k=2, mode='max')
early_stopping = EarlyStopping(monitor='val_mcc', patience=5, mode='max')
logger = TensorBoardLogger("lightning_logs", name="transformer")

trainer = Trainer(max_epochs=30,
                accelerator="cpu",   ### change to "cuda" or "gpu" or 'msp'
                limit_train_batches=0.5,
                logger=logger,
                accumulate_grad_batches=4,
                num_sanity_val_steps=8,
                callbacks = [model_checkpoint, early_stopping],
                check_val_every_n_epoch=1)


GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/carlomarx/.local/share/virtualenvs/project2vec-D-jEdA35-python/lib/python3.11/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [22]:
trainer.fit(model, dataloader)

Missing logger folder: lightning_logs/transformer

   | Name        | Type                   | Params | Mode
---------------------------------------------------------------
0  | transformer | Transformer            | 79.5 K | eval
1  | decoder     | CLS_Decoder            | 2.4 K  | eval
2  | loss        | BCEWithLogitsLoss      | 0      | eval
3  | train_loss  | MeanMetric             | 0      | eval
4  | val_loss    | MeanMetric             | 0      | eval
5  | test_loss   | MeanMetric             | 0      | eval
6  | train_acc   | BinaryAccuracy         | 0      | eval
7  | val_acc     | BinaryAccuracy         | 0      | eval
8  | test_acc    | BinaryAccuracy         | 0      | eval
9  | train_mcc   | BinaryMatthewsCorrCoef | 0      | eval
10 | val_mcc     | BinaryMatthewsCorrCoef | 0      | eval
11 | test_mcc    | BinaryMatthewsCorrCoef | 0      | eval
---------------------------------------------------------------
81.9 K    Trainable params
0         Non-trainable params
81.9 K   

Sanity Checking DataLoader 0: 100%|██████████| 8/8 [00:00<00:00, 105.54it/s]
Val Metrics
	Loss: 0.757
	Accuracy: 0.514
	MCC: 0.003

                                                                            

/Users/carlomarx/.local/share/virtualenvs/project2vec-D-jEdA35-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/carlomarx/.local/share/virtualenvs/project2vec-D-jEdA35-python/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 125/125 [00:01<00:00, 98.67it/s, v_num=0]
Val Metrics
	Loss: 0.751
	Accuracy: 0.504
	MCC: 0.000

Epoch 0: 100%|██████████| 125/125 [00:02<00:00, 45.61it/s, v_num=0]
Train Metrics
	Loss: 0.770
	Accuracy: 0.493
	MCC: -0.001

Epoch 1: 100%|██████████| 125/125 [00:01<00:00, 94.90it/s, v_num=0]
Val Metrics
	Loss: 0.743
	Accuracy: 0.502
	MCC: 0.000

Epoch 1: 100%|██████████| 125/125 [00:02<00:00, 44.23it/s, v_num=0]
Train Metrics
	Loss: 0.765
	Accuracy: 0.489
	MCC: -0.001

Epoch 2: 100%|██████████| 125/125 [00:01<00:00, 96.60it/s, v_num=0]
Val Metrics
	Loss: 0.737
	Accuracy: 0.496
	MCC: -0.000

Epoch 2: 100%|██████████| 125/125 [00:02<00:00, 44.13it/s, v_num=0]
Train Metrics
	Loss: 0.753
	Accuracy: 0.494
	MCC: -0.001

Epoch 3:  86%|████████▌ | 107/125 [00:01<00:00, 95.45it/s, v_num=0]

/Users/carlomarx/.local/share/virtualenvs/project2vec-D-jEdA35-python/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
