# Diagnostic 1: Testing Read Count Relationship to Model Fine-tuning Performance

I have tried several different active learning schemes, with none of them yielding an improvement in overall performance, and some of them yielding worse performance. This suggests that there's some kind of systematic bias being introduced during the active learning process. One hypothesis is that the active learning process trends towards uncommon sequences that have reduced read counts, and are therefore noisier measurements. To test this, I'll see if there's any relationship between model performance on limited data and the threshold of read counts used to filter the training pool from which the model draws from each time. If read counts alone are inducing this bias, then we would expect to see improved model performance as the threshold increases.

## Data Processing

First, we have to modify the way we process the sequencing data to output a file that includes the read counts.

In [2]:
import pandas as pd
import numpy as np

pd.set_option('display.width', 1000)

def preprocess_data(df): # receives raw df from excel document provided by Genewiz Amplicon EZ, outputs df with receptor sequences and read counts.
  # remove indels
  df_no_indel = df[df['IndelLength'] == 0]
  # remove truncated sequences
  df_no_trunc = df_no_indel[df_no_indel['TargetAA'].str.len() == 99]
  # remove adapters
  df_no_adapters = df_no_trunc.copy()
  df_no_adapters['TargetAA'] = df_no_adapters["TargetAA"].str[19:-2]
  # remove x-containing sequences
  df_no_x = df_no_adapters[~df_no_adapters["TargetAA"].str.contains('X')]
  # create new dataframe with only unique AA and their summed reads
  df_processed = df_no_x.groupby('TargetAA',)['Reads'].sum().reset_index()
  df_processed = df_processed.sort_values('Reads',ascending = False).reset_index()
  return df_processed

def add_missing_sequences(lib_df, sel_df):
  # add sequences present in selected group that weren't present in the library group to the library with read count 1
  lib_df['source'] = 'library'
  sel_df['source'] = 'selected'

  combined_df = pd.concat([sel_df[['TargetAA', 'Reads', 'source']], lib_df[['TargetAA', 'Reads', 'source']]], ignore_index=True)

  # Identify sequences unique to the selected dataset
  unique_to_selected = combined_df.loc[combined_df['source'] == 'selected', 'TargetAA'].drop_duplicates()

  # Check which of these unique sequences are not in the library
  missing_in_library = ~unique_to_selected.isin(lib_df['TargetAA'])

  # Create a DataFrame for these missing sequences with a count of 1
  missing_sequences = pd.DataFrame({
      'TargetAA': unique_to_selected[missing_in_library],
      'Reads': 1,
      'source': 'library'
  })

  # Add these missing sequences to the library DataFrame
  lib_df_updated = pd.concat([lib_df, missing_sequences], ignore_index=True)
  return lib_df_updated

def remove_low_abundance_seq(df, threshold):
  df_without_low_abundance = df[df['Reads'] >= threshold]
  return df_without_low_abundance

def process_and_calculate_enrichment(sel_df, lib_df, threshold):
  # preprocess data
  df_sel_processed = preprocess_data(sel_df)
  df_lib_processed = preprocess_data(lib_df)
  df_lib_processed = add_missing_sequences(df_lib_processed, df_sel_processed)
  df_lib_processed = remove_low_abundance_seq(df_lib_processed, threshold)

  # Calculate enrichment scores
  pseudo_count = 1  # To avoid dividing by zero

  df_sel_processed["proportion_selected"] = (df_sel_processed['Reads'] + pseudo_count) / (df_sel_processed['Reads'].sum() + pseudo_count)
  df_lib_processed['proportion_library'] = (df_lib_processed['Reads'] + pseudo_count) / (df_lib_processed['Reads'].sum() + pseudo_count)

  # Merge dataframes on sequence
  df_merged = pd.merge(df_sel_processed, df_lib_processed, on='TargetAA', suffixes=('_selected', '_library'))

  # Calculate enrichment score
  df_merged['enrichment_score'] = np.log10(df_merged['proportion_selected'] / df_merged['proportion_library'])

  # Sort by enrichment
  df_enrichment_sorted = df_merged.sort_values(by='enrichment_score', ascending=False)
  df_enrichment_sorted = df_enrichment_sorted.rename(columns={'TargetAA': 'aa_sequence'})
  df_enrichment_sorted = df_enrichment_sorted.rename(columns={'Reads_selected': 'sorted_reads'})
  df_enrichment_sorted = df_enrichment_sorted.rename(columns={'Reads_library': 'library_reads'})
  df_final = df_enrichment_sorted[['aa_sequence', 'enrichment_score', 'library_reads', 'sorted_reads']]
  return df_final

In [4]:
import pandas as pd

df_lib = pd.read_excel("library1_raw.xlsx")
df_sel = pd.read_excel("sort1_C_raw.xlsx")

In [5]:
# save threshold 10 as the baseline
df = process_and_calculate_enrichment(df_sel, df_lib, 10)
df.to_csv(f"avrpikC_full_with_read_counts.csv", index=False)

## Preparing Datasets

In [8]:
# input threshold 10 data into torch Dataset, include library read counts as an attribute
import torch
import pandas as pd
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class BindingDataset(Dataset):
  def __init__(self, df):
    self.sequences = df['aa_sequence']
    self.scores = df['enrichment_score']
    self.reads = df['library_reads']
    self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

  def __len__(self):
    return len(self.sequences)

  def __getitem__(self, idx):
    sequence = self.sequences[idx]
    label = torch.tensor(self.scores[idx], dtype=torch.float)

    # tokenize the sequence
    tokenized = self.tokenizer(
        sequence,
        max_length=80, # 78 residues + 2 extra tokens
        return_tensors='pt'
    )

    # return input_ids: attention masks, removing the batch dimension
    inputs = {key: val.squeeze(0) for key, val in tokenized.items()}

    return inputs, label

df = pd.read_csv("avrpikC_full_with_read_counts.csv")
full_dataset = BindingDataset(df)

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [9]:
# perform train val test split on threshold 10 data
from torch.utils.data import random_split, DataLoader
torch.manual_seed(42)

batch_size = 32

val_split = 0.1
test_split = 0.1

# Calculate split sizes
total_len = len(full_dataset)
val_len = int(total_len * val_split)
test_len = int(total_len * test_split)
train_len = total_len - val_len - test_len

# Split into training pool, validation, and test sets
training_pool, val_dataset, test_dataset = random_split(full_dataset, [train_len, val_len, test_len])

# Create fixed DataLoaders for validation and test sets
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [10]:
import numpy as np
from torch.utils.data import Subset

# from training pool, get indices that have read counts above each threshold
# Access the indices of the training_pool subset
training_pool_original_indices = np.array(training_pool.indices)
training_pool_reads = full_dataset.reads.iloc[training_pool_original_indices]

thresholds = [10, 12, 15, 18, 21, 24]
sample_sizes = [32, 64, 128, 256]

datasets = []

for threshold in thresholds:
  # Get the indices within the training_pool_original_indices where the read count meets the threshold
  indices_within_training_pool = training_pool_original_indices[training_pool_reads >= threshold]

  for sample_size in sample_sizes:
    # Check if there are enough indices to sample from
    if len(indices_within_training_pool) < sample_size:
        print(f"Warning: Not enough sequences ({len(indices_within_training_pool)}) above threshold {threshold} in training pool to sample {sample_size}. Skipping this combination.")
        continue # Skip to the next sample size

    # Sample from the indices *within the training_pool_original_indices*
    subset_indices = np.random.choice(indices_within_training_pool, sample_size, False)

    # Create the Subset using indices relative to the original dataset
    # The Subset class handles the mapping from these original indices to the subset's internal indices
    subset = Subset(full_dataset, subset_indices)
    dataset_dict = {
        'reads': threshold,
        'samples': sample_size,
        'dataset': subset,
        }
    datasets.append(dataset_dict)

## Run experiments

In [11]:
from scripts.training import initialize_and_train_new_model, test_model
import pandas as pd
from torch.utils.data import DataLoader

# initialize dataframe to store results
results = pd.DataFrame(columns=['threshold', 'samples', 'spearmanr'])

# for each subset
for dataset in datasets:
  # create a new DataLoader
  train_dataloader = DataLoader(dataset['dataset'], batch_size=32, shuffle=True)

  # train model for 20 epochs
  print(f'\nTraining with threshold {dataset["reads"]}, and {dataset["samples"]} samples')
  print('---------------------------------------------------------------------------------')
  model = initialize_and_train_new_model(
      'cls-based',
      'facebook/esm2_t6_8M_UR50D',
      2e-5,
      0.01,
      50, # Reduced epochs for demonstration, you can change this back to 20
      train_dataloader,
      val_dataloader,
      patience=50,
  )
  # test model
  print(f'\nTesting')
  test_results = test_model(model, test_dataloader, True) # Store the result in a temporary variable

  # append results
  results_dict = {
      'threshold': dataset['reads'],
      'samples': dataset['samples'],
      'spearmanr': test_results['spearmanr'] # Access 'spearmanr' from the test_results dictionary
  }
  results = pd.concat([results, pd.DataFrame([results_dict])], ignore_index=True) # Append the results to the DataFrame
  results.to_csv("results/05_impact_of_read_counts/learning_curve.csv", index=False)

results


Training with threshold 10, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:13<00:00,  3.79it/s]


Train Loss: 0.0656 | Val Loss: 0.1388 | SpearmanR: 0.3366

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 73.88it/s]
  results = pd.concat([results, pd.DataFrame([results_dict])], ignore_index=True) # Append the results to the DataFrame



Training with threshold 10, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:12<00:00,  3.92it/s]


Train Loss: 0.0373 | Val Loss: 0.1358 | SpearmanR: 0.5086

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 77.02it/s]



Training with threshold 10, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:14<00:00,  3.41it/s]


Train Loss: 0.0606 | Val Loss: 0.1281 | SpearmanR: 0.4666

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.74it/s]



Training with threshold 10, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:22<00:00,  2.26it/s]


Train Loss: 0.0256 | Val Loss: 0.0792 | SpearmanR: 0.6990

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 72.98it/s]



Training with threshold 12, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:10<00:00,  4.86it/s]


Train Loss: 0.0268 | Val Loss: 0.1255 | SpearmanR: 0.3985

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 73.48it/s]



Training with threshold 12, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:12<00:00,  4.03it/s]


Train Loss: 0.0369 | Val Loss: 0.1445 | SpearmanR: 0.4687

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.83it/s]



Training with threshold 12, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:15<00:00,  3.23it/s]


Train Loss: 0.0291 | Val Loss: 0.0915 | SpearmanR: 0.6626

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 73.70it/s]



Training with threshold 12, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:21<00:00,  2.28it/s]


Train Loss: 0.0125 | Val Loss: 0.0715 | SpearmanR: 0.7226

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 73.62it/s]



Training with threshold 15, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:09<00:00,  5.13it/s]


Train Loss: 0.0766 | Val Loss: 0.1359 | SpearmanR: 0.4181

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 72.88it/s]



Training with threshold 15, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


Train Loss: 0.0259 | Val Loss: 0.1029 | SpearmanR: 0.5959

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 70.95it/s]



Training with threshold 15, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:16<00:00,  3.11it/s]


Train Loss: 0.0357 | Val Loss: 0.0977 | SpearmanR: 0.6244

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.41it/s]



Training with threshold 15, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:21<00:00,  2.30it/s]


Train Loss: 0.0180 | Val Loss: 0.0649 | SpearmanR: 0.7698

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 74.60it/s]



Training with threshold 18, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:10<00:00,  4.73it/s]


Train Loss: 0.0937 | Val Loss: 0.1432 | SpearmanR: 0.4881

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 74.34it/s]



Training with threshold 18, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:11<00:00,  4.22it/s]


Train Loss: 0.0252 | Val Loss: 0.1157 | SpearmanR: 0.5662

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.43it/s]



Training with threshold 18, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:15<00:00,  3.16it/s]


Train Loss: 0.0413 | Val Loss: 0.0916 | SpearmanR: 0.6558

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.18it/s]



Training with threshold 18, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:22<00:00,  2.23it/s]


Train Loss: 0.0198 | Val Loss: 0.0774 | SpearmanR: 0.7028

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 74.58it/s]



Training with threshold 21, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:10<00:00,  4.69it/s]


Train Loss: 0.0636 | Val Loss: 0.1213 | SpearmanR: 0.4771

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 73.42it/s]



Training with threshold 21, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:12<00:00,  4.07it/s]


Train Loss: 0.0268 | Val Loss: 0.1290 | SpearmanR: 0.4374

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.53it/s]



Training with threshold 21, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:15<00:00,  3.16it/s]


Train Loss: 0.0233 | Val Loss: 0.1373 | SpearmanR: 0.5060

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 74.02it/s]



Training with threshold 21, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:22<00:00,  2.23it/s]


Train Loss: 0.0190 | Val Loss: 0.0702 | SpearmanR: 0.7175

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.43it/s]



Training with threshold 24, and 32 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:11<00:00,  4.40it/s]


Train Loss: 0.0510 | Val Loss: 0.1478 | SpearmanR: 0.3449

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 74.44it/s]



Training with threshold 24, and 64 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:12<00:00,  4.01it/s]


Train Loss: 0.0403 | Val Loss: 0.1138 | SpearmanR: 0.5604

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.06it/s]



Training with threshold 24, and 128 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:16<00:00,  3.12it/s]


Train Loss: 0.0393 | Val Loss: 0.1366 | SpearmanR: 0.6272

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 75.18it/s]



Training with threshold 24, and 256 samples
---------------------------------------------------------------------------------


[Training]: 100%|██████████| 50/50 [00:23<00:00,  2.13it/s]


Train Loss: 0.0277 | Val Loss: 0.0777 | SpearmanR: 0.7144

Testing


[Testing]: 100%|██████████| 10/10 [00:00<00:00, 67.33it/s]


Unnamed: 0,threshold,samples,spearmanr
0,10,32,0.36632
1,10,64,0.541817
2,10,128,0.54321
3,10,256,0.742074
4,12,32,0.448068
5,12,64,0.441382
6,12,128,0.694497
7,12,256,0.727022
8,15,32,0.485453
9,15,64,0.646631


## Analysis

In [12]:
# plot learning curves with plotly
import plotly.express as px
import pandas as pd

fig = px.line(results,
              x="samples",
              y="spearmanr",
              color="threshold",
              title="Spearman Correlation vs. Sample Size by Threshold",
              markers=True) # Add markers to show data points

fig.update_layout(
    xaxis_title="Sample Size",
    yaxis_title="Spearman Correlation",
    legend_title="Threshold"
)

fig.show()

This is now the second time I've run this experiment and analysis, and with this second replicate, it seems pretty clear that there isn't going to be any statistically discernable differences between thresholds. This means read count bias is unlikely to be the culprit.