In [1]:
import torch
from mamkit.utils import get_text_dataset
from mamkit.datasets import MAMKitMonomodalDataset, BERT_Collator, UnimodalCollator
from random import randint
from mamkit.models import MAMKitTextOnly
import lightning as L
from torch.nn.utils.rnn import pad_sequence
from mamkit.utils import to_lighting_model


In [2]:
# ACC
LABELS_TO_INT = {
    'Claim': 0,
    'Premise': 1,
}

# ASD
# LABELS_TO_INT = {
#     'ARG': 0,
#     'Not-ARG': 1,
# }

train, val, test = get_text_dataset('usdbelec', lambda x: x, 'acc')


In [3]:
def label_collator(labels):
    return torch.tensor([LABELS_TO_INT[label] for label in labels])

In [4]:
unimodal_collator = UnimodalCollator(
    features_collator = BERT_Collator(),
    label_collator = label_collator
)

In [5]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=False, collate_fn=unimodal_collator)
val_dataloader = torch.utils.data.DataLoader(val, batch_size=8, shuffle=False, collate_fn=unimodal_collator)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=8, shuffle=False, collate_fn=unimodal_collator)

In [6]:
classification_head = torch.nn.Sequential(
    torch.nn.Linear(768, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 3)
)

text_only = MAMKitTextOnly(
    head = classification_head,
)

model = to_lighting_model(text_only, torch.nn.CrossEntropyLoss(), torch.optim.Adam, lr=1e-3)

trainer = L.Trainer(max_epochs=1)
trainer.fit(model, train_dataloader, val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MAMKitTextOnly   | 7.7 K 
1 | loss_function | CrossEntropyLoss | 0     
---------------------------------------------------
7.7 K     Trainable params
0  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [7]:
test_results = trainer.test(model, test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
           acc               0.627115786075592
          loss              0.6716703772544861
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [8]:
test_results

[{'loss': 0.6716703772544861, 'acc': 0.627115786075592}]