In [None]:
from helpers import *
from datasets import load_dataset
from transformers import TrainingArguments, Trainer, RobertaTokenizer, LongformerForSequenceClassification
from torchmetrics import F1, Recall, Precision, PrecisionRecallCurve

In [None]:
dataset_path = 'data/'
output_path = 'output/longformer-reuters-multilabel'
model_save_dir = os.path.join('saved-models/longformer-reuters-multilabel')
all_topics = ['earn', 'acq', 'money-fx', 'grain', 'crude', 'trade', 'interest']
num_labels = len(all_topics)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

In [None]:
filenames = get_filenames(dataset_path)

In [None]:
dataset_train_dict, dataset_val_dict, dataset_test_dict = build_dataset_dictionaries(dataset_path, 
                                                                                     filenames, 
                                                                                     all_topics)

In [None]:
csv_train, csv_eval, csv_test = write_to_csv(dataset_path, 
                                             all_topics, 
                                             dataset_train_dict, 
                                             dataset_val_dict, 
                                             dataset_test_dict)

In [None]:
raw_datasets = load_dataset('csv', data_files={'train': os.path.join(dataset_path, csv_train),
                                               'eval': os.path.join(dataset_path, csv_eval),
                                               'test': os.path.join(dataset_path, csv_test)})
raw_datasets

In [None]:
show_histogram_multilabel(all_topics, raw_datasets)

In [None]:
tokenized_datasets = adjust_and_tokenize_datasets(raw_datasets)

In [None]:
training_args = TrainingArguments(output_path, 
                                  evaluation_strategy="epoch", 
                                  save_strategy="epoch",
                                  per_device_train_batch_size=8, 
                                  per_device_eval_batch_size=8, 
                                  num_train_epochs=6,
                                  learning_rate=2e-5,
                                  logging_steps=1000,
                                  load_best_model_at_end=True)

In [None]:
model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096', 
                                                            num_labels=num_labels, 
                                                            ignore_mismatched_sizes=True, 
                                                            problem_type='multi_label_classification')

In [None]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['eval'], 
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained(model_save_dir)

In [None]:
model = LongformerForSequenceClassification.from_pretrained(model_save_dir)
model.config
trainer = Trainer(
    model,
    tokenizer=tokenizer,
)

In [None]:
predictions = trainer.predict(tokenized_datasets['test'])

In [None]:
predictions.label_ids

In [None]:
preds = torch.sigmoid(torch.tensor(predictions.predictions))
threshold = torch.tensor([0.5])
predicted_labels = (preds>threshold).float()*1
target = torch.tensor(predictions.label_ids, dtype=torch.int8)

In [None]:
f1 = F1(num_classes=num_labels, average=None)
precision = Precision(num_classes=num_labels, average=None)
recall = Recall(num_classes=num_labels, average=None)
#curve = PrecisionRecallCurve()

In [None]:
f1(predicted_labels, target)

In [None]:
precision(predicted_labels, target)

In [None]:
recall(predicted_labels, target)