In [39]:
import torch
import mamkit
from mamkit.utils import get_audio_dataset
from mamkit.datasets import MAMKitMonomodalDataset, BERT_Collator, UnimodalCollator
from random import randint
from mamkit.models import MAMKitAudioOnly
import lightning as L
from torch.nn.utils.rnn import pad_sequence
from mamkit.utils import to_lighting_model, MetricTracker
import torchmetrics


import importlib
importlib.reload(mamkit.utils)

<module 'mamkit.utils' from '/home/stefano/Public/mamkit/mamkit/utils.py'>

In [4]:
# ACC
LABELS_TO_INT = {
    'Claim': 0,
    'Premise': 1,
}

# ASD
# LABELS_TO_INT = {
#     'ARG': 0,
#     'Not-ARG': 1,
# }

train, val, test = get_audio_dataset('usdbelec', 'wav2vec2-single', 'acc')
print("Lengths: ", len(train), len(val), len(test))


Downloading dataset wav2vec2-single...
Dataset wav2vec2-single downloaded.
Extracting dataset wav2vec2-single...
Dataset wav2vec2-single extracted.
Lengths:  9455 5201 5908


In [5]:
def label_collator(labels):
    return torch.tensor([LABELS_TO_INT[label] for label in labels])

def audio_collate_fn(features):
    features = pad_sequence(features, batch_first=True, padding_value=float('-inf'))
    attention_mask = features[:, :, 0] != float('-inf')
    features[(features == float('-inf'))] = 0
    return features, attention_mask

In [6]:
unimodal_collator = UnimodalCollator(
    features_collator = audio_collate_fn,
    label_collator = label_collator
)

In [7]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=False, collate_fn=unimodal_collator)
val_dataloader = torch.utils.data.DataLoader(val, batch_size=8, shuffle=False, collate_fn=unimodal_collator)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=8, shuffle=False, collate_fn=unimodal_collator)

In [42]:
classification_head = torch.nn.Sequential(
    torch.nn.Linear(768, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 2)
)

audio_only = MAMKitAudioOnly(
    head = classification_head,
)

val_metrics = {"Val_Accuracy":torchmetrics.classification.Accuracy(task="multiclass", num_classes=2), "Val_F1":torchmetrics.classification.F1Score(task="multiclass", num_classes=2)}
test_metrics = {"Test_Accuracy":torchmetrics.classification.Accuracy(task="multiclass", num_classes=2), "Test_F1":torchmetrics.classification.F1Score(task="multiclass", num_classes=2)}

model = to_lighting_model(audio_only, torch.nn.CrossEntropyLoss(), torch.optim.Adam, lr=1e-3, val_metrics=val_metrics, test_metrics=test_metrics, log_metrics=True)

mt = MetricTracker()

trainer = L.Trainer(max_epochs=2, check_val_every_n_epoch=1, callbacks=[mt])
trainer.fit(model, train_dataloader, val_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | MAMKitAudioOnly  | 7.7 K 
1 | loss_function | CrossEntropyLoss | 0     
2 | val_metrics   | ModuleList       | 0     
3 | test_metrics  | ModuleList       | 0     
---------------------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 1182/1182 [00:22<00:00, 52.69it/s, v_num=48, train_loss=0.659, val_loss=0.745, Accuracy=0.541, F1=0.541]

Epoch 1: 100%|██████████| 1182/1182 [00:22<00:00, 52.88it/s, v_num=48, train_loss=0.659, val_loss=0.745, Accuracy=0.541, F1=0.541]

`Trainer.fit` stopped: `max_epochs=2` reached.




Epoch 1: 100%|██████████| 1182/1182 [00:22<00:00, 52.85it/s, v_num=48, train_loss=0.659, val_loss=0.745, Accuracy=0.541, F1=0.541]


In [45]:
mt.collection

[{'val_loss': tensor(0.6818),
  'Accuracy': tensor(0.4375),
  'F1': tensor(0.4375)},
 {'train_loss': tensor(0.6593),
  'val_loss': tensor(0.7453),
  'Accuracy': tensor(0.5414),
  'F1': tensor(0.5414)},
 {'train_loss': tensor(0.6593),
  'val_loss': tensor(0.7453),
  'Accuracy': tensor(0.5414),
  'F1': tensor(0.5414)}]

In [46]:
val_results = trainer.validate(model, val_dataloader)

/home/stefano/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation DataLoader 0: 100%|██████████| 651/651 [00:07<00:00, 85.37it/s] 


In [47]:
test_results = trainer.test(model, test_dataloader)

/home/stefano/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0:   1%|          | 8/739 [00:00<00:07, 96.05it/s] 

Testing DataLoader 0: 100%|██████████| 739/739 [00:07<00:00, 94.41it/s] 


In [56]:
test_results

[{'Accuracy': 0.5140487551689148,
  'F1': 0.5140487551689148,
  'loss': 0.6929669380187988}]