# Run a one batch acquisition using the variants/mutations that were found in Gisaid


In [20]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split


#import * from utils, gaussian_process, active_learner in ../src
import sys
sys.path.append('../src')
from utils import *
from gaussian_process import *
from active_learner import *
from hist_al import *

# Get gisaid data

In [21]:
rbd_df=pd.read_csv('../gisaid/rbd_dates.csv', sep=',')
rbd_df

Unnamed: 0,seq,count,q05_date
0,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,3980832,2021-06-29 00:00:00
1,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1297376,2022-01-31 00:00:00
2,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1219498,2022-06-11 00:00:00
3,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,1094468,2021-01-18 00:00:00
4,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,816046,2020-03-26 00:00:00
...,...,...,...
25720,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2023-01-02 00:00:00
25721,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2023-01-12 00:00:00
25722,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2022-10-19 00:00:00
25723,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2022-10-31 00:00:00


# DMS dataset (Bloom)

In [22]:
# Define the function get_mutation
def get_mutation(WT, sequences):
    """
    Function to determine mutations from the WT sequence given a list of sequences.
    
    Parameters:
    - WT (str): The wild-type (reference) sequence.
    - sequences (list): List of sequences to analyze for mutations.
    
    Returns:
    - set: A set of mutations, where each mutation is represented as WT+site+mutant.
    """
    mutations = set()
    
    # Iterate through each sequence in the list
    for seq in sequences:
        # Ensure sequences are of the same length to allow site-based comparison
        if len(seq) != len(WT):
            continue
        
        # Compare each site in the sequence to the WT
        for i, (wt_residue, mutant_residue) in enumerate(zip(WT, seq), start=1):
            if wt_residue != mutant_residue:
                # Format the mutation as WT+site+mutant (e.g., "A12G" for a mutation from A to G at position 12)
                mutation = f"{wt_residue}{i+330}{mutant_residue}"
                mutations.add(mutation)
                
    return mutations

# Define WT sequence
WT_sequence = "NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKST"

# Define date intervals for subsets
intervals = [
    ("2020-01-01", "2020-06-30"),
    ("2020-01-01", "2020-12-31"),
    ("2020-01-01", "2021-06-30"),
    ("2020-01-01", "2021-12-31"),
    ("2020-01-01", "2022-06-30"),
    ("2020-01-01", "2022-12-31")
]
# Initialize list to store detailed results
detailed_results = []

# Iterate over each date interval, similar to previous analysis
for start_date, end_date in intervals:
    # Filter by date range
    subset = rbd_df[(rbd_df['q05_date'] >= start_date) & (rbd_df['q05_date'] <= end_date)]
    subset = subset[subset['count'] >= 1000]
    
    # Get unique sequences in the subset
    unique_sequences = subset['seq'].unique()
    
    # Calculate mutations using the get_mutation function
    mutations = get_mutation(WT_sequence, unique_sequences)
    
    # Append result for the current subset
    detailed_results.append({
        'start_date': start_date,
        'end_date': end_date,
        'unique_sequences': len(unique_sequences),
        'unique_mutations': len(mutations),
        'mutations_observed': mutations
    })

# Convert results to a DataFrame for easy display
detailed_results_df = pd.DataFrame(detailed_results)
detailed_results_df

Unnamed: 0,start_date,end_date,unique_sequences,unique_mutations,mutations_observed
0,2020-01-01,2020-06-30,4,3,"{V367F, A520S, Y453F}"
1,2020-01-01,2020-12-31,15,14,"{S494P, R357K, L452R, N501T, A522S, K417N, S47..."
2,2020-01-01,2021-06-30,35,23,"{S494P, D427N, A475V, R346K, S477N, F490S, A52..."
3,2020-01-01,2021-12-31,86,64,"{S494P, D427N, S477I, V382L, N354K, V483F, Q49..."
4,2020-01-01,2022-06-30,116,77,"{P384L, V483F, F490S, L455F, N501T, F490L, N45..."
5,2020-01-01,2022-12-31,156,91,"{P384L, V483F, F490S, L455F, E484R, F486I, N50..."


In [23]:
for index, row in detailed_results_df.iterrows():
    print(list(row['mutations_observed']))



['V367F', 'A520S', 'Y453F']
['S494P', 'R357K', 'L452R', 'N501T', 'A522S', 'K417N', 'S477N', 'Y453F', 'N439K', 'N501Y', 'V367F', 'E484K', 'A520S', 'N440K']
['S494P', 'D427N', 'A475V', 'R346K', 'S477N', 'F490S', 'A520S', 'K417T', 'N501T', 'K417N', 'N501Y', 'V367F', 'T478K', 'R357K', 'Y453F', 'N439K', 'E484K', 'N440K', 'R346S', 'E484Q', 'A522S', 'L452R', 'L452Q']
['S494P', 'D427N', 'S477I', 'V382L', 'N354K', 'V483F', 'Q498R', 'R346K', 'S477N', 'P384L', 'F490S', 'L455F', 'A475V', 'E484A', 'A352S', 'K417T', 'A520S', 'S371L', 'T385I', 'G496S', 'N501T', 'F490L', 'K417N', 'P479S', 'N501Y', 'V367F', 'S373P', 'K356R', 'T478K', 'Q493R', 'G446S', 'V367L', 'R357K', 'G476S', 'R346I', 'A344S', 'Y453F', 'N439K', 'Q414R', 'A522V', 'R408I', 'P463S', 'S459F', 'E484K', 'P479L', 'N460S', 'N440K', 'G339D', 'A411S', 'E484Q', 'R346S', 'Q414K', 'A522S', 'L452R', 'L452Q', 'D427V', 'N354T', 'S494L', 'A348S', 'G446V', 'T376I', 'Y505H', 'A419S', 'S375F']
['P384L', 'V483F', 'F490S', 'L455F', 'N501T', 'F490L', 'N450

In [24]:
EMBED = "esm3_coord"


In [25]:
df, train_x, train_y, targets = load_and_preprocess_data(EMBED)
pheno_columns = [
    "delta_log_kd_ACE2",
    "delta_log_kd_LY-CoV016",
    "delta_log_kd_REGN10987",
    "delta_log_kd_LY-CoV555",
    "delta_log_kd_S309",
]
pheno = df[pheno_columns]
pheno = pheno.to_numpy()
print(pheno.shape)

fitnesses = bio_model(pheno)
train_y = torch.tensor(
    np.array([fitnesses, fitnesses, fitnesses, fitnesses, fitnesses]).transpose()
)

dataset = Dataset_perso(train_x, train_y)
site_rbd_list = np.unique(df["site_SARS2"].values)





loaded embeddings of shape torch.Size([3803, 1536])


  embeddings = torch.load(file_path)  # Load the tensor


(3122, 5)


In [26]:


# Function to find mutation encoding and indexes
def get_mutation_indexes(detailed_df, site_mutation_df):
    # Prepare a list to store the results
    all_results = []

    # Iterate over each row in detailed_results_df
    for index, row in detailed_df.iterrows():
        mutations = list(row["mutations_observed"])
        mutation_indexes = []

        # Iterate through each mutation in the list
        for mutation in mutations:
            # Extract WT, site, and mutant from the mutation string (e.g., "Y453F")
            wt = mutation[0]
            site = float(mutation[1:-1])
            mutant = mutation[-1]

            # Find rows in the second dataframe matching mutant and site
            matching_rows = site_mutation_df[
                (site_mutation_df["mutation"] == mutant) & (site_mutation_df["site_SARS2"] == site)
            ]

            # Store the indexes of matching rows
            mutation_indexes.extend(matching_rows.index.tolist())

        # Append the result for the current row
        all_results.append(mutation_indexes)

    return all_results

# Call the function and print the results
mutation_indexes_results = get_mutation_indexes(detailed_results_df, df)
for i, indexes in enumerate(mutation_indexes_results):
    print(f"Row {i} mutation indexes: {indexes}")
    print(len(indexes))


Row 0 mutation indexes: [555, 3030, 1956]
3
Row 1 mutation indexes: [2629, 388, 1946, 2727, 3068, 1411, 2362, 1956, 1713, 2730, 555, 2472, 3030, 1732]
14
Row 2 mutation indexes: [2629, 1544, 2329, 217, 2362, 2574, 3030, 1416, 2727, 1411, 2730, 555, 2378, 388, 1956, 1713, 2472, 1732, 223, 2477, 3068, 1946, 1945]
23
Row 3 mutation indexes: [2629, 1544, 2358, 826, 350, 2450, 2668, 217, 2362, 864, 2574, 1975, 2329, 2465, 318, 1416, 3030, 636, 881, 2650, 2727, 2568, 1411, 2403, 2730, 555, 677, 374, 2378, 2611, 1833, 560, 388, 2346, 216, 185, 1956, 1713, 1376, 3070, 1275, 2118, 2032, 2472, 2398, 2061, 1732, 97, 1339, 2477, 223, 1371, 3068, 1946, 1945, 1550, 357, 2626, 261, 1835, 729, 2775, 1453, 707]
64
Row 4 mutation indexes: [864, 2450, 2574, 1975, 2727, 2568, 1897, 374, 3065, 2611, 1833, 560, 388, 1275, 722, 631, 2398, 1732, 97, 3068, 1945, 1550, 729, 1221, 1453, 1544, 826, 350, 2668, 2362, 3030, 636, 2055, 185, 2061, 2477, 1282, 1946, 2626, 2775, 100, 2629, 217, 318, 1416, 555, 1942, 105

In [27]:
def save_indexes(active_learner, date_index):
    train_data_u, _ = active_learner.train_dataset.get_data()

    train_data_u_indexes = []
    for j in range(len(df)):
        for i in range(len(train_data_u)):
            if df[EMBED][j] == train_data_u[i].tolist():
                train_data_u_indexes.append(j)

    print("indexes checked by embedding", train_data_u_indexes)

    training_set = df.loc[train_data_u_indexes]
    # save training set in csv with name training_set_+strategy+"_run_"+run
    filename = "../script_results/1_batch_real_direct_esm3_coord/training_set_bloom_1000_date"+str(date_index)+".csv"

    # Save the training set to a CSV file
    training_set.to_csv(filename, index=False)

    hist_indexes = active_learner.get_training_indices_history()
    # save as a npy with the right name
    # Save the history as a .npy file
    print("hist_indexes", hist_indexes)
    filename = "../script_results/1_batch_real_direct_esm3_coord/training_indices_history_bloom_1000_date"+str(date_index)+".csv"

    # Save list of lists to a CSV file
    with open(filename, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(hist_indexes)

    print(f"List of lists saved as CSV to {filename}")

In [28]:

df_kd = pd.read_csv("../data_bloom/kd_bloom/df_bloom_processed.csv")
# rename columns site as site_SARS2
df_kd.rename(columns={"site": "site_SARS2"}, inplace=True)


train_data = load_esm_embeddings("esm3_coord")


targets = [
    "delta_log_kd_ACE2",
    "delta_log_kd_LY-CoV016",
    "delta_log_kd_REGN10987",
    "delta_log_kd_LY-CoV555",
    "delta_log_kd_S309",
]

train_x = torch.tensor(train_data)
train_y = torch.tensor(df_kd[targets].values)

# Remove rows with NaN
indexes_nan = np.unique(np.where(np.isnan(train_y))[0])
non_nan_indexes = np.setdiff1d(np.arange(train_y.shape[0]), indexes_nan)

train_x = train_x[non_nan_indexes]
train_y = train_y[non_nan_indexes]
df = df_kd.drop(indexes_nan).reset_index(drop=True)

df[EMBED] = train_x.tolist()
indexNames = []
for i in range(len(df)):
    if df["site_SARS2"][i] in [331, 332, 333, 527, 528, 529, 530, 531]:
        indexNames.append(i)
df.drop(indexNames, inplace=True)
#reset index
df = df.reset_index(drop=True)
train_x = np.delete(train_x, indexNames, axis=0)
train_y = np.delete(train_y, indexNames, axis=0)
print(df)

loaded embeddings of shape torch.Size([3803, 1536])


  embeddings = torch.load(file_path)  # Load the tensor


                                        mutant_sequence  log10Kd_ACE2  \
0     NITALCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.003694   
1     NITCLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.923694   
2     NITDLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.103694   
3     NITELCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.043694   
4     NITFLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.813694   
...                                                 ...           ...   
3117  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.013694   
3118  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.043694   
3119  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.993694   
3120  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.953694   
3121  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.013694   

      delta_log_kd_ACE2 mutation  site_SARS2  mut_escape_LY-CoV016  \
0             -0.003694        A       334.0         

In [29]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

# Get end dates from detailed_results_df
end_dates = detailed_results_df['end_date'].values

# Prepare a list to hold all rows for the final dataframe
results = []

# Define biomodel and strategy
biomodel = 'direct'
strategy = 'greedy'

# Loop through mutation indexes and process results
for date_index, indexes in enumerate(mutation_indexes_results):
    NB_0 = len(indexes)  # Initial number of points
    NB_POINTS = 500 + 165 - NB_0  # Total points to use
    NB_ROUNDS = 1  # Only one round

    # Prepare initial training data
    train_0_indexes = indexes
    train_0_x = torch.stack([dataset.data_x[i] for i in train_0_indexes])
    train_0_y = torch.stack([dataset.data_y[i] for i in train_0_indexes])
    dataset_0 = Dataset_perso(train_0_x, train_0_y)

    # Run the active learner and collect r2, p, and var
    r2_list, p_list, var_list, active_learner = run_active_learner(
        strategy=strategy,
        dataset_0=dataset_0,
        dataset=dataset,
        NB_POINTS=NB_POINTS,
        NB_ROUNDS=NB_ROUNDS,
        biomodel=biomodel,
    )

    # Collect results for the single round (end of round 1)
    results.append({
        "end_dates": end_dates[date_index],
        "indexes": indexes,
        "p": p_list[-1],   # Last value of p (end of round 1)
        "r2": r2_list[-1], # Last value of r2 (end of round 1)
        "var": var_list[-1] # Last value of var (end of round 1)
    })

    save_indexes(active_learner, date_index)


 

# Convert the results list into a dataframe
results_df = pd.DataFrame(results)

 

created new AL object
training on  3  points
Learned kernel: RationalQuadratic(alpha=1.29e+03, length_scale=592)
total_y:  3122
acquisition function with strategy:  greedy
training on  665  points
Learned kernel: RationalQuadratic(alpha=0.00336, length_scale=311)
total_y:  3122
Strategy: greedy, AUC: 0.77694360799343, P: 0.33226837060702874, Var: 1.1830344200134277
indexes checked by embedding [10, 38, 39, 40, 42, 43, 45, 46, 48, 50, 51, 52, 53, 54, 55, 57, 58, 59, 61, 62, 63, 64, 66, 67, 70, 71, 72, 73, 74, 75, 76, 77, 78, 80, 81, 86, 87, 88, 89, 90, 91, 92, 114, 115, 117, 120, 122, 123, 125, 129, 130, 135, 138, 139, 144, 148, 152, 153, 154, 156, 162, 163, 166, 209, 214, 216, 218, 225, 232, 253, 270, 271, 273, 282, 283, 292, 325, 326, 328, 361, 362, 365, 366, 368, 369, 372, 377, 385, 396, 399, 400, 401, 402, 404, 405, 406, 409, 410, 411, 412, 413, 414, 416, 418, 422, 423, 434, 437, 438, 444, 446, 452, 453, 456, 457, 458, 459, 460, 461, 462, 463, 464, 466, 467, 468, 469, 470, 472, 473,



Learned kernel: RationalQuadratic(alpha=0.0007, length_scale=26.8)
total_y:  3122
Strategy: greedy, AUC: 0.5890523770417009, P: 0.16613418530351437, Var: 0.5326253175735474
indexes checked by embedding [2, 3, 5, 11, 12, 14, 15, 29, 30, 31, 34, 35, 78, 79, 84, 88, 94, 97, 105, 106, 110, 116, 136, 140, 142, 160, 172, 173, 182, 183, 186, 201, 202, 203, 205, 221, 222, 228, 230, 231, 238, 239, 240, 243, 244, 246, 248, 249, 261, 262, 266, 268, 269, 271, 277, 278, 279, 281, 285, 287, 288, 289, 290, 291, 292, 295, 296, 297, 298, 301, 302, 303, 305, 306, 312, 314, 315, 316, 318, 319, 320, 326, 342, 344, 345, 354, 357, 358, 361, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 388, 392, 420, 421, 429, 430, 431, 433, 439, 482, 486, 487, 491, 495, 496, 498, 505, 506, 508, 509, 543, 545, 547, 551, 553, 555, 560, 562, 563, 564, 567, 593, 642, 657, 661, 665, 667, 668, 670, 676, 678, 680, 692, 703, 705, 706, 714, 715, 716, 718, 719, 722, 724, 725, 731, 735, 738, 740

MemoryError: Unable to allocate 3.37 MiB for an array with shape (665, 665) and data type float64

In [None]:
results_df = results_df[["end_dates", "p", "r2", "var", "indexes"]]

# Display or save the dataframe
print(results_df)
results_df.to_csv("../script_results/1_batch_real_direct_esm3_coord/bloom_1_batch_real_1000.csv", index=False) 

    end_dates         p        r2       var  \
0  2020-06-30  0.335463  0.781978  1.182352   
1  2020-12-31  0.175719  0.593399  0.533985   
2  2021-06-30  0.175719  0.600066  0.568055   
3  2021-12-31  0.258786  0.699422  1.007511   
4  2022-06-30  0.271565  0.706349  0.836747   
5  2022-12-31  0.277955  0.680278  0.854968   

                                             indexes  
0                                  [555, 3030, 1956]  
1  [2629, 388, 1946, 2727, 3068, 1411, 2362, 1956...  
2  [2629, 1544, 2329, 217, 2362, 2574, 3030, 1416...  
3  [2629, 1544, 2358, 826, 350, 2450, 2668, 217, ...  
4  [864, 2450, 2574, 1975, 2727, 2568, 1897, 374,...  
5  [864, 2450, 2574, 1975, 2478, 2509, 2727, 2568...  


In [None]:
#     end_dates         p        r2       var  \
# 0  2020-06-30  0.335463  0.781978  1.182352   
# 1  2020-12-31  0.175719  0.593399  0.533985   
# 2  2021-06-30  0.175719  0.600066  0.568055   
# 3  2021-12-31  0.258786  0.699422  1.007511   
# 4  2022-06-30  0.271565  0.706349  0.836747   
# 5  2022-12-31  0.277955  0.680278  0.854968 