In [None]:
import seaborn as sns
import torch
from sklearn.metrics import accuracy_score, f1_score
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
                          DataCollatorWithPadding, Trainer, TrainingArguments)

from data import SharedTaskData


# Prepare dataset

In [None]:
train_data = SharedTaskData("TaskA_train.csv")
dev_data = SharedTaskData("TaskA_dev.csv")

train_dataset = train_data.convert_to_hf_dataset()
dev_dataset = dev_data.convert_to_hf_dataset(features=train_dataset.features)

# Make sure internal label mapping is identical across datasets
assert train_dataset.features['validity_str']._str2int == dev_dataset.features['validity_str']._str2int

In [None]:
# investigate validity labels
sns.displot(train_dataset['Validity'])

In [None]:
# .. and novelty labels
sns.displot(train_dataset['Novelty'])

# Data imbalance
We suffer from some data imbalance so we may need to oversample.

In [None]:
# We focus on predicting Validity label for now
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, 
    num_labels=train_dataset.features['validity_str'].num_classes,
    label2id=train_dataset.features['validity_str']._str2int, 
    id2label={v:k for k, v in train_dataset.features['validity_str']._str2int.items()},
)


def tokenize_function(examples):
    batch_size = len(examples['Premise'])
    batched_inputs = [
        examples['topic'][i] + tokenizer.sep_token + \
        examples['Premise'][i] + tokenizer.sep_token + \
        examples['Conclusion'][i] for i in range(batch_size)
    ]
    samples = tokenizer(batched_inputs, truncation=True, padding=True)
    samples['labels'] = examples['validity_str']
    return samples 

tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_dev_dataset = dev_dataset.map(tokenize_function, batched=True)
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
tokenized_dev_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [None]:
# define metrics

def single_label_metrics(predictions, labels):
    softmax = torch.nn.Softmax(dim=1)
    preds = torch.Tensor(predictions)
    probs = softmax(preds)
    y_pred = torch.argmax(probs, dim=1)
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    accuracy = accuracy_score(y_true, y_pred)
    return {'f1': f1_micro_average, 'accuracy': accuracy}

def compute_metrics(p):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    return single_label_metrics(
        predictions=preds, 
        labels=p.label_ids
    )

In [None]:
training_args = TrainingArguments(
    "argmining2022_trainer",
    num_train_epochs=10,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_dev_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()