In [None]:
!pip install transformers datasets evaluate

In [None]:
from datasets import load_dataset

miam_dataset = load_dataset('miam', 'loria')
miam_dataset = miam_dataset.remove_columns(['Speaker', 'Dialogue_ID', 'File_ID', 'Idx'])

In [None]:
from transformers import BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

def preprocess(sample):
  text = sample["Utterance"]
  labels = torch.tensor(sample["Label"])
  str_labels = sample["Dialogue_Act"]
  tokens = tokenizer(text, padding="max_length", truncation=True)
  
  return {"input_ids": tokens.input_ids, "attention_mask": tokens.attention_mask, "labels": labels, "str_labels": str_labels, "sentences": text}

miam_tokenized = miam_dataset.map(preprocess, batched=True)

In [None]:
from transformers import BertForSequenceClassification

mBert = BertForSequenceClassification.from_pretrained("bert-base-multilingual-cased", num_labels=31)

In [32]:
import evaluate
import numpy as np

metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
  logits, labels= eval_pred
  predictions = np.argmax(logits, axis=-1)
  return metric.compute(predictions=predictions, references=labels)

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="mBERT-training-DA",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=mBert,
    args=training_args,
    train_dataset=miam_tokenized["train"],
    eval_dataset=miam_tokenized["validation"],
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
mBert = BertForSequenceClassification.from_pretrained("mBERT-training-DA")

In [None]:
def get_pred(model, tokenizer, sentence):

    tokenized_sentence = tokenizer(sentence, return_tensors="pt")

    out = model(
        input_ids=tokenized_sentence.input_ids,
        attention_mask=tokenized_sentence.attention_mask
    )

    logits = out.logits

    probas = torch.softmax(logits, -1).squeeze()

    pred = torch.argmax(probas)

    return pred

In [None]:
test_preds = []

for sentence in miam_dataset["test"]["Utterance"]:
  pred = get_pred(mbert, tokenizer, sentence)
  test_preds.append(pred)

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(labels, preds, label_names):
    confusion_norm = confusion_matrix(labels, preds, labels=list(range(len(label_names))), normalize="true")
    confusion = confusion_matrix(labels, preds, labels=list(range(len(label_names))))
    
    plt.figure(figsize=(16, 14))
    sns.heatmap(
        confusion_norm,
        annot=confusion,
        cbar=False,
        fmt="d",
        xticklabels=label_names,
        yticklabels=label_names,
        cmap="viridis"
    )
    return confusion

In [None]:
predictions = test_preds
label_names = miam_dataset["test"].features["Label"].names
labels = miam_dataset["test"]["Label"]

cm = plot_confusion_matrix(labels, predictions, label_names)