# Active Fine-tuning of PLMs

## Objectives
- Trial three approaches for active fine-tuning of ESM2 on our Pikh1 HMA data
    1. train an ensemble of models and use the mean and variance of their predictions to guide learning
    2. use dropout layer during prediction, find mean and variance of predictions
    3. mean variance estimation fine-tuning to predict two values with gaussian negative log-likelihood loss
- Compare data efficiency between each method (that is, how many training labels are needed to achieve the same spearman r on a universal test set).

## Prepare data splits
Starting data is from FACS of surface displayed Pikh1 HMA variants when exposed to 1 uM AVR-PikC. There are 3960 labels in this dataset, each with a full sequence and an enrichment score, roughly correlated with binding affinity. We'll load these into a torch dataset, then split into 80% train, 10% val, and 10% test. The training data will be what we use for all active learning loops, validated on the val set. The test set will be used as a universal final test set for all models.

In [2]:
import pandas as pd

df = pd.read_csv('avrpikC_full.csv')

df.head()

Unnamed: 0,aa_sequence,enrichment_score
0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.468796
1,GLKRIIVIKVAREGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.415944
2,GLKRIIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.389615
3,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.359651
4,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.343857


In [3]:
sequences = df.aa_sequence
scores = df.enrichment_score

In [6]:
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class BindingDataset(Dataset):
    def __init__(self, sequences, scores):
        # make sure sequence and scores have the same length
        assert len(sequences) == len(scores), f"Sequences and scores must be of the same length.\nNumber of sequences: {len(sequences)}\nNumber of scores: {len(scores)}"
        self.sequences = sequences
        self.scores = scores
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self,idx):
        sequence = self.sequences[idx]
        label = torch.tensor(self.scores[idx], dtype=torch.float)

        # tokenize the sequence
        tokenized = self.tokenizer(
            sequence,
            max_length=80, # 78 residues + 2 extra tokens
            return_tensors='pt'
        )

        # return input_ids: attention masks, removing the batch dimension
        inputs = {key: val.squeeze(0) for key, val in tokenized.items()}

        return inputs, label

In [7]:
from torch.utils.data import random_split, DataLoader
torch.manual_seed(42)
torch.cuda.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else 'cpu'

BATCH_SIZE = 16

full_dataset = BindingDataset(sequences, scores)

# split the data into train, val, and test sets
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# define dataloaders for val and test sets, train will be defined later for subsets
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## The active learning loop
1. Initialize model 1: pretrained esm2 with a randomized binding head
2. Train on an initial batch of 24 samples (test random and zero-shot predictions).
    - Initial hyperparameters (note that during an actual active learning run, you can't change this in the middle of the campaign)
        - optimizer: AdamW
        - learning rate: 2e-5
        - weight decay: 0.01
        - early stopping
3. Evaluate on validation set
4. Evalutate on the rest of the available training pool
5. Initialize model 2
6. Repeat 2-5 until 5 models have been trained and used to evalutate
7. Calculate variance of model predictions on each sequence in the training pool.
8. Select the next 24 sequences with the highest variance.
9. Reinitialize model 1 with pretrained esm2 and a freshly randomized binding head.
10. Repeat ensemble training, evaluation, and acquisition

- Notes about alternative approaches for the baseline fine-tuning approach:
    - Fine-tuning ESM2 is still under active research, and recent literature suggests that just using the CLS token for mutant effects may not be the best approach
    - [ESM Effect](https://www.biorxiv.org/content/10.1101/2025.02.03.635741v1.full.pdf) details a framework for using the mutant token vectors as the basis for predicting mutant effects.
        - Here, they use 35M model with last two layers unfrozen. They use the mutant tokens to predict mutant effects. They have a prediction head composed of two linear layers. However, their datasets are of deep mutational scans, so only one mutation. I would have to adapt this for higher numbers of mutations.
            - to adapt it for a higher number of mutations, I could pool, concatenate, or use an attention mechanism
- For now, I'm going to set up the active training loop using the standard CLS token approach (aka, what the AutoModelForSequenceClassification class does from the transformers library), but I will need to separately experiment with other approaches.

In [8]:
from transformers import AutoModelForSequenceClassification
from torch.nn import MSELoss
from torch.optim import AdamW
import tqdm
from torchmetrics.regression import SpearmanCorrCoef

MODEL_NAME = "facebook/esm2_t12_8M_UR50D"
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
ES_PATIENCE = 8

def initialize_HF_ESM2():
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels = 1)
    # loss_fn = MSELoss() HuggingFace automatically handles the loss
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    spearman = SpearmanCorrCoef()
    return model, optimizer, spearman #,loss_fn

def train_step(model, optimizer, train_dataloader):
    model.to(device)
    model.train()
    total_train_loss = 0
    for inputs, labels in tqdm(train_dataloader, desc=f"[Training]"):
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(**inputs, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_dataloader)
    return avg_train_loss

def val_step(model, val_dataloader, spearman):
    model.eval()
    total_val_loss = 0

    all_preds = []
    all_labels = []

    with torch.inference_mode():
        for inputs, labels in tqdm(val_dataloader, desc=f"[Validation]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            outputs = model(**inputs, labels=labels)
            preds = outputs.logits
            loss = outputs.loss

            total_val_loss += loss.item()

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_val_loss = total_val_loss / len(val_dataloader)

    spearmanr = spearman(all_preds, all_labels).item()

    return avg_val_loss, spearmanr

In [9]:
def training_loop(epochs, train_dataloader, val_dataloader):
    model, optimizer, spearman = initialize_HF_ESM2()
    for epoch in range(epochs):
        print(f'\nEpoch {epoch + 1}/{epochs}\n----------------------------')
        train_loss = train_step(model, optimizer, train_dataloader)
        val_loss, spearmanr = val_step(model, val_dataloader, spearman)
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | SpearmanR: {spearmanr:.4f}')

In [None]:
# get initial train batch given dataset, num samples
    # get the total number of samples in the training pool
    # list out their indicies
    # randomly selected the given number of those
    # store unlabeled indices
    # make subsets
    # return train dataloader, pool dataloader, labeled indicies, and unlabeled indicies

# get new train batch given dataset, acquisition scores, currently labeled, unlabeled indicies, num samples
    # get the indicies of the top acquisition scores (num of samples)
        # handle when remaining samples < num samples
    # use these to find the indicies that map back to the original dataset
    # add these to the currently labeled indicies, remove them from unlabeled indicies
    # make subsets
    # return train dataloader, pool dataloader, labeled indicies, and unlabeled indicies

In [None]:
# get model predictions on unlabeled dataset given pool dataloader
    # set the model to eval mode
    # define list to store predictions
    # set inference mode
        # iterate through pool loader
            # get model predictions, append them to list (num batches, batch size)
    # concat predictions from all batches for a single prediction tensor
    # return full prediction tensor

# train ensemble given training subset, pool subset, validation set, and number of models
    # define list to store ensemble predictions
    # for i in number of models
        # set manual seed i
        # initialize model
        # run training loop
        # get model predictions on pool subset, append to list 
    # return list of ensemble predictions

In [None]:
# get acquisition scores (variance) given model predictions
    # calculate variance for each index
    # return list of acquisition scores

In [None]:
# run active learning campaign given num samples, num samples per batch
    # calculate num cycles
    # for each cycle
        # if first cycle
            # get initial train batch
        # else
            # get acquisition scores
            # get new train batch
        # train ensemble, storing predictions
    # train final model using last train dataloader
    # return final model

In [None]:
# run standard fine-tuning procedure given num samples
    # get dataloader of random train data
    # train model
    # return model

- More notes
    - This simulation plays out in a defined pool of samples, how will this work in the real world, where there's practically infinite options?
        - first, you need to define the limits of the space. probably something like a limited number of mutations per sequence, and perhaps a distribution of those mutation numbers in your final set
        - boring approach: generate mutants and evaluate ensemble on mutants for a defined amount of time, select the highest variance mutants from that set.
        - cool approach: train an adversarial model that learns how to optimally challenge the plm model