In [1]:
import lightning as L
import torch
from torch.nn.utils.rnn import pad_sequence
from mamkit.utils import get_multimodal_dataset, to_lighting_model
from mamkit.datasets import MultimodalCollator, BERT_Collator
from mamkit.models import MAMKitCSA

In [2]:
train, val, test = get_multimodal_dataset('usdbelec', lambda x: x, 'wav2vec2-single', 'acc')

Downloading dataset wav2vec2-single...
Dataset wav2vec2-single downloaded.
Extracting dataset wav2vec2-single...
Dataset wav2vec2-single extracted.


In [3]:
# ACC

LABELS_TO_INT = {
    'Claim': 0,
    'Premise': 1,
}

# ASD

# LABELS_TO_INT = {
#     'ARG': 0,
#     'Not-ARG': 1,
# }

In [4]:
def audio_collate_fn(features):
    features = pad_sequence(features, batch_first=True, padding_value=float('-inf'))
    attention_mask = features[:, :, 0] != float('-inf')
    features[(features == float('-inf'))] = 0
    return features, attention_mask

def label_collator(labels):
    return torch.tensor([LABELS_TO_INT[label] for label in labels])

In [5]:
collate_fn = MultimodalCollator(
    text_collator=BERT_Collator(),
    audio_collator=audio_collate_fn,
    label_collator=label_collator
)

In [6]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=False, collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [7]:
from mamkit.modules import CustomEncoder, PositionalEncoding

pos = PositionalEncoding(768)

csa_encoder = CustomEncoder(
    d_model = 768,
    ffn_hidden = 16,
    n_head = 1,
    n_layers = 1,
    drop_prob = 0.1
)

csa_head = torch.nn.Linear(768, 3)


model = MAMKitCSA(
    transformer = csa_encoder,
    head = csa_head,
    positional_encoder = pos
)



In [8]:
lt_model = to_lighting_model(model, torch.nn.CrossEntropyLoss(), torch.optim.Adam)

In [9]:
trainer = L.Trainer(max_epochs=1)

trainer.fit(lt_model, train_dataloader, val_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MAMKitCSA        | 2.4 M 
1 | loss_function | CrossEntropyLoss | 0     
---------------------------------------------------
2.4 M     Trainable params
0  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/home/andrea/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
