# Active Fine-tuning of PLMs

## Objectives
- Trial three approaches for active fine-tuning of ESM2 on our Pikh1 HMA data
    1. train an ensemble of models and use the mean and variance of their predictions to guide learning
    2. use dropout layer during prediction, find mean and variance of predictions
    3. mean variance estimation fine-tuning to predict two values with gaussian negative log-likelihood loss
- Compare data efficiency between each method (that is, how many training labels are needed to achieve the same spearman r on a universal test set).

## Prepare data splits
Starting data is from FACS of surface displayed Pikh1 HMA variants when exposed to 1 uM AVR-PikC. There are 3960 labels in this dataset, each with a full sequence and an enrichment score, roughly correlated with binding affinity. We'll load these into a torch dataset, then split into 80% train, 10% val, and 10% test. The training data will be what we use for all active learning loops, validated on the val set. The test set will be used as a universal final test set for all models.

In [8]:
import pandas as pd

df = pd.read_csv('avrpikC_full.csv')

df.head()

Unnamed: 0,aa_sequence,enrichment_score
0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.468796
1,GLKRIIVIKVAREGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.415944
2,GLKRIIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.389615
3,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,1.359651
4,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,1.343857


In [9]:
sequences = df.aa_sequence
scores = df.enrichment_score

In [10]:
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class BindingDataset(Dataset):
    def __init__(self, sequences, scores):
        # make sure sequence and scores have the same length
        assert len(sequences) == len(scores), f"Sequences and scores must be of the same length.\nNumber of sequences: {len(sequences)}\nNumber of scores: {len(scores)}"
        self.sequences = sequences
        self.scores = scores
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self,idx):
        sequence = self.sequences[idx]
        label = torch.tensor(self.scores[idx], dtype=torch.float)

        # tokenize the sequence
        tokenized = self.tokenizer(
            sequence,
            max_length=80, # 78 residues + 2 extra tokens
            return_tensors='pt'
        )

        # return input_ids: attention masks, removing the batch dimension
        inputs = {key: val.squeeze(0) for key, val in tokenized.items()}

        return inputs, label

In [11]:
from torch.utils.data import random_split, DataLoader
torch.manual_seed(42)
torch.cuda.manual_seed(42)

device = "cuda" if torch.cuda.is_available() else 'cpu'

BATCH_SIZE = 12

full_dataset = BindingDataset(sequences, scores)

# split the data into train, val, and test sets
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
training_pool, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# define dataloaders for val and test sets, train will be defined later for subsets
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Approach 1: Ensemble Predictions
1. Initialize model 1: pretrained esm2 with a randomized binding head
2. Train on an initial batch of 24 samples (test random and zero-shot predictions).
    - Initial hyperparameters (note that during an actual active learning run, you can't change this in the middle of the campaign)
        - optimizer: AdamW
        - learning rate: 2e-5
        - weight decay: 0.01
        - early stopping
3. Evaluate on validation set
4. Evalutate on the rest of the available training pool
5. Initialize model 2
6. Repeat 2-5 until 5 models have been trained and used to evalutate
7. Calculate variance of model predictions on each sequence in the training pool.
8. Select the next 24 sequences with the highest variance.
9. Reinitialize model 1 with pretrained esm2 and a freshly randomized binding head.
10. Repeat ensemble training, evaluation, and acquisition

- Notes about alternative approaches for the baseline fine-tuning approach:
    - Fine-tuning ESM2 is still under active research, and recent literature suggests that just using the CLS token for mutant effects may not be the best approach
    - [ESM Effect](https://www.biorxiv.org/content/10.1101/2025.02.03.635741v1.full.pdf) details a framework for using the mutant token vectors as the basis for predicting mutant effects.
        - Here, they use 35M model with last two layers unfrozen. They use the mutant tokens to predict mutant effects. They have a prediction head composed of two linear layers. However, their datasets are of deep mutational scans, so only one mutation. I would have to adapt this for higher numbers of mutations.
            - to adapt it for a higher number of mutations, I could pool, concatenate, or use an attention mechanism
- For now, I'm going to set up the active training loop using the standard CLS token approach (aka, what the AutoModelForSequenceClassification class does from the transformers library), but I will need to separately experiment with other approaches.

In [12]:
from transformers import AutoModelForSequenceClassification
from torch.nn import MSELoss
from torch.optim import AdamW
from tqdm import tqdm
from torchmetrics.regression import SpearmanCorrCoef
from transformers import logging

logging.set_verbosity_error() # supress initialization warnings

def initialize_HF_ESM2(model_name, learning_rate, weight_decay):
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = 1)
    # loss_fn = MSELoss() : HuggingFace automatically handles the loss
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    spearman = SpearmanCorrCoef()
    return model, optimizer, spearman #,loss_fn

def train_step(model, optimizer, train_dataloader):
    model.to(device)
    model.train()
    total_train_loss = 0
    for inputs, labels in train_dataloader:
        inputs = {k: v.to(device) for k, v in inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(**inputs, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_dataloader)
    return avg_train_loss

def val_step(model, val_dataloader, spearman):
    model.eval()
    total_val_loss = 0

    all_preds = []
    all_labels = []

    with torch.inference_mode():
        for inputs, labels in val_dataloader:
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            outputs = model(**inputs, labels=labels)
            preds = outputs.logits.squeeze() # to make sure dimensions are the same for spearman
            loss = outputs.loss

            total_val_loss += loss.item()

            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_val_loss = total_val_loss / len(val_dataloader)

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    spearmanr = spearman(all_preds, all_labels).item()

    return avg_val_loss, spearmanr

In [13]:
def initialize_and_train_new_model(
        model_name, 
        learning_rate, 
        weight_decay,
        epochs, 
        train_dataloader, 
        val_dataloader, 
        patience=5,
        return_history=False,
        checkpoint_path="best_model.pth"
        ):
    
    model, optimizer, spearman = initialize_HF_ESM2(model_name, learning_rate, weight_decay)

    # initialize variables for early stopping
    best_val_spearman = -1
    epochs_wo_improvement = 0

    # initialize lists to store metrics
    train_loss_history = []
    val_loss_history = []
    spearmanr_history = []

    # main training loop
    for epoch in tqdm(range(epochs), desc="[Training]"):
        train_loss = train_step(model, optimizer, train_dataloader)
        val_loss, spearmanr = val_step(model, val_dataloader, spearman)

        train_loss_history.append(train_loss)
        val_loss_history.append(val_loss)
        spearmanr_history.append(spearmanr)

        # early stopping logic
        if spearmanr > best_val_spearman:
            best_val_spearman = spearmanr
            epochs_wo_improvement = 0
            # save the best model for later
            torch.save(model.state_dict(), checkpoint_path)
        else:
            epochs_wo_improvement += 1
        
        if epochs_wo_improvement == patience:
            print(f"Early stopping triggered after {patience} epochs with no improvement.")
            break
        
    print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | SpearmanR: {spearmanr:.4f}')

    # load the best model before output
    model.load_state_dict(torch.load(checkpoint_path))

    if return_history:
        history = {
            'train_loss': train_loss_history,
            'val_loss': val_loss_history,
            'spearmanr': spearmanr_history
        }
        return model, history

    return model

In [14]:
import numpy as np
from torch.utils.data import Subset

def acquire_new_batch(dataset, initial_batch_size, batch_size_to_acquire, labeled_indices, unlabeled_indices, acquisition_scores=None):
    # if initial batch, when there are no acquisition scores, select randomly
    if acquisition_scores is None:
        initial_batch_size = min(initial_batch_size, len(unlabeled_indices))
        indices_to_acquire = np.random.choice(unlabeled_indices, size=initial_batch_size, replace=False)
    
    # else select based on top acquisition scores
    else:
        # make sure we don't overshoot samples to acquire if on the final batch
        batch_size_to_acquire = min(batch_size_to_acquire, len(acquisition_scores))
        # get the indicies of the top acquisition scores (num of samples)
        top_k_indices = acquisition_scores.topk(batch_size_to_acquire).indices
        # use these to find the indicies that map back to the original dataset
        indices_to_acquire = unlabeled_indices[top_k_indices.cpu().numpy()]
    
    # update the indices lists
    labeled_indices = np.concatenate([labeled_indices, indices_to_acquire])
    unlabeled_indices = np.setdiff1d(unlabeled_indices, indices_to_acquire, assume_unique=True)
    
    # create new subsets and dataloaders
    train_subset = Subset(dataset, labeled_indices.tolist())
    pool_subset = Subset(dataset, unlabeled_indices.tolist())
    train_dataloader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
    pool_dataloader = DataLoader(pool_subset, batch_size=BATCH_SIZE, shuffle=False)
    
    return train_dataloader, pool_dataloader, labeled_indices, unlabeled_indices

In [15]:
def get_model_predictions(model, pool_dataloader):
    model.eval()
    all_preds = []

    with torch.inference_mode():
        # iterate through pool loader
        for inputs, labels in tqdm(pool_dataloader, desc=f"[Surveying]"):
            # get model predictions, append them to list (num batches, batch size)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            outputs = model(**inputs, labels=labels)
            preds = outputs.logits
            all_preds.append(preds.cpu())
    
    # concat predictions from all batches for a single prediction tensor
    all_preds = torch.cat(all_preds)
    return all_preds

In [16]:
def train_ensemble(
        n_models, 
        model_name, 
        learning_rate,
        weight_decay,
        epochs,
        train_dataloader, 
        pool_dataloader, 
        val_dataloader,
        patience
        ):
    
    # define list to store predictions as each model is trained then evaluated
    ensemble_predictions = []
    
    for i in range(n_models):
        print(f"\nTraining Model {i+1}...")
        # set a changing manual seed
        torch.manual_seed(i)
        torch.cuda.manual_seed(i)

        # initialize and train a new model
        model = initialize_and_train_new_model(model_name, learning_rate, weight_decay, epochs, train_dataloader, val_dataloader, patience)
        
        # get model predictions on pool dataloader, append to ensemble predictions list
        pool_preds = get_model_predictions(model, pool_dataloader)
        ensemble_predictions.append(pool_preds)

    # stack ensemble predictions to create tensor of shape (n_models, n_unlabeled_samples)
    ensemble_predictions = torch.stack(ensemble_predictions, dim=0)
    print("Ensemble training complete, submitting predictions for next cycle.")
    # return list of ensemble predictions
    return ensemble_predictions

In [17]:
# get acquisition scores (variance) given model predictions
def get_acquisition_scores(ensemble_predictions):
    # calculate variance for each index
    variances = torch.var(ensemble_predictions, dim=0)
    # return list of acquisition scores
    return variances.squeeze()

In [None]:
from IPython.display import clear_output

# run active learning campaign given num samples, num samples per batch
def run_active_learning_campaign(
        n_samples,
        initial_n_samples,
        n_samples_per_batch,
        model_name, 
        learning_rate,
        weight_decay,
        epochs,
        training_pool, 
        val_dataloader,
        patience,
        n_models
        ):
    # initialize index lists
    total_pool_size = len(training_pool)
    unlabeled_indices = np.arange(total_pool_size)
    labeled_indices = np.array([], dtype=np.int64)

    ensemble_predictions = None
    current_cycle = 1
    
    while len(labeled_indices) < n_samples and len(unlabeled_indices) > 0:
        print(f"\nCycle {current_cycle}/{int(np.ceil((n_samples-initial_n_samples)/n_samples_per_batch)) + 1}\n-------------------------------------------------")

        # on the first cycle, choose random samples of initial_n_samples size
        if ensemble_predictions is None:
            print(f"Choosing initial {initial_n_samples} samples randomly...")
            train_dataloader, pool_dataloader, labeled_indices, unlabeled_indices = acquire_new_batch(
                training_pool, initial_n_samples, n_samples_per_batch, labeled_indices, unlabeled_indices, acquisition_scores=None
            )
        # each other time, use the n_samples_per_batch with acquisition scores to select
        else:
            scores = get_acquisition_scores(ensemble_predictions)
            print(f"Selecting new data points...")
            train_dataloader, pool_dataloader, labeled_indices, unlabeled_indices = acquire_new_batch(
                training_pool, initial_n_samples, n_samples_per_batch, labeled_indices, unlabeled_indices, acquisition_scores=scores
            )
        
        if len(unlabeled_indices) == 0:
            print("Unlabeled pool is empty. Proceeding to final model training.")
            break

        print("Starting ensemble training and pool evaluation...")
        ensemble_predictions = train_ensemble(n_models, model_name, learning_rate, weight_decay, epochs, train_dataloader, pool_dataloader, val_dataloader, patience)

        current_cycle += 1

    print("\nActive learning campaign complete.")
    print(f"Training final model on {len(labeled_indices)} actively selected samples...")
    model, history = initialize_and_train_new_model(model_name, learning_rate, weight_decay, epochs, train_dataloader, val_dataloader, patience, return_history=True)
    
    return model, history

In [44]:
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
EPOCHS = 100
PATIENCE = 5
N_MODELS = 5

model_active, history_active = run_active_learning_campaign(
        n_samples = 144,
        initial_n_samples = 72,
        n_samples_per_batch = 24,
        model_name = MODEL_NAME, 
        learning_rate = LEARNING_RATE,
        weight_decay = WEIGHT_DECAY,
        epochs = EPOCHS,
        training_pool = training_pool, 
        val_dataloader = val_dataloader,
        patience = PATIENCE,
        n_models = N_MODELS
        )


Cycle 6/3
-------------------------------------------------
Selecting new data points...
Starting ensemble training and pool evaluation...

Training Model 1...


[Training]:  19%|█▉        | 19/100 [00:09<00:40,  2.01it/s]


Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.0612 | Val Loss: 0.1342 | SpearmanR: 0.5872


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 149.07it/s]



Training Model 2...


[Training]:  11%|█         | 11/100 [00:05<00:44,  2.02it/s]


Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.1146 | Val Loss: 0.2053 | SpearmanR: 0.5294


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 152.78it/s]



Training Model 3...


[Training]:  21%|██        | 21/100 [00:10<00:40,  1.95it/s]


Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.0607 | Val Loss: 0.1454 | SpearmanR: 0.5548


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 152.20it/s]



Training Model 4...


[Training]:  18%|█▊        | 18/100 [00:08<00:39,  2.06it/s]


Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.0705 | Val Loss: 0.1532 | SpearmanR: 0.5010


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 155.94it/s]



Training Model 5...


[Training]:   9%|▉         | 9/100 [00:04<00:45,  1.98it/s]


Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.1030 | Val Loss: 0.2079 | SpearmanR: 0.4205


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 154.81it/s]


Ensemble training complete, submitting predictions for next cycle.

Active learning campaign complete.
Training final model on 144 actively selected samples...


[Training]:  22%|██▏       | 22/100 [00:10<00:38,  2.00it/s]

Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.0614 | Val Loss: 0.1512 | SpearmanR: 0.5417





In [19]:
# run standard fine-tuning procedure given num samples
def run_standard_finetuning(
        n_samples, 
        model_name, 
        learning_rate, 
        weight_decay, 
        epochs, 
        training_pool, 
        val_dataloader,
        patience,
        ):
    # get dataloader of random train data
    random_indices = torch.randperm(len(training_pool))[:n_samples].tolist()
    train_subset = Subset(training_pool, random_indices)
    train_dataloader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)

    # train model
    model, history = initialize_and_train_new_model(model_name, learning_rate, weight_decay, epochs, train_dataloader, val_dataloader, patience, return_history=True)
    return model, history

In [46]:
model_standard, history_standard = run_standard_finetuning(
        n_samples = 144,
        model_name = MODEL_NAME, 
        learning_rate = LEARNING_RATE,
        weight_decay = WEIGHT_DECAY,
        epochs = EPOCHS,
        training_pool = training_pool, 
        val_dataloader = val_dataloader,
        patience = PATIENCE
        )

[Training]:  11%|█         | 11/100 [00:06<00:52,  1.70it/s]

Early stopping triggered after 5 epochs with no improvement.
Train Loss: 0.1754 | Val Loss: 0.1673 | SpearmanR: 0.4616





In [20]:
import torch
from tqdm import tqdm
from torchmetrics.regression import SpearmanCorrCoef, PearsonCorrCoef, MeanSquaredError

def test_model(model, test_dataloader, return_results=False):
    # Initialize metrics
    spearman = SpearmanCorrCoef().to(device)
    pearson = PearsonCorrCoef().to(device)
    mse = MeanSquaredError().to(device)

    model.to(device)
    model.eval()
    
    total_test_loss = 0
    all_preds = []
    all_labels = []

    with torch.inference_mode():
        for inputs, labels in tqdm(test_dataloader, desc="[Testing]"):
            inputs = {k: v.to(device) for k, v in inputs.items()}
            labels = labels.to(device)

            outputs = model(**inputs, labels=labels)
            preds = outputs.logits.squeeze()
            loss = outputs.loss

            total_test_loss += loss.item()

            all_preds.append(preds)
            all_labels.append(labels)

    # Concatenate all predictions and labels from all batches
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Calculate final metrics
    avg_test_loss = total_test_loss / len(test_dataloader)
    spearmanr = spearman(all_preds, all_labels).item()
    pearsonr = pearson(all_preds, all_labels).item()
    final_mse = mse(all_preds, all_labels).item()

    if return_results:
        results = {
            "avg_test_loss": avg_test_loss,
            "spearmanr": spearmanr,
            "pearsonr": pearsonr,
            "final_mse": final_mse
        }
        return results
    else:
        # Print the report
        print(f"Spearman's Rho: {spearmanr:.4f}")
        print(f"Pearson's Rho: {pearsonr:.4f}")
        print(f"Mean Squared Error (MSE): {final_mse:.4f}")

In [47]:
test_model(model_active, test_dataloader)

[Testing]:   0%|          | 0/33 [00:00<?, ?it/s]

[Testing]: 100%|██████████| 33/33 [00:00<00:00, 76.09it/s]

Test Loss: 0.1396
Spearman's Rho: 0.5594
Pearson's Rho: 0.5615
Mean Squared Error (MSE): 0.1396





In [48]:
test_model(model_standard, test_dataloader)

[Testing]: 100%|██████████| 33/33 [00:00<00:00, 159.47it/s]

Test Loss: 0.1576
Spearman's Rho: 0.5311
Pearson's Rho: 0.4833
Mean Squared Error (MSE): 0.1576





In [49]:
df_active_results = pd.DataFrame(history_active)
df_standard_results = pd.DataFrame(history_standard)

df_active_results

Unnamed: 0,train_loss,val_loss,spearmanr
0,0.292849,0.237494,0.071654
1,0.260006,0.23866,0.171154
2,0.256437,0.232513,0.250838
3,0.24837,0.258319,0.398959
4,0.253237,0.201234,0.36116
5,0.216148,0.194606,0.412462
6,0.195385,0.186426,0.433561
7,0.155915,0.22753,0.436336
8,0.150174,0.200697,0.455136
9,0.172423,0.166411,0.46465


- More notes
    - This simulation plays out in a defined pool of samples, how will this work in the real world, where there's practically infinite options?
        - first, you need to define the limits of the space. probably something like a limited number of mutations per sequence, and perhaps a distribution of those mutation numbers in your final set
        - boring approach: generate mutants and evaluate ensemble on mutants for a defined amount of time, select the highest variance mutants from that set.
        - cool approach: train an adversarial model that learns how to optimally challenge the plm model

Anyways, now that we have the active learning loop working, let's start running some experiments to better understand how different parameters effect performance.

In [50]:
from pathlib import Path
# define the experiment conditions as a list of dicts with the hyperparameters, and 
# two extra entries, one with the parameter currently changing, and the other with 
# the local experiment index

# define a results list (will be a list of dicts)
def run_experiments(experiments, training_pool, val_dataloader, test_dataloader, results_path='active_vs_standard_results.csv'):
    results_path = Path(results_path)

    # Load existing results if the file exists, otherwise start with a fresh DataFrame.
    if results_path.exists():
        all_results_df = pd.read_csv(results_path)
    else:
        all_results_df = pd.DataFrame()

    final_results = []
    # loop through experiment conditions
    for i, exp in enumerate(experiments):
        print(f"\nEXPERIMENT {i}\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        # run active learning campaign, ignore history for now
        params_active = {
            "n_samples": exp["n_samples"],
            "initial_n_samples": exp["initial_n_samples"],
            "n_samples_per_batch": exp["n_samples_per_batch"],
            "model_name": exp["model_name"], 
            "learning_rate": exp["learning_rate"],
            "weight_decay": exp["weight_decay"],
            "epochs": exp["epochs"],
            "patience": exp["patience"],
            "n_models": exp["n_models"]
        }
        model, _ = run_active_learning_campaign(
            **params_active, 
            training_pool=training_pool, 
            val_dataloader=val_dataloader
            )
        # run model on test set, returning results
        results = test_model(model, test_dataloader, return_results=True)
        # add to the results dict the changing var, local experiment idx, the value of the changing var, and training method active
        results = {
            'changing_var': exp['changing_var'],
            'local_exp_idx': exp['local_exp_idx'],
            'value': params_active[exp['changing_var']],
            'training_method': 'active',
            **results
        }
        # append dict to results list
        final_results.append(results)

        # run standard fine-tuning, ignore history for now
        print(f"\nTraining using standard approach, with {exp['n_samples_per_batch']} randomly selected samples...")
        params_standard = {
            "n_samples": exp["n_samples"],
            "model_name": exp["model_name"], 
            "learning_rate": exp["learning_rate"],
            "weight_decay": exp["weight_decay"],
            "epochs": exp["epochs"],
            "patience": exp["patience"],
        }
        model, _ = run_standard_finetuning(
            **params_standard, 
            training_pool=training_pool, 
            val_dataloader=val_dataloader
            )
        # run model on test set, returning results
        results = test_model(model, test_dataloader, return_results=True)
        # add to the results dict the changing var, local experiment idx, the value of the changing var, and training method standard
        results = {
            'changing_var': exp['changing_var'],
            'local_exp_idx': exp['local_exp_idx'],
            'value': params_active[exp['changing_var']],
            'training_method': 'standard',
            **results
        }
        # append dict to results list
        final_results.append(results)

        # save to disk each time to save progress
        results_df = pd.DataFrame(final_results)
        all_results_df = pd.concat([all_results_df, results_df], ignore_index=True)
        all_results_df.to_csv(results_path, index=False)
        print(f"Progress for experiment {i} appended to {results_path}")
    return all_results_df

In [None]:
experiments_initial_n = [
    {
        "changing_var": "n_samples",
        "local_exp_idx": 0,
        "n_samples": 48,
        "initial_n_samples": 24,
        "n_samples_per_batch": 24,
        "model_name": "facebook/esm2_t6_8M_UR50D", 
        "learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epochs": 50,
        "patience": 10,
        "n_models": 5
    },
    {
        "changing_var": "n_samples",
        "local_exp_idx": 1,
        "n_samples": 72,
        "initial_n_samples": 24,
        "n_samples_per_batch": 24,
        "model_name": "facebook/esm2_t6_8M_UR50D", 
        "learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epochs": 50,
        "patience": 10,
        "n_models": 5
    },
    {
        "changing_var": "n_samples",
        "local_exp_idx": 2,
        "n_samples": 96,
        "initial_n_samples": 24,
        "n_samples_per_batch": 24,
        "model_name": "facebook/esm2_t6_8M_UR50D", 
        "learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epochs": 50,
        "patience": 10,
        "n_models": 5
    },
    {
        "changing_var": "n_samples",
        "local_exp_idx": 3,
        "n_samples": 144,
        "initial_n_samples": 24,
        "n_samples_per_batch": 24,
        "model_name": "facebook/esm2_t6_8M_UR50D", 
        "learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epochs": 50,
        "patience": 10,
        "n_models": 5
    },
    {
        "changing_var": "n_samples",
        "local_exp_idx": 4,
        "n_samples": 240,
        "initial_n_samples": 24,
        "n_samples_per_batch": 24,
        "model_name": "facebook/esm2_t6_8M_UR50D", 
        "learning_rate": 2e-5,
        "weight_decay": 0.01,
        "epochs": 50,
        "patience": 10,
        "n_models": 5
    },
]


In [49]:
run_experiments(experiments_initial_n, training_pool, val_dataloader, test_dataloader)


Cycle 4/4
-------------------------------------------------
Selecting new data points...
Starting ensemble training and pool evaluation...

Training Model 1...


[Training]: 100%|██████████| 50/50 [00:26<00:00,  1.87it/s]


Train Loss: 0.0043 | Val Loss: 0.1255 | SpearmanR: 0.6381


[Surveying]: 100%|██████████| 252/252 [00:00<00:00, 1658.27it/s]



Training Model 2...


[Training]:  76%|███████▌  | 38/50 [00:21<00:06,  1.80it/s]


Early stopping triggered after 10 epochs with no improvement.
Train Loss: 0.0178 | Val Loss: 0.1302 | SpearmanR: 0.6569


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 126.51it/s]



Training Model 3...


[Training]: 100%|██████████| 50/50 [00:24<00:00,  2.00it/s]


Train Loss: 0.0086 | Val Loss: 0.1421 | SpearmanR: 0.6605


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 131.49it/s]



Training Model 4...


[Training]:  70%|███████   | 35/50 [00:17<00:07,  1.99it/s]


Early stopping triggered after 10 epochs with no improvement.
Train Loss: 0.0259 | Val Loss: 0.1164 | SpearmanR: 0.6514


[Surveying]: 100%|██████████| 252/252 [00:01<00:00, 126.84it/s]



Training Model 5...


[Training]:  74%|███████▍  | 37/50 [00:21<00:07,  1.73it/s]


Early stopping triggered after 10 epochs with no improvement.
Train Loss: 0.0157 | Val Loss: 0.1068 | SpearmanR: 0.6706


[Surveying]: 100%|██████████| 252/252 [00:02<00:00, 116.26it/s]


Ensemble training complete, submitting predictions for next cycle.

Active learning campaign complete.
Training final model on 144 actively selected samples...


[Training]:  78%|███████▊  | 39/50 [00:22<00:06,  1.75it/s]


Early stopping triggered after 10 epochs with no improvement.
Train Loss: 0.0182 | Val Loss: 0.1454 | SpearmanR: 0.6433


[Testing]: 100%|██████████| 33/33 [00:00<00:00, 116.89it/s]



Training using standard approach, with 24 randomly selected samples...


[Training]:  86%|████████▌ | 43/50 [00:23<00:03,  1.85it/s]


Early stopping triggered after 10 epochs with no improvement.
Train Loss: 0.0232 | Val Loss: 0.1222 | SpearmanR: 0.6210


[Testing]: 100%|██████████| 33/33 [00:00<00:00, 122.94it/s]

Progress for experiment 1 appended to active_vs_standard_results.csv





Unnamed: 0,changing_var,local_exp_idx,value,training_method,avg_test_loss,spearmanr,pearsonr,final_mse
0,initial_n_samples,1,48,active,0.152872,0.615488,0.619794,0.152872
1,initial_n_samples,1,48,standard,0.145488,0.505765,0.539804,0.145488
2,initial_n_samples,2,72,active,0.114177,0.666418,0.684209,0.114177
3,initial_n_samples,2,72,standard,0.11918,0.631235,0.642157,0.11918
