In [None]:
import lightning as L
import torch
from torch.nn.utils.rnn import pad_sequence
from mamkit.utils import get_multimodal_dataset, to_lighting_model
from mamkit.datasets import MultimodalCollator, BERT_Collator
from mamkit.models import MAMKitCSA

In [None]:
train, val, test = get_multimodal_dataset('usdbelec', lambda x: x, 'wav2vec2-single', 'acc')

In [None]:
# ACC

LABELS_TO_INT = {
    'Claim': 0,
    'Premise': 1,
}

# ASD

# LABELS_TO_INT = {
#     'ARG': 0,
#     'Not-ARG': 1,
# }

In [None]:
def audio_collate_fn(features):
    features = pad_sequence(features, batch_first=True, padding_value=float('-inf'))
    attention_mask = features[:, :, 0] != float('-inf')
    features[(features == float('-inf'))] = 0
    return features, attention_mask

def label_collator(labels):
    return torch.tensor([LABELS_TO_INT[label] for label in labels])

In [None]:
collate_fn = MultimodalCollator(
    text_collator=BERT_Collator(),
    audio_collator=audio_collate_fn,
    label_collator=label_collator
)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=False, collate_fn=collate_fn)
val_dataloader = torch.utils.data.DataLoader(val, batch_size=8, shuffle=False, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [None]:
from mamkit.modules import CustomEncoder, PositionalEncoding

pos = PositionalEncoding(768)

csa_encoder = CustomEncoder(
    d_model = 768,
    ffn_hidden = 16,
    n_head = 1,
    n_layers = 1,
    drop_prob = 0.1
)

csa_head = torch.nn.Linear(768, 3)


model = MAMKitCSA(
    transformer = csa_encoder,
    head = csa_head,
    positional_encoder = pos
)



In [None]:
lt_model = to_lighting_model(model, torch.nn.CrossEntropyLoss(), torch.optim.Adam)

In [None]:
trainer = L.Trainer(max_epochs=1)

trainer.fit(lt_model, train_dataloader, val_dataloader)

In [None]:
test_results = trainer.test(lt_model, test_dataloader)

In [None]:
test_results