# CoV Classifier Inference

for AttCAT systematic data selection (Fig. 5D)

In [None]:
from transformers import (
    EsmTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)

import torch
import pandas as pd
import numpy as np

import datasets
from datasets import (
    Dataset,
    DatasetDict,
    Sequence,
    Value,
    ClassLabel,
    load_dataset,
)

from tqdm.notebook import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# models and corresponding test data
cov_dict = {
    n: {
        # replace with actual paths to models/test data
        "models": {
            "uniform_250k": f"../models/uniform-250k_itr{n}_50ep_HD-CoV",
            "uniform_350k": f"../models/uniform-350k_itr{n}_50ep_HD-CoV",
            "preferential_250k": f"../models/preferential-250k_itr{n}_50ep_HD-CoV",
        },
        "data": f'./train-test_splits/E_hd-0_cov-1_test{n}.csv',
    }
for n in range(5)}

cov_dict

In [None]:
# run inference on entire test set for all 3 models
for i in tqdm(cov_dict.keys()):

    # load test data
    test_data = pd.read_csv(cov_dict[i]["data"])
    test_preds = test_data.copy() # for storing prediction metrics
    
    class_labels = ClassLabel(names=['Healthy-donor', 'CoV-specific'])
    n_classes = len(class_labels.names)
    label2id = {"Healthy-donor": 0, "CoV-specific": 1}
    id2label = {0: "Healthy-donor", 1: "CoV-specific"}
    
    # make huggingface dataset
    dataset = datasets.Dataset.from_pandas(test_data)
    dataset = dataset.cast_column("label", class_labels)
    
    # filter for length (model has max length of 320 from training)
    def filter_long_sequences(item):
        return (len(item['h_sequence'])+len(item['l_sequence'])) <= 315 # allows 4 tokens (start, sep (which is 2 tokens long), end)
    filtered = dataset.filter(filter_long_sequences)
    
    # tokenizer
    tokenizer = EsmTokenizer.from_pretrained("../tokenizer/vocab.txt")
    
    def preprocess_dataset(
        batch, 
        tokenizer=None, 
        tokenizer_path="./tokenizer", 
        separator="<cls><cls>",
        max_len=320
    ) -> list:
        """
        docstring
        """
        # tokenize the H/L sequence pair
        sequences = [h + separator + l for h, l in zip(batch["h_sequence"], batch["l_sequence"])]
        tokenized = tokenizer(sequences, padding="max_length", max_length=max_len)
        batch["input_ids"] = tokenized.input_ids
        batch["attention_mask"] = tokenized.attention_mask
        
        return batch
    
    # tokenize
    tokenized_dataset = filtered.map(
        preprocess_dataset,
        fn_kwargs={
            "tokenizer": tokenizer,
            "max_len": 320,
        },
        batched=True,
        remove_columns=["name", "h_sequence", "l_sequence"]
    )

    # load each model
    for model_id, model_path in cov_dict[i]["models"].items():
        model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
    
        # predict on test set and get metrics
        trainer = Trainer(
            model=model,
            tokenizer=tokenizer,
            args=TrainingArguments(output_dir="./", 
                                   report_to="none"), # to turn off wandb logging
            eval_dataset=tokenized_dataset,
        )
        logits, labels, metrics = trainer.predict(tokenized_dataset)
        probabilities = torch.softmax(torch.from_numpy(logits), dim=1).detach().numpy()[:, -1]
        predictions = np.argmax(logits, axis=1)
        
        del model # free up memory
        
        # categorize predictions
        pred_data = []
        for pred, prob, label, logit in zip(predictions, probabilities, labels, logits):
            if pred == label == 1:
                category = "true_positive"
            elif pred == label == 0:
                category = "true_negative"
            elif pred == 1 and label == 0:
                category = "false_positive"
            else:
                category = "false_negative"
            pred_data.append(
                {
                    # "label": label,
                    f"{model_id}_prediction": pred,
                    f"{model_id}_probability": prob,
                    f"{model_id}_category": category,
                    f"{model_id}_logits": logit,
                }
            )
        pred_df = pd.DataFrame(pred_data)
        # pred_df["is_correct"] = pred_df["prediction"] == pred_df["label"]
        
        # store predictive performance with its corresponding sequence
        test_preds = pd.concat([test_preds, pred_df], axis=1)

    # count predictions that are consistent across all models (for AttCAT systematic selection)
    test_preds["consistent_prediction"] = test_preds.apply(lambda x: x["uniform_250k_prediction"] if (x["uniform_250k_prediction"] == x["uniform_350k_prediction"] == x["preferential_250k_prediction"]) else None, axis = 1)
    test_preds["consistent_category"] = test_preds.apply(lambda x: x["uniform_250k_category"] if (x["uniform_250k_category"] == x["uniform_350k_category"] == x["preferential_250k_category"]) else None, axis = 1)

    # average prediction probability
    test_preds["avg_probability"] = test_preds.apply(lambda x: x[["uniform_250k_probability", "uniform_350k_probability", "preferential_250k_probability"]].mean(), axis = 1)
    
    # save as csv
    test_preds.to_csv(f"./results/CoV_predictions_itr{i}.csv", index=False)
    