### HD vs Flu vs CoV classification task

In [None]:
import warnings
warnings.simplefilter('ignore')


from transformers import (
    AutoTokenizer,
    RobertaTokenizer,
    AutoModelForSequenceClassification, 
    Trainer,
    TrainingArguments,
)

import torch

import pandas as pd
import numpy as np

import datasets
from datasets import (
    DatasetDict,
    ClassLabel,
    load_dataset,
)

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
)

import evaluate
accuracy = evaluate.load("accuracy")
from datetime import date
import wandb
from random import randint

In [None]:
# replace with actual model path
checkpoint = './BALM-paired/'

# model name for run name & saving
model_str = "BALM-paired"

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("../tokenizer/")

### Process and Tokenize Data

In [None]:
def preprocess_dataset(
    batch, 
    tokenizer=None, 
    tokenizer_path="./tokenizer", 
    separator="</s>",
    max_len=512
) -> list:
    """
    docstring
    """
    # set up tokenizer if not provided
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, max_len=max_len)
        
    # tokenize the H/L sequence pair
    sequences = [h + separator + l for h, l in zip(batch["h_sequence"], batch["l_sequence"])]
    tokenized = tokenizer(sequences, padding="max_length", max_length=max_len)
    batch["input_ids"] = tokenized.input_ids
    batch["attention_mask"] = tokenized.attention_mask
    
    return batch

In [None]:
class_labels = ClassLabel(names=['Healthy-Donor', 'Flu-specific', 'CoV-specific'])
n_classes = len(class_labels.names)

In [None]:
# the dataset provided in zenodo is the full dataset (not split into train-test)
# so you'll need to do your dataset split(s) first

In [None]:
itr_datasets = []
for i in range(5):
    data_files = DatasetDict({
        'train': f'./datasets/HD-Flu-CoV/hd-0_flu-1_cov-2_train{i}.csv',
        'test': f'./datasets/HD-Flu-CoV/hd-0_flu-1_cov-2_test{i}.csv'
    })
    split_dataset = load_dataset('csv', data_files=data_files)
    itr_datasets.append(split_dataset)

### Tokenize

In [None]:
tokenized = []
for dataset in itr_datasets:
    tokenized_dataset = dataset.map(
        preprocess_dataset,
        fn_kwargs={
            "tokenizer": tokenizer,
            "max_len": 320,
        },
        batched=True,
        remove_columns=["name", "h_sequence", "l_sequence"]
    )
    tokenized.append(tokenized_dataset)

### Load Model

In [None]:
label2id = {"Healthy-Donor": 0, "Flu-specific": 1, "CoV-specific": 2}
id2label = {0: "Healthy-Donor", 1: "Flu-specific", 2: "CoV-specific"}

Multi-class Metrics:
* https://www.evidentlyai.com/classification-metrics/multi-class-metrics
* https://www.kaggle.com/code/nkitgupta/evaluation-metrics-for-multi-class-classification
* https://discuss.huggingface.co/t/combining-metrics-for-multiclass-predictions-evaluations/21792/11

In [None]:
# Fig 5 presents accuracy, macro-f1, and mcc
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis = -1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "macro-precision": precision_score(labels, preds, average='macro'),
        "macro-recall": recall_score(labels, preds, average='macro'),
        "macro-f1": f1_score(labels, preds, average='macro'),
        "micro-precision": precision_score(labels, preds, average='micro'),
        "micro-recall": recall_score(labels, preds, average='micro'),
        "micro-f1": f1_score(labels, preds, average='micro'),
        "mcc": matthews_corrcoef(labels, preds),
    }

In [None]:
test_results = pd.DataFrame({"itr": [],
                             "test_loss": [],
                             "test_accuracy": [],
                             "test_macro-precision": [],
                             "test_macro-recall": [],
                             "test_macro-f1": [],
                             "test_micro-precision": [],
                             "test_micro-recall": [],
                             "test_micro-f1": [],
                             "mcc": []
                            })

In [None]:
for itr, dataset in enumerate(tokenized):
    run_name = f"{model_str}_HD-Flu-CoV_itr-{itr}_{date.today().isoformat()}"
    
    model = AutoModelForSequenceClassification.from_pretrained(
        checkpoint, 
        num_labels=n_classes,
        label2id=label2id,
        id2label=id2label,
    )
    # use this to freeze the base model weights + train only classification head
    # for param in model.base_model.parameters():
    #     param.requires_grad = False
    
    batch_size = 8 # on 1 gpu (ie. total batch size should equal 8)
    lr = 5e-5
    training_args = TrainingArguments(
        evaluation_strategy = "steps",
        logging_steps=10,
        save_strategy="no",
        eval_steps=10,
        learning_rate=lr,
        per_device_train_batch_size=batch_size, 
        per_device_eval_batch_size=batch_size, 
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type='linear',

        output_dir=f"./checkpoints/{run_name}",
        seed=randint(0, 1024), 
        report_to="wandb",
        logging_dir=f"./logs/{run_name}",
        logging_first_step=True,
        run_name = run_name
    )
    
    wandb.init(
        project = 'specificity-class',
        group="HD-Flu-CoV",
        job_type=model_str,
        name = run_name,
        dir = './',
    )
    
    # train
    trainer = Trainer(
        model,
        args=training_args,
        tokenizer=tokenizer,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        compute_metrics=compute_metrics
    )
    trainer.train()
    trainer.save_model(f"./models/{run_name}")
    wandb.finish()
    
    # evaluate
    logits, labels, metrics = trainer.predict(dataset['test'])
    metrics['itr'] = itr
    test_results = test_results.append(metrics, ignore_index=True)
    
    del model # delete to ensure untrained model is being trained for each dataset

In [None]:
test_results.loc['mean'] = test_results.mean()

In [None]:
test_results.loc['std'] = test_results.std()

In [None]:
test_results.to_csv(f'./results/HD-Flu-CoV_{model_str}.csv')