In [None]:
MODEL_CHECKPOINT = "xlnet-base-cased"

In [None]:
import torch
from transformers import XLNetModel


class XLNetClassifier(torch.nn.Module):

    def __init__(self, model_checkpoint, dropout=None):

        super(XLNetClassifier, self).__init__()
        self.checkpoint = model_checkpoint
        self.xlnet = XLNetModel.from_pretrained(model_checkpoint, num_labels=2)
        self.linear1 = torch.nn.Linear(self.xlnet.config.hidden_size, self.xlnet.config.hidden_size)
        self.tanh = torch.nn.Tanh()
        
        self.dropout = torch.nn.Dropout(dropout if dropout else self.xlnet.config.summary_last_dropout)
        self.linear2 = torch.nn.Linear(self.xlnet.config.hidden_size, 2)


    def forward(self, input_ids, attention_mask, token_type_ids):

        xlnet_output = self.xlnet(input_ids=input_ids, attention_mask=attention_mask, token_type_ids = token_type_ids)
        summary = self.linear1(xlnet_output.last_hidden_state[:, -1])
        activation = self.tanh(summary)
        dropout_output = self.dropout(activation)
        output = self.linear2(dropout_output)
        
        return output

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

In [None]:
import os
import pandas as pd

from data_preparing import get_dataloader, split_data
from data_preprocessing import xml_to_df
from global_parameters import MAX_LEN, TEST_BATCH_SIZE, TRAIN_BATCH_SIZE, VAL_BATCH_SIZE


xml_files = []
data_path = "../data-querying/results/"
folder_names = ["human_medical_data/", "veterinary_medical_data/"]
for folder in folder_names:
    xml_files.append([f"{data_path}{folder}{xml}" for xml in os.listdir(data_path + folder)])

hum_df, vet_df = xml_to_df(xml_files)

# balance case reports and other text types
vet_case_rep = vet_df[vet_df['text_types'].apply(lambda x: "Case Reports" in x)].sample(frac=1, random_state=42).reset_index(drop=True, inplace=False)
max_num = len(vet_case_rep)
vet_jour_art = vet_df[vet_df['text_types'].apply(lambda x: "Case Reports" not in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)
hum_case_rep = hum_df[hum_df['text_types'].apply(lambda x: "Case Reports" in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)
hum_jour_art = hum_df[hum_df['text_types'].apply(lambda x: "Case Reports" not in x)].sample(max_num, random_state=42).reset_index(drop=True, inplace=False)
hum_df_balanced = pd.concat([hum_case_rep, hum_jour_art])
vet_df_balanced = pd.concat([vet_case_rep, vet_jour_art])

train_set, val_set, test_set = split_data(hum_df_balanced, vet_df_balanced, 3)

train_set["title_abstract"] = train_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
val_set["title_abstract"] = val_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)
test_set["title_abstract"] = test_set[["title", "abstract"]].apply(lambda row: ' '.join(row.values.astype(str)), axis=1)

train_dataloader = get_dataloader(train_set.title_abstract, train_set.labels, tokenizer, TRAIN_BATCH_SIZE, MAX_LEN)
val_dataloader = get_dataloader(val_set.title_abstract, val_set.labels, tokenizer, VAL_BATCH_SIZE, MAX_LEN)
test_dataloader = get_dataloader(test_set.title_abstract, test_set.labels, tokenizer, TEST_BATCH_SIZE, MAX_LEN)

In [None]:
from transformers import get_linear_schedule_with_warmup
import os

from global_parameters import EPOCHS, LEARNING_RATE, PATH_SAVED_MODELS
from training import train_model
from loss_fn import loss_fn


train = True

model_filename = f"{MODEL_CHECKPOINT}.bin"
filenames = None if not os.path.exists(PATH_SAVED_MODELS) else os.listdir(PATH_SAVED_MODELS)

if filenames and model_filename in filenames and not train:
    model = XLNetClassifier(MODEL_CHECKPOINT)
    model.load_state_dict(torch.load(PATH_SAVED_MODELS + model_filename))
    model.to(device)
    print(f"{model_filename} loaded.")
else:
    print("No saved model found or forced to train.")
    
    torch.cuda.empty_cache()
    
    model = XLNetClassifier(MODEL_CHECKPOINT)
    model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    total_steps = len(train_dataloader) * EPOCHS
    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=0,
      num_training_steps=total_steps
    )
    
    history = train_model(model, train_dataloader, val_dataloader, TRAIN_BATCH_SIZE, loss_fn, optimizer, device, scheduler, EPOCHS)

In [None]:
import pandas as pd
from training import plot


if 'history' in locals():
    history = pd.DataFrame(history)
    data_list = [history['train_acc'], history['val_acc']]
    label_list = ["train accuracy", "validation accuracy"]
    plot(data_list, label_list, "Training history", "Accuracy", "Epoch", [0.95, 1])

In [None]:
from sklearn.metrics import classification_report

from global_parameters import LABELS_MAP
from predict import predict


test_preds = predict(model, test_set.title_abstract, tokenizer, device)
preds_labels = [torch.argmax(pred).item() for pred in test_preds]
report = classification_report(test_set.labels, preds_labels, target_names=LABELS_MAP.keys())
print(report)

In [None]:
from matplotlib import pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay


labels = list(LABELS_MAP.keys())
test_classes = [labels[0] if label == 0 else labels[1] for label in test_set.labels]
preds_classes = [labels[0] if label == 0 else labels[1] for label in preds_labels]
disp = ConfusionMatrixDisplay.from_predictions(test_classes, preds_classes, labels=labels, normalize="true", cmap=plt.cm.Blues)
disp.ax_.set_title("Confusion matrix (normalized)")
plt.show()