### HD vs CoV classification task

To adapt this code to the **Flu vs CoV** classification task, you'll need to change the dataset, the class labels, and the batch size (to 8 total). 

In [None]:
import warnings
warnings.simplefilter('ignore')


from transformers import (
    RobertaTokenizer, 
    AutoTokenizer,
    AutoModelForSequenceClassification, 
    Trainer,
    TrainingArguments,
)

import torch

import pandas as pd
import numpy as np

from datasets import (
    DatasetDict,
    ClassLabel,
    load_dataset,
)

import sklearn as skl
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    matthews_corrcoef,
    roc_auc_score,
    average_precision_score
)

import evaluate
accuracy = evaluate.load("accuracy")
from datetime import date
from random import randint

import wandb

In [None]:
# replace with actual model path
checkpoint = './BALM-paired/'

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("../tokenizer/")

### Process and Tokenize Data

In [None]:
class_labels = ClassLabel(names=['Healthy-donor','Sars-specific'])
n_classes = len(class_labels.names)

In [None]:
# the dataset provided in zenodo is the full dataset (not split into train-test)
# so you'll need to do your dataset split(s) first

In [None]:
itr_datasets = []
for i in range(5):
    data_files = DatasetDict({
        'train': f'./datasets/HD-CoV/hd-0_cov-1_train{i}.csv',
        'test': f'./datasets/HD-CoV/hd-0_cov-1_test{i}.csv'
    })
    dataset = load_dataset('csv', data_files=data_files)
    itr_datasets.append(dataset)

### Tokenizer

In [None]:
def preprocess_dataset(
    batch, 
    tokenizer=None, 
    tokenizer_path="./tokenizer", 
    separator="</s>",
    max_len=320
) -> list:
    """
    docstring
    """
    # set up tokenizer if not provided
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, max_len=max_len)
        
    # tokenize the H/L sequence pair
    sequences = [h + separator + l for h, l in zip(batch["h_sequence"], batch["l_sequence"])]
    tokenized = tokenizer(sequences, padding="max_length", max_length=max_len, truncation=True)
    batch["input_ids"] = tokenized.input_ids
    batch["attention_mask"] = tokenized.attention_mask
    
    return batch

In [None]:
tokenized = []
for dataset in itr_datasets:
    tokenized_dataset = dataset.map(
        preprocess_dataset,
        fn_kwargs={
            "tokenizer": tokenizer,
            "max_len": 320,
        },
        batched=True,
        remove_columns=["name", "h_sequence", "l_sequence"]
    )
    tokenized.append(tokenized_dataset)

### Load Model

In [None]:
label2id = {"Healthy-donor": 0, "Sars-specific": 1}
id2label = {0: "Healthy-donor", 1: "Sars-specific"}

In [None]:
# Fig 5 presents accuracy, f1, auc, aupr, and mcc
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    probabilities = torch.softmax(torch.from_numpy(predictions), dim=1).detach().numpy()[:,-1]
    predictions = np.argmax(predictions, axis=1)
    _accuracy = accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    return {
        "accuracy": _accuracy,
        "precision": precision_score(labels, predictions, pos_label=1),
        "recall": recall_score(labels, predictions, pos_label=1),
        "f1": f1_score(labels, predictions, pos_label=1),
        "auc": roc_auc_score(labels, probabilities),
        "aupr": average_precision_score(labels, probabilities, pos_label=1),
        "mcc": matthews_corrcoef(labels, predictions),
    }

In [None]:
test_results = pd.DataFrame({"itr": [],
                             "test_loss": [],
                             "test_accuracy": [],
                             "test_precision": [],
                             "test_recall": [],
                             "test_f1": [],
                             "test_auc": [],
                             "test_aupr": [],
                             "test_mcc": [],
                            })

In [None]:
for n, dataset in enumerate(tokenized):
    itr = n
    run_name = f"BALM-paired_HD-CoV_itr-{itr}_{date.today().isoformat()}"
    
    model = AutoModelForSequenceClassification.from_pretrained(
        checkpoint, 
        num_labels=n_classes,
        label2id=label2id,
        id2label=id2label,
    )
    # use this to freeze the base model weights + train only classification head
    # for param in model.base_model.parameters():
    #     param.requires_grad = False
    
    batch_size = 32 # on 1 gpu (ie. total batch size should equal 32)
    lr = 5e-5
    training_args = TrainingArguments(
        evaluation_strategy = "steps",
        logging_steps=10,
        save_strategy="no",
        eval_steps=10,
        learning_rate=lr,
        per_device_train_batch_size=batch_size, 
        per_device_eval_batch_size=batch_size, 
        num_train_epochs=1,
        warmup_ratio=0.1,
        lr_scheduler_type='linear',

        output_dir=f"./checkpoints/{run_name}",
        seed=randint(0, 1024),
        report_to="wandb",
        logging_dir=f"./logs/{run_name}",
        logging_first_step=True,
        run_name = run_name
    )
    
    wandb.init(
        project = 'specificity-class',
        group="HD-CoV",
        job_type="BALM-paired",
        name = run_name,
        dir = './',
    )
    
    # train
    trainer = Trainer(
        model,
        args=training_args,
        tokenizer=tokenizer,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        compute_metrics=compute_metrics
    )
    trainer.train()
    trainer.save_model(f"./models/{run_name}")
    wandb.finish()
    
    # evaluate
    logits, labels, metrics = trainer.predict(dataset['test'])
    metrics['itr'] = itr
    test_results = test_results.append(metrics, ignore_index=True)
    
    del model # delete to ensure untrained model is being trained for each dataset

In [None]:
test_results.loc['mean'] = test_results.mean()

In [None]:
test_results.loc['std'] = test_results.std()

In [None]:
test_results.to_csv(f'./results/HD-CoV_BALM-paired.csv')