# Example script for Hackathon

Within each cycle of active learning, you can:

1. Collect training data (original training data + your query data).

2. Train a prediction model to predict the DMS_score for each mutant (e.g., M0A).

3. Use the trained model to predict the score for all mutant in the test set.

4. Select query mutants for next round based on certain criteria. You may want to make sure you don't query the same mutant twice as you only have a limited chances of making queries in total.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from copy import deepcopy
import pandas as pd
from scipy.stats import spearmanr
import argparse
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from scipy.stats import spearmanr
import torch.nn.functional as F
from torch.nn.functional import gelu

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [None]:
# Configure parameters 
config = {
    "learning_rate": 9e-5,
    "batch_size": 8,
    "num_epochs": 5,
    "hidden_dim": 256,
    "esm_model_name": "esm2_t33_650M_UR50D", 
    "embedding_layer": 33, #Final layer for the esm model
}

## 1. collect training data

Upload `sequence.fasta`, `train.csv`, and `test.csv` to the current runtime:

1. click the folder icon on the left

2. click the upload icon and upload the files to the current directory

In [None]:
def read_fasta(filepath):
    with open(filepath, 'r') as f:
        lines = f.readlines()
    # Skip header lines (starting with ">")
    seq = "".join([line.strip() for line in lines if not line.startswith(">")])
    return seq

sequence_wt = read_fasta('sequence.fasta')
sequence_wt

'MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLREKMRRRLESGDKWFSLEFFPPRTAEGAVNLISRFDRMAAGGPLYIDVTWHPAGDPGSDKETSSMMIASTAVNYCGLETILHMTCCRQRLEEITGHLHKAKQLGLKNIMALRGDPIGDQWEEEEGGFNYAVDLVKHIRSEFGDYFDICVAGYPKGHPEAGSFEADLKHLKEKVSAGADFIITQLFFEADTFFRFVKACTDMGITCPIVPGIFPIQGYHSLRQLVKLSKLEVPQEIKDVIEPIKDNDAAIRNYGIELAVSLCQELLASGLVPGLHFYTLNREMATTEVLKRLGMWTEDPRRPLPWALSAHPKRREEDVRPIFWASRPKSYIYRTQEWDEFPNGRWGNSSSPAFGELKDYYLFYLKSKSPKEELLKMWGEELTSEESVFEVFVLYLSGEPNRNGHKVTCLPWNDEPLAAETSLLKEELLRVNRQGILTINSQPNINGKPSSDPIVGWGPSGGYVFQKAYLEFFTSRETAEALLQVLKKYELRVNYHLVNVKGENITNAPELQPNAVTWGIFPGREIIQPTVVDPVSFMFWKDEAFALWIERWGKLYEEESPSRTIIQYIHDNYFLVNLVDNDFPLDNCLWQVVEDTLELLNRPTQNARETEAP'

In [5]:
len(sequence_wt)

656

In [6]:
def get_mutated_sequence(mut, sequence_wt):
  wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]

  sequence = deepcopy(sequence_wt)

  return sequence[:pos]+mt+sequence[pos+1:]

In [7]:
def generate_mutant_sequence(sequence_wt, mut):
    return get_mutated_sequence(mut, sequence_wt)

In [8]:
df_train = pd.read_csv('train.csv')
df_train['sequence'] = df_train.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train

Unnamed: 0,mutant,DMS_score,sequence
0,M0Y,0.2730,YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,M0W,0.2857,WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,M0V,0.2153,VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,M0T,0.3122,TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,M0S,0.2180,SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...,...
1135,P347D,0.3876,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1136,P347C,0.1837,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1137,P347A,0.4611,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1138,P347M,0.2412,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [9]:
df_test = pd.read_csv('test.csv')
df_test['sequence'] = df_test.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_test

Unnamed: 0,mutant,sequence
0,V1D,MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,V1Y,MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,V1C,MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,V1A,MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,V1E,MENEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...
11319,P655S,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11320,P655T,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11321,P655V,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11322,P655A,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [None]:
# Integrating first set of query data:
df_query_new = pd.read_csv('queried_data.csv')
df_query_new['sequence'] = df_query_new['mutant'].apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train = pd.concat([df_train, df_query_new], ignore_index=True)

In [None]:
# Integrating second set of query data:
df_query_new2 = pd.read_csv('queried_data2.csv')
df_query_new2['sequence'] = df_query_new2['mutant'].apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train = pd.concat([df_train, df_query_new2], ignore_index=True)

In [None]:
# Integrating third set of query data:
df_query_new3 = pd.read_csv('queried_data3.csv')
df_query_new3['sequence'] = df_query_new3['mutant'].apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train = pd.concat([df_train, df_query_new3], ignore_index=True)

In [13]:
print(df_train)


     mutant  DMS_score                                           sequence
0       M0Y   0.273000  YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1       M0W   0.285700  WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2       M0V   0.215300  VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3       M0T   0.312200  TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4       M0S   0.218000  SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...     ...        ...                                                ...
1435  L639H   0.476989  MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1436  L639D   0.704033  MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1437  P655Y   0.645727  MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1438  P655F   0.727330  MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1439  P655W   0.596169  MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...

[1440 rows x 3 columns]


In [14]:
# Split train into training and validation sets
df_train, df_valid = train_test_split(df_train, test_size=0.1, random_state=42)
print("Training set shape:", df_train.shape)
print("Validation set shape:", df_valid.shape)

Training set shape: (1296, 3)
Validation set shape: (144, 3)


## 2. Train a prediction model

Here, we provided a linear regression model and used one-hot encoding to encode each variant. You would need to build your own model to achieve better performances.

Hint: you can perform cross-validation on the training set to evaluate your predictor before making predictions on the test set.

In [15]:
!pip install fair-esm
!pip install biopython
!pip install peft

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [16]:
import esm
dir(esm.pretrained)

['ESM2',
 'Namespace',
 'Path',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '_download_model_and_regression_data',
 '_has_regression_weights',
 '_load_model_and_alphabet_core_v1',
 '_load_model_and_alphabet_core_v2',
 'esm',
 'esm1_t12_85M_UR50S',
 'esm1_t34_670M_UR100',
 'esm1_t34_670M_UR50D',
 'esm1_t34_670M_UR50S',
 'esm1_t6_43M_UR50S',
 'esm1b_t33_650M_UR50S',
 'esm1v_t33_650M_UR90S',
 'esm1v_t33_650M_UR90S_1',
 'esm1v_t33_650M_UR90S_2',
 'esm1v_t33_650M_UR90S_3',
 'esm1v_t33_650M_UR90S_4',
 'esm1v_t33_650M_UR90S_5',
 'esm2_t12_35M_UR50D',
 'esm2_t30_150M_UR50D',
 'esm2_t33_650M_UR50D',
 'esm2_t36_3B_UR50D',
 'esm2_t48_15B_UR50D',
 'esm2_t6_8M_UR50D',
 'esm_if1_gvp4_t16_142M_UR50',
 'esm_msa1_t12_100M_UR50S',
 'esm_msa1b_t12_100M_UR50S',
 'esmfold_v0',
 'esmfold_v1',
 'has_emb_layer_norm_before',
 'load_hub_workaround',
 'load_model_and_alphabet',
 'load_model_and_alphabet_core',
 'load_model_and_alphabet_hub',


In [17]:
torch.cuda.empty_cache()

In [None]:
# Load the ESM model as specified in our config
model, alphabet = esm.pretrained.__dict__[config["esm_model_name"]]()
model = model.to(device)

In [19]:
from esm.modules import RobertaLMHead

In [None]:
# Configure LoRA
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=4,                # Low-rank dimension
    lora_alpha=16,      # Scaling factor 
    target_modules=["q_proj", "v_proj"],  # Target attention projection layers 
    lora_dropout=0.1,
    bias="none"
)

# Wrap the ESM model with LoRA adapters
model = get_peft_model(model, lora_config)
model.train()  # Set to training mode

  from .autonotebook import tqdm as notebook_tqdm


PeftModel(
  (base_model): LoraModel(
    (model): ESM2(
      (embed_tokens): Embedding(33, 1280, padding_idx=1)
      (layers): ModuleList(
        (0-32): 33 x TransformerLayer(
          (self_attn): MultiheadAttention(
            (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
            (v_proj): lora.Linear(
              (base_layer): Linear(in_features=1280, out_features=1280, bias=True)
              (lora_dropout): ModuleDict(
                (default): Dropout(p=0.1, inplace=False)
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=1280, out_features=4, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=4, out_features=1280, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (q_proj): lora.Linear(
 

In [21]:
# Verify parameter counts
num_total = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {num_total}")
print(f"Trainable parameters (LoRA adapters + unfrozen layers): {num_trainable}")

Total parameters: 651719094
Trainable parameters (LoRA adapters + unfrozen layers): 675840


In [None]:
# Here we create a new dataset that returns raw sequences.
class ProteinDatasetRaw(Dataset):
    def __init__(self, df, wt_seq):
        self.df = df.reset_index(drop=True)
        self.wt_seq = wt_seq

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mutation = row['mutant']
        label = row['DMS_score'] if 'DMS_score' in self.df.columns else None
        mutated_seq = generate_mutant_sequence(self.wt_seq, mutation)
        return {"sequence": mutated_seq, "label": label, "mutation": mutation}

In [None]:
#convert raw to token
batch_converter = alphabet.get_batch_converter()
def collate_fn(batch):
    sequences = [item["sequence"] for item in batch]
    # creat tokens
    batch_data = [("protein{}".format(i), seq) for i, seq in enumerate(sequences)]
    _, _, tokens = batch_converter(batch_data)
    tokens = tokens.to(device)
    # For training/validation batches, also gather labels
    if "label" in batch[0] and batch[0]["label"] is not None:
        labels = [item["label"] for item in batch]
        labels = torch.tensor(labels, dtype=torch.float32).to(device)
    else:
        labels = None
    return tokens, labels

In [24]:
# Create training and validation datasets and loaders
train_dataset_raw = ProteinDatasetRaw(df_train, sequence_wt)
valid_dataset_raw = ProteinDatasetRaw(df_valid, sequence_wt)
train_loader = DataLoader(train_dataset_raw, batch_size=config["batch_size"], shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset_raw, batch_size=config["batch_size"], shuffle=False, collate_fn=collate_fn)


In [None]:
class ESMWithLMHead(nn.Module):
    def __init__(self, esm_model, embedding_layer):
        super(ESMWithLMHead, self).__init__()
        self.esm_model = esm_model  
        self.embedding_layer = embedding_layer
        # initialize  LM head with 33 layers to extract embedings
        self.lm_head = RobertaLMHead(
            embed_dim=self.esm_model.embed_dim,         # 1280
            output_dim=self.esm_model.alphabet_size,      # 33
            weight=self.esm_model.embed_tokens.weight,    
        )

    def forward(self, tokens):
        results = self.esm_model(tokens, repr_layers=[self.embedding_layer], return_contacts=False)
        token_representations = results["representations"][self.embedding_layer]  
        # Average pooling over tokens 
        pooled = token_representations[:, 1:-1].mean(dim=1) 
        lm_out = self.lm_head(pooled)  
        # Aggregate the output into a scalar fitness prediction by taking the mean
        fitness_pred = lm_out.mean(dim=1) 
        return fitness_pred

# combine model with ESM and embedding layer
combined_model = ESMWithLMHead(model, config["embedding_layer"]).to(device)

In [26]:
#Training Loop
optimizer = optim.Adam(combined_model.parameters(), lr=config["learning_rate"])
criterion = nn.MSELoss()

def train_epoch(model, data_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    pbar = tqdm(data_loader, desc="Training", leave=False)
    for tokens, labels in pbar:
        optimizer.zero_grad()
        outputs = model(tokens)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * tokens.size(0)
        pbar.set_postfix(loss=loss.item())
    return running_loss / len(data_loader.dataset)

In [None]:
def evaluate(model, data_loader, criterion):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for tokens, labels in data_loader:
            outputs = model(tokens)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * tokens.size(0)
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / len(data_loader.dataset), all_preds, all_labels

best_val_loss = float('inf')
for epoch in range(config["num_epochs"]):
    train_loss = train_epoch(combined_model, train_loader, optimizer, criterion)
    # Evaluate on training data
    train_loss_eval, train_preds, train_labels = evaluate(combined_model, train_loader, criterion)
    train_spearman, _ = spearmanr(train_preds, train_labels)
    # Evaluate on validation data
    val_loss, val_preds, val_labels = evaluate(combined_model, valid_loader, criterion)
    val_spearman, _ = spearmanr(val_preds, val_labels)
    print(f"Epoch {epoch+1}/{config['num_epochs']}, "
          f"Train Loss: {train_loss:.4f}, Train Spearman: {train_spearman:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Spearman: {val_spearman:.4f}")
    #take the one with least val_loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = combined_model.state_dict()

# Load the best model state after training
combined_model.load_state_dict(best_model_state)

                                                                        

Epoch 1/5, Train Loss: 0.1799, Train Spearman: 0.3180, Val Loss: 0.0955, Val Spearman: 0.2343


Training:  15%|█▍        | 24/162 [00:29<02:50,  1.23s/it, loss=0.0886]

## 3. Preparing For Submission

In [None]:
# Create a test dataset 
class ProteinTestDataset(Dataset):
    def __init__(self, df, wt_seq):
        self.df = df.reset_index(drop=True)
        self.wt_seq = wt_seq

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        mutation = row['mutant']
        mutated_seq = generate_mutant_sequence(self.wt_seq, mutation)
        return mutated_seq

test_dataset = ProteinTestDataset(df_test, sequence_wt)
# For test set, create a simple collate that tokenizes the sequences
def test_collate_fn(batch):
    batch_data = [("protein{}".format(i), seq) for i, seq in enumerate(batch)]
    _, _, tokens = batch_converter(batch_data)
    tokens = tokens.to(device)
    return tokens, None

test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=test_collate_fn)


In [None]:
# Generating a new query list (#1 randomly):

# Got the set of mutants already present in our updated training data
train_mutants = set(df_train['mutant'].unique())

# Filtered the test set to get only candidates not in the training data
query_candidates = df_test[~df_test['mutant'].isin(train_mutants)]

# Sample up to 100 new mutants
num_queries = min(100, len(query_candidates))
query_mutants = query_candidates.sample(n=num_queries, random_state=42)['mutant'].tolist()

# Save the list of mutants to query.txt 
with open("query.txt", "w") as f:
    for mut in query_mutants:
        f.write(mut + "\n")

print("query.txt generated with", len(query_mutants), "mutants.")


NameError: name 'df_train' is not defined

In [None]:
# Uncertainty method 
def mc_dropout_inference(model, tokens, num_samples=10):
    model.train()  
    preds = []
    with torch.no_grad():
        for _ in range(num_samples):
            preds.append(model(tokens))
    preds = torch.stack(preds, dim=0)  # get multiple predictions for each mutant
    mean_pred = preds.mean(dim=0)       #average all predictions 
    uncertainty = preds.std(dim=0)      # compute std dev as uncertainty
    return mean_pred, uncertainty


In [None]:

all_preds = []
all_uncertainties = []
combined_model.eval()  


#run dropout on testset
for tokens, _ in tqdm(test_loader, desc="MC Dropout Inference"):
    mean_pred, uncertainty = mc_dropout_inference(combined_model, tokens, num_samples=10)
    all_preds.extend(mean_pred.cpu().numpy())
    all_uncertainties.extend(uncertainty.cpu().numpy())

# add predictions and uncertainty to test dataframe
df_test['DMS_score_predicted'] = all_preds
df_test['uncertainty'] = all_uncertainties


MC Dropout Inference: 100%|██████████| 1416/1416 [2:06:54<00:00,  5.38s/it] 


In [None]:
#generate top100 for query submission (#2 order by highest uncertainty score)

# get mutants already in training set
train_mutants = set(df_train['mutant'].unique())

# filter test set to include only mutants not already in training data
query_candidates = df_test[~df_test['mutant'].isin(train_mutants)]

# sort the remaining candidates by uncertainty 
query_candidates = query_candidates.sort_values('uncertainty', ascending=False)

# select top 100 mutants
query_mutants = query_candidates.head(100)['mutant'].tolist()

# save the query list to query.txt
with open("query.txt", "w") as f:
    for mut in query_mutants:
        f.write(mut + "\n")

print("query.txt generated with", len(query_mutants), "mutants.")


query.txt generated with 100 mutants.


In [None]:
# Select only the relevant columns
predictions_df = df_test[['mutant', 'DMS_score_predicted']]

# Save as CSV
predictions_df.to_csv('predictions.csv', index=False)

print("predictions.csv generated with", predictions_df.shape[0], "entries.")


predictions.csv generated with 11324 entries.


In [None]:
#generate top100 for query submission (#3 order by highest DMS score)

# Load our predictions 
df_preds = pd.read_csv('predictions.csv') 

# Sort by predicted DMS 
df_preds_sorted = df_preds.sort_values(by='DMS_score_predicted', ascending=False)

# Remove mutants that are already in the training set
mutants_in_train = df_train['mutant'].unique()
df_preds_filtered = df_preds_sorted[~df_preds_sorted['mutant'].isin(mutants_in_train)]

# Print results
print(df_preds_filtered.head())

#save to query3.txt
df_preds_filtered['mutant'].head(100).to_csv('query3.txt', index=False, header=False)


     mutant  DMS_score_predicted
2423  L134R             0.702731
4890  L300W             0.699369
2418  L134K             0.697462
2374  R131H             0.696087
4227  F256M             0.694895


In [None]:
#load original train.csv as df_train_raw for top10.txt creation below

df_train_raw = pd.read_csv('train.csv')
df_train_raw['sequence'] = df_train_raw.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train_raw

In [None]:
# Generate top10.txt for query submission

#remove the ones already in train.csv
train_mutants = set(df_train_raw['mutant'].unique())
query_candidates = predictions_df[~predictions_df['mutant'].isin(train_mutants)]
# Select the top 10 based on DMS
top10_df = query_candidates.sort_values('DMS_score_predicted', ascending=False).head(10)

#create top10.txt file
with open("top10.txt", "w") as f:
    for mutant in top10_df['mutant']:
        f.write(mutant + "\n")
print("top10.txt generated with mutants:")
print(top10_df['mutant'].tolist())


top10.txt generated with mutants:
['L134R', 'L300W', 'L134K', 'R131H', 'F256M', 'T243F', 'R593P', 'T243Y', 'R593G', 'W632G']
