In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install -q -r requirements.txt

In [None]:
import json
import pandas as pd
import os
import re
import string
from transformers import BertForSequenceClassification
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
from transformers import TrainingArguments, Trainer
from datasets import Dataset
import numpy as np
from sklearn.model_selection import train_test_split

from vecsim_app.categories import CATEGORIES
from vecsim_app.data_utils import papers

DATA_PATH = "/home/jovyan/arxiv/arxiv-metadata-oai-snapshot.json"
YEAR_CUTOFF = 2012
YEAR_PATTERN = r"(19|20[0-9]{2})"

In [None]:
df = pd.DataFrame(papers(data_path=DATA_PATH, year_cutoff=YEAR_CUTOFF, year_pattern=YEAR_PATTERN))
len(df)

In [None]:
df.head(3)

In [None]:
df['text'] = df['title'] + ' ' + df['abstract']
# df['categories'] = df['categories'].apply(lambda x: x.split(','))

In [None]:
df.iloc[0].categories

In [None]:
# Split into train and test
df_train, df_test = train_test_split(df, test_size=0.2)

In [None]:
def get_tokenizer(tokenizer_model):
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
    return tokenize_function, tokenizer

tokenize_function, tokenizer = get_tokenizer('bert-base-uncased')

In [None]:
mlb = MultiLabelBinarizer()
mlb.fit(df_train['categories'])

In [None]:
def preprocess_data(examples):
  # take a batch of texts
  text = examples["text"]

  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)

  encoded_categories = mlb.transform(examples['categories']).astype(float)

  encoding["labels"] = encoded_categories

  return encoding

In [None]:
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels=len(mlb.classes_), 
    problem_type="multi_label_classification"
)

In [None]:
df_train_hf = Dataset.from_pandas(df_train[['text', 'categories']].head(100))
tokenized_train = df_train_hf.map(preprocess_data, batched=True)

df_test_hf = Dataset.from_pandas(df_test[['text', 'categories']].head(100))
tokenized_test = df_test_hf.map(preprocess_data, batched=True)

In [None]:
# Debugging - get inverse transform

print("Reversed", mlb.inverse_transform(np.asarray(tokenized_test[0]['labels']).reshape(1, -1)))
print("Original categories", tokenized_test[0]['categories'])

In [None]:
# Adaptation: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb 
# 
batch_size = 16
metric_name = "f1"

args = TrainingArguments(
    f"paper-multilabel-finetuning",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)


# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
eval = trainer.evaluate()
eval

In [None]:
text = df['text'].iloc[5]

encoding = tokenizer(text, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k, v in encoding.items()}

outputs = trainer.model(**encoding)
logits = outputs.logits

In [None]:
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.1)] = 1

In [None]:
print(text)
print(mlb.inverse_transform(predictions.reshape(1, -1)))

In [None]:
import pickle


with open('mlb.pkl', 'wb') as handle:
    pickle.dumps(mlb, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open('mlb_metrics.json', 'w') as f:
    f.write(json.dumps(eval, indent=4))