# Progress Report

In [1]:
import torch
import esm
import pandas as pd
import requests
import os
from tqdm import tqdm
import numpy as np
from sequence_models.pretrained import load_model_and_alphabet

In [2]:
# Load Model
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval()
device = torch.device("mps")
model.to(device)


MSATransformer(
  (embed_tokens): Embedding(33, 768, padding_idx=1)
  (dropout_module): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x AxialTransformerLayer(
      (row_self_attention): NormalizedResidualBlock(
        (layer): RowSelfAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
          (dropout_module): Dropout(p=0.1, inplace=False)
        )
        (dropout_module): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (column_self_attention): NormalizedResidualBlock(
        (layer): ColumnSelfAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768,

In [3]:
# Load Data
url = "https://elifesciences.org/download/aHR0cHM6Ly9jZG4uZWxpZmVzY2llbmNlcy5vcmcvYXJ0aWNsZXMvMTY5NjUvZWxpZmUtMTY5NjUtc3VwcDEtdjQueGxzelife-16965-supp1-v4.xlsx?_hash=UsG4XAO0qnBOtjEvbXpBgu%2FazhuWskkDqs417%2BIpaAM%3D"
r = requests.get(url, allow_redirects=True)
filepath = 'data/gb1data.xlsx'
if not os.path.exists(filepath):
    with open(filepath, 'wb') as f:
        f.write(r.content)
GB1_data = pd.read_excel(filepath)

  warn(msg)


In [4]:
# Load sequences
with open("data/GB1_Wu_2016.fasta", "r") as f:
    f.readline() # Skip the header
    wildtype = f.read().replace("\n", "")

sites = [39, 40, 41, 54]

In [5]:
# One-hot encode sequences
aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P","Q", "R", "S", "T", "V", "W", "Y"]
aa_to_onehot = {aa: np.eye(len(aa_list))[i] for i, aa in enumerate(aa_list)}
# Create a one-hot encoding for the wild-type sequence
onehot_wildtype = np.array([aa_to_onehot[aa] for aa in wildtype])

def encode_mutants_onehot(row, onehot_wildtype):
    mutant = onehot_wildtype.copy()
    mutations = row['Variants']
    for i, mutation in enumerate(mutations):
        mutant[sites[i]-1] = aa_to_onehot[mutation]
    return mutant
def encode_sequence(row, wild):
    mutant = list(wild)
    mutations = row['Variants']
    for i, mutation in enumerate(mutations):
        mutant[sites[i]-1] = mutation
    mutant = "".join(mutant)
    return mutant

In [6]:
# Apply processing for dataframe to get sequence/onehot
rows = len(GB1_data)
tqdm.pandas(total=rows)
GB1_data['onehot_sequence'] = GB1_data.progress_apply(lambda row: encode_mutants_onehot(row, onehot_wildtype), axis=1)
GB1_data['full_sequence'] = GB1_data.progress_apply(lambda row: encode_sequence(row, wildtype), axis=1)

100%|██████████| 149361/149361 [00:01<00:00, 124093.26it/s]
100%|██████████| 149361/149361 [00:00<00:00, 235594.01it/s]


In [7]:
# Get masked sequence
to_mask_sequence = list(wildtype)
for site in sites:
    to_mask_sequence[site-1] = "<mask>"
masked_sequence = "".join(to_mask_sequence)
masked_sequence

'MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNG<mask><mask><mask>EWTYDDATKTFT<mask>TE'

In [8]:
def get_prediction(data, masked_batch):
    _, _, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)
    batch_tokens = torch.concat((masked_batch, batch_tokens), dim=1)

    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)
    logits = results['logits'][0][0]
    logits = logits.cpu()
    predicted_tokens = torch.argmax(logits, dim=-1)
    predicted_sequence = "".join([alphabet.all_toks[token_idx.item()] for token_idx in predicted_tokens])[1:]
    predicted_fitness = float(GB1_data[GB1_data.full_sequence == predicted_sequence]['Fitness'].item())
    return predicted_sequence, predicted_fitness

def few_shot_prediction(masked_data, beat_wt_sequences, iters=3, context=5):
    _, _, batch_tokens_mask = batch_converter(masked_data)
    masked_batch = batch_tokens_mask[:, :, :-20]
    masked_batch = masked_batch.to(device)

    fits_avg = []
    fits_max = []
    old_preds = []
    tot_fit = 0.0
    max_fit = 0.0
    for i in tqdm(range(context ** iters)):
        full_data = [(f"{j}", seq) for j, seq in enumerate(np.random.choice(beat_wt_sequences, context, replace=False))]
        try:
            predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
        except:
            predicted_sequence, predicted_fitness = wildtype, 1.0
        max_fit = max(max_fit, predicted_fitness)
        tot_fit += predicted_fitness
        old_preds.append(predicted_sequence)
    fits_avg.append(tot_fit / (context ** iters))
    fits_max.append(max_fit)

    new_preds = []
    while len(old_preds) != 1:
        tot_fit = 0.0
        count = 0
        max_fit = 0.0
        while len(old_preds) != 0:
            count += 1
            full_data = [(f"{j}", old_preds.pop()) for j in range(context)]
            try:
                predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
            except:
                predicted_sequence, predicted_fitness = wildtype, 1.0
            max_fit = max(max_fit, predicted_fitness)
            tot_fit += predicted_fitness
            new_preds.append(predicted_sequence)
        fits_avg.append(tot_fit / count)
        fits_max.append(max_fit)
        old_preds = new_preds
        new_preds = []
    return old_preds, np.array(fits_avg), np.array(fits_max)

In [9]:
data_mask = [
    ("GB1", masked_sequence),
]

In [481]:
# varying context length
contexts = [5, 10, 15, 20, 25]
iters = 2
fitness_per_context_avg = []
fitness_per_context_max = []
for context in contexts:
    curr_context_fitness_avg = np.array([0.0, 0.0, 0.0])
    curr_context_fitness_max = np.array([0.0, 0.0, 0.0])
    for trial in range(10):
        final_pred, fitness_avg, fitness_max = few_shot_prediction(data_mask, GB1_data[GB1_data.Fitness > 1.0].full_sequence, iters, context)
        curr_context_fitness_avg += fitness_avg
        curr_context_fitness_max += fitness_max

    curr_context_fitness_avg /= 10.0
    curr_context_fitness_max /= 10.0

    fitness_per_context_avg.append(curr_context_fitness_avg)
    fitness_per_context_max.append(curr_context_fitness_max)


100%|██████████| 25/25 [00:16<00:00,  1.49it/s]
100%|██████████| 25/25 [00:16<00:00,  1.51it/s]
100%|██████████| 25/25 [00:16<00:00,  1.54it/s]
100%|██████████| 25/25 [00:16<00:00,  1.54it/s]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s]
100%|██████████| 25/25 [00:16<00:00,  1.52it/s]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 100/100 [01:11<00:00,  1.40it/s]
100%|██████████| 225

In [482]:
fitness_per_context_avg

[array([2.1337431 , 1.74710995, 1.22521008]),
 array([2.04258572, 1.95656149, 1.46530294]),
 array([2.06059032, 2.31722678, 1.75491497]),
 array([2.14678006, 2.6657728 , 2.93525528]),
 array([2.18196093, 2.73182302, 2.90317318])]

In [483]:
fitness_per_context_max

[array([5.91097697, 3.6112047 , 1.22521008]),
 array([5.95734381, 4.33668472, 1.46530294]),
 array([6.76436923, 5.22139873, 1.75491497]),
 array([6.67776499, 5.35068811, 2.93525528]),
 array([7.31623955, 4.9909823 , 2.90317318])]

We see an increase at 20 context, so do more iterations there.

In [10]:
def few_shot_prediction_iters(masked_data, beat_wt_sequences, iters=3, context=5):
    _, _, batch_tokens_mask = batch_converter(masked_data)
    masked_batch = batch_tokens_mask[:, :, :-20]
    masked_batch = masked_batch.to(device)

    fits_avg = []
    fits_max = []
    old_preds = []
    tot_fit = 0.0
    max_fit = 0.0
    for i in tqdm(range(context ** 2)):
        full_data = [(f"{j}", seq) for j, seq in enumerate(np.random.choice(beat_wt_sequences, context, replace=False))]
        try:
            predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
        except:
            predicted_sequence, predicted_fitness = wildtype, 1.0
        max_fit = max(max_fit, predicted_fitness)
        tot_fit += predicted_fitness
        old_preds.append(predicted_sequence)
    fits_avg.append(tot_fit / (context ** iters))
    fits_max.append(max_fit)

    new_preds = []
    for i in range(iters):
        tot_fit = 0.0
        max_fit = 0.0
        while len(new_preds) != (context ** 2):
            sample = np.random.choice(old_preds, context, replace=False)
            full_data = [(f"{j}", samp) for j, samp in enumerate(sample)]
            try:
                predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
            except:
                predicted_sequence, predicted_fitness = wildtype, 1.0
            max_fit = max(max_fit, predicted_fitness)
            tot_fit += predicted_fitness
            new_preds.append(predicted_sequence)
        fits_avg.append(tot_fit / (context ** 2))
        fits_max.append(max_fit)
        old_preds = new_preds
        new_preds = []
    
    sample = np.random.choice(old_preds, context, replace=False)
    full_data = [(f"{j}", samp) for j, samp in enumerate(sample)]
    try:
        predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
    except:
        predicted_sequence, predicted_fitness = wildtype, 1.0
    return predicted_sequence, predicted_fitness, np.array(fits_avg), np.array(fits_max)

In [11]:
iters_pred, iters_fitness, iters_avg_fit, iters_max_fit = few_shot_prediction_iters(data_mask, GB1_data[(GB1_data.Fitness > 1.0)].full_sequence, iters=10, context = 20)

100%|██████████| 400/400 [05:24<00:00,  1.23it/s]


In [12]:
iters_pred, iters_fitness

('MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGFSGEWTYDDATKTFTATE', 3.51664630346)

In [13]:
iters_avg_fit

array([8.31241451e-11, 2.79386615e+00, 3.29377146e+00, 3.82471291e+00,
       3.51920806e+00, 3.51664630e+00, 3.51664630e+00, 3.51664630e+00,
       3.51664630e+00, 3.51664630e+00, 3.51664630e+00])

In [14]:
iters_max_fit

array([6.04207764, 5.44061457, 5.44061457, 5.44061457, 4.54135002,
       3.5166463 , 3.5166463 , 3.5166463 , 3.5166463 , 3.5166463 ,
       3.5166463 ])

Seems to converge on a prediction that is better than the average initial fitness. 

In [13]:
def few_shot_prediction_iters_weighted(masked_data, beat_wt_data, iters=3, context=5):
    _, _, batch_tokens_mask = batch_converter(masked_data)
    masked_batch = batch_tokens_mask[:, :, :-20]
    masked_batch = masked_batch.to(device)

    fits_avg = []
    fits_max = []
    old_preds = []
    tot_fit = 0.0
    max_fit = 0.0
    for i in tqdm(range(context ** 2)):
        weights = list(beat_wt_data.Fitness / beat_wt_data.Fitness.sum())
        sample = np.random.choice(beat_wt_data.full_sequence, context, replace=False, p=weights)
        full_data = [(f"{j}", seq) for j, seq in enumerate(sample)]
        try:
            predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
        except:
            predicted_sequence, predicted_fitness = wildtype, 1.0
        max_fit = max(max_fit, predicted_fitness)
        tot_fit += predicted_fitness
        old_preds.append([predicted_sequence, float(predicted_fitness)])
    fits_avg.append(tot_fit / (context ** iters))
    fits_max.append(max_fit)

    new_preds = []
    for i in tqdm(range(iters)):
        tot_fit = 0.0
        max_fit = 0.0
        while len(new_preds) != (context ** 2):
            old_preds = np.array(old_preds)
            weights = old_preds[:, 1].astype(float) / np.sum(old_preds[:, 1].astype(float))
            sample = np.random.choice(old_preds[:, 0], context, replace=False, p=weights)
            full_data = [(f"{j}", samp) for j, samp in enumerate(sample)]
            try:
                predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
            except:
                predicted_sequence, predicted_fitness = wildtype, 1.0
            max_fit = max(max_fit, predicted_fitness)
            tot_fit += predicted_fitness
            new_preds.append([predicted_sequence, float(predicted_fitness)])
        fits_avg.append(tot_fit / (context ** 2))
        fits_max.append(max_fit)
        old_preds = new_preds
        new_preds = []
    old_preds = np.array(old_preds)
    weights = old_preds[:, 1].astype(float) / np.sum(old_preds[:, 1].astype(float))
    sample = np.random.choice(old_preds[:, 0], context, replace=False, p=weights)
    full_data = [(f"{j}", samp) for j, samp in enumerate(sample)]
    try:
        predicted_sequence, predicted_fitness = get_prediction(full_data, masked_batch)
    except:
        predicted_sequence, predicted_fitness = wildtype, 1.0
    return predicted_sequence, predicted_fitness, np.array(fits_avg), np.array(fits_max)

Weighted and limit search space to start

In [14]:
iters_pred_weighted, iters_fitness_weighted, iters_avg_fit_weighted, iters_max_fit_weighted = few_shot_prediction_iters_weighted(data_mask, GB1_data[(GB1_data['Fitness'] > 1.0) & (GB1_data['Fitness'] < 4.0)][['full_sequence', 'Fitness']], iters=10, context = 20)

100%|██████████| 400/400 [05:24<00:00,  1.23it/s]
100%|██████████| 10/10 [53:55<00:00, 323.52s/it]


In [15]:
iters_avg_fit_weighted

array([8.61112654e-11, 4.41664563e+00, 5.16145244e+00, 5.44061457e+00,
       5.44061457e+00, 5.44061457e+00, 5.44061457e+00, 5.44061457e+00,
       5.44061457e+00, 5.44061457e+00, 5.44061457e+00])

In [16]:
iters_max_fit_weighted

array([8.76196566, 5.44061457, 5.44061457, 5.44061457, 5.44061457,
       5.44061457, 5.44061457, 5.44061457, 5.44061457, 5.44061457,
       5.44061457])

In [17]:
iters_fitness_weighted

5.44061456678

# CARP

In [113]:
model, collater = load_model_and_alphabet('carp_38M')

In [114]:
masked_sequence_carp = masked_sequence.replace("<mask>", '#')

In [115]:
masked_sequence_carp

'MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNG###EWTYDDATKTFT#TE'

In [116]:
x = collater(masked_sequence_carp)[0]

In [117]:
results = model(x, logits=True)

In [118]:
test = model(collater("ABCD#DDE")[0], logits=True)

In [119]:
masked_sequence_carp.index('#')

38

In [120]:
test['logits'][4]

tensor([[  0.2644,  -0.9251,  -0.3768,  -0.2805,  -0.3468,   0.4952,  -0.7283,
          -0.2962,  -0.3186,   0.5028,   3.4290,  -0.5135,  -0.2505,  -0.4270,
           0.2813,   0.1844,  -0.1102,   0.2527,  -1.3649,  -0.5134,  -8.6467,
          -8.9015,  -1.2752, -16.6647, -14.5850,  -9.1555, -16.6833, -16.7042,
         -16.6712, -16.6683]], grad_fn=<SelectBackward0>)

In [121]:
results['logits'][38]

tensor([[  0.2644,  -0.9251,  -0.3768,  -0.2805,  -0.3468,   0.4952,  -0.7283,
          -0.2962,  -0.3186,   0.5028,   3.4290,  -0.5135,  -0.2505,  -0.4270,
           0.2813,   0.1844,  -0.1102,   0.2527,  -1.3649,  -0.5134,  -8.6467,
          -8.9015,  -1.2752, -16.6647, -14.5850,  -9.1555, -16.6833, -16.7042,
         -16.6712, -16.6683]], grad_fn=<SelectBackward0>)

In [122]:
results['logits'][39]

tensor([[  0.2644,  -0.9251,  -0.3768,  -0.2805,  -0.3468,   0.4952,  -0.7283,
          -0.2962,  -0.3186,   0.5028,   3.4290,  -0.5135,  -0.2505,  -0.4270,
           0.2813,   0.1844,  -0.1102,   0.2527,  -1.3649,  -0.5134,  -8.6467,
          -8.9015,  -1.2752, -16.6647, -14.5850,  -9.1555, -16.6833, -16.7042,
         -16.6712, -16.6683]], grad_fn=<SelectBackward0>)

# What progress has been done

So far, we have preprocessed the data, implemented ESM for inference using the pretrained model, implemented functions to extract fitness from predicted variants, and have started to implement CARP.

In addition to the above, we have done some research into how to fine-tune ESM for our purposes and also what other datasets we can bring in that could give us similar insights to GB1. 


# Challenges

There was little documentation on how to get the actual sequences out of the ESM model, so we used a combination of the internet and chat-gpt to get us through it and we think what we have makes sense. 

Additonally, getting models to work locally has been a challenge sine we both have M1/M2 macbooks which have some glitches with newer versions of pytorch and python. This also took a long time to debug. Later on, when we bring in larger ESM and CARP models and more data we will likely run it on the class cluster.

Finally, We are having trouble finding documentation on doing masked sequence prediction with CARP and are currently researching more into this.


# Next steps

Our next steps are to implement fine-tuning for ESM so that we can run comparisons to the pre-trained and the fine-tuned ESM for different model sizes and get a wide range of comparisons.

We are also working on finding other datasets to bring in to compare to the performance on the GB1 dataset.

Finally, we are in the process of implementing Microsoft's CARP model for this task and will likely have it ready by early next week.