# PLM testing

In [None]:
import torch
from torch import cuda


device = 'cuda' if cuda.is_available() else 'cpu'
device

### Load models and test set

In [None]:
model_names = {
    "bert":{"slug": "bert-base-uncased-pubmed", "file_name":"bert-base-uncased.pt"},
    "roberta":{"slug": "roberta-base-pubmed", "file_name":"roberta-base.pt"},
    "deberta":{"slug": "deberta-base-pubmed", "file_name":"deberta-base.pt"},
    "bluebert":{"slug": "bluebert-large-pubmed", "file_name":"bluebert_pubmed_uncased_L-24_H-1024_A-16.pt"},
    "xlnet":{"slug": "xlnet-large-pubmed", "file_name":"xlnet-large-cased.pt"}
}

In [None]:
import kaggle
import os

from Source_code.z_utils.BERTClassifier import BERTClassifier
from Source_code.z_utils.RoBERTaClassifier import RoBERTaClassifier
from Source_code.z_utils.DeBERTaClassifier import DeBERTaClassifier
from Source_code.z_utils.BlueBERTClassifier import BlueBERTClassifier
from Source_code.z_utils.XLNetClassifier import XLNetClassifier


kaggle.api.authenticate()
data_path = "./models/"
if not os.path.exists(data_path):
    os.makedirs(data_path)
    print(f"Directory created: {data_path}")
models = {}

for model in model_names.keys():
    file_name = model_names[model]["file_name"]
    target_path = f"{data_path}{file_name}"
    
    if not os.path.exists(target_path):
        slug = model_names[model]["slug"]
        kaggle.api.model_instance_version_download_cli(f"marcelhiltner/{slug}/pytorch/{slug}/1", data_path, untar=True)
        
    models[model] = torch.load(target_path)
    models[model].eval()
    print(f"{model} loaded.")

In [None]:
import kaggle
import os
import zipfile


kaggle.api.authenticate()
data_path = "./datasets/"
if not os.path.exists(data_path):
    os.makedirs(data_path)
    print(f"Directory created: {data_path}")
    
kaggle.api.dataset_download_file('marcelhiltner/pubmed-human-veterinary-medicine-classification', file_name="test.json", path=data_path)
zip_path = f"{data_path}test.json.zip"
with zipfile.ZipFile(zip_path, "r") as z:
    z.extractall(data_path)
os.remove(zip_path)
print(os.listdir(data_path))

In [None]:
import pandas as pd


try:
    test_set = pd.read_json(f"{data_path}test.json", orient="records")
    print("Data loaded successfully: test.json")
    print(f"Shape: {test_set.shape}")
except Exception as e:
    print(f"An error occurred: {e}")

### Test and plot classification report and confusion matrices

In [None]:
from tqdm import tqdm
import pickle
from sklearn.metrics import ConfusionMatrixDisplay, classification_report
from matplotlib import pyplot as plt
import time
import datetime
import numpy as np

from Source_code.z_utils.predict import predict
from Source_code.z_utils.data_preparing import get_dataloader
from Source_code.z_utils.global_constants import *


for model_key in models.keys():
    model = models[model_key]
    
    print("=" * 30)
    print(f'Model {model.checkpoint}')
    print("=" * 30)
    
    model.to(device)
    model.eval()
    
    time0 = time.monotonic_ns()
    # cls report
    dataloader = get_dataloader(test_set.title_abstract, test_set.labels, model.tokenizer, TEST_BATCH_SIZE, MAX_LEN)
    test_preds, labels = predict(model, dataloader=dataloader, device=device)
    time1 = time.monotonic_ns()
    print(datetime.timedelta(microseconds=(time1 - time0)/1000))

    preds_labels = [np.argmax(pred) for pred in np.concatenate(test_preds)]
    report = classification_report(test_set["labels"], preds_labels, target_names=LABELS_MAP.keys())
    
    test_set.title_abstract[test_set.labels != preds_labels].to_json(f"test_false_predictions_{model.checkpoint[model.checkpoint.find('/')+1:]}")
    with open(f"test_report_{model.checkpoint[model.checkpoint.find('/')+1:]}", 'wb') as f:
        pickle.dump(report, f)
    
    print(report)

    # confusion matrix
    labels = list(LABELS_MAP.keys())
    test_classes = [labels[0] if label == 0 else labels[1] for label in test_set.labels]
    preds_classes = [labels[0] if label == 0 else labels[1] for label in preds_labels]
    disp = ConfusionMatrixDisplay.from_predictions(test_classes, preds_classes, labels=labels, normalize=None, cmap=plt.cm.Blues)
    disp.ax_.set_title("Confusion matrix")
    plt.savefig(f"confusion_matrix_{model.checkpoint[model.checkpoint.find('/')+1:]}.pdf", format="pdf", bbox_inches="tight")
    plt.show()
    plt.close()