# Run a one batch acquisition using the variants/mutations that were found in Gisaid


In [17]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split


#import * from utils, gaussian_process, active_learner in ../src
import sys
sys.path.append('../src')
from utils import *
from gaussian_process import *
from active_learner import *
from hist_al import *

# Get gisaid data

In [18]:
rbd_df=pd.read_csv('../gisaid/rbd_dates.csv', sep=',')
rbd_df

Unnamed: 0,seq,count,q05_date
0,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,3980832,2021-06-29 00:00:00
1,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1297376,2022-01-31 00:00:00
2,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1219498,2022-06-11 00:00:00
3,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,1094468,2021-01-18 00:00:00
4,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,816046,2020-03-26 00:00:00
...,...,...,...
25720,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2023-01-02 00:00:00
25721,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2023-01-12 00:00:00
25722,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2022-10-19 00:00:00
25723,NITNLCPFDEVFNATTFASVYAWNRKRISNCVADYSVLYNFAPFFA...,1,2022-10-31 00:00:00


# DMS dataset (Bloom)

In [None]:
# Define the function get_mutation
def get_mutation(WT, sequences):
    """
    Function to determine mutations from the WT sequence given a list of sequences.
    
    Parameters:
    - WT (str): The wild-type (reference) sequence.
    - sequences (list): List of sequences to analyze for mutations.
    
    Returns:
    - set: A set of mutations, where each mutation is represented as WT+site+mutant.
    """
    mutations = set()
    
    # Iterate through each sequence in the list
    for seq in sequences:
        # Ensure sequences are of the same length to allow site-based comparison
        if len(seq) != len(WT):
            continue
        
        # Compare each site in the sequence to the WT
        for i, (wt_residue, mutant_residue) in enumerate(zip(WT, seq), start=1):
            if wt_residue != mutant_residue:
                # Format the mutation as WT+site+mutant (e.g., "A12G" for a mutation from A to G at position 12)
                mutation = f"{wt_residue}{i+330}{mutant_residue}"
                mutations.add(mutation)
                
    return mutations

# Define WT sequence
WT_sequence = "NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKST"

# Define date intervals for subsets
intervals = [
    ("2020-01-01", "2020-06-30"),
    ("2020-01-01", "2020-12-31"),
    ("2020-01-01", "2021-06-30"),
    ("2020-01-01", "2021-12-31"),
    ("2020-01-01", "2022-06-30"),
    ("2020-01-01", "2022-12-31")
]
# Initialize list to store detailed results
detailed_results = []

# Iterate over each date interval, similar to previous analysis
for start_date, end_date in intervals:
    # Filter by date range
    subset = rbd_df[(rbd_df['q05_date'] >= start_date) & (rbd_df['q05_date'] <= end_date)]
    subset = subset[subset['count'] >= 1000]
    
    # Get unique sequences in the subset
    unique_sequences = subset['seq'].unique()
    
    # Calculate mutations using the get_mutation function
    mutations = get_mutation(WT_sequence, unique_sequences)
    
    # Append result for the current subset
    detailed_results.append({
        'start_date': start_date,
        'end_date': end_date,
        'unique_sequences': len(unique_sequences),
        'unique_mutations': len(mutations),
        'mutations_observed': mutations
    })

# Convert results to a DataFrame for easy display
detailed_results_df = pd.DataFrame(detailed_results)
detailed_results_df

Unnamed: 0,start_date,end_date,unique_sequences,unique_mutations,mutations_observed
0,2020-01-01,2020-06-30,25,24,"{R408I, G446V, V367F, P479S, A522V, E484Q, Y50..."
1,2020-01-01,2020-12-31,59,55,"{S477N, R346K, R408I, S494P, G446V, V367F, P47..."
2,2020-01-01,2021-06-30,190,120,"{V401L, G504S, F374L, R346T, K528G, G446V, P47..."
3,2020-01-01,2021-12-31,437,367,"{G504S, R346T, K528G, G446V, P479S, A344T, I41..."
4,2020-01-01,2022-06-30,631,389,"{G504S, R346T, K528G, G446V, P479S, A344T, I41..."
5,2020-01-01,2022-12-31,821,410,"{G504S, R346T, K528G, G446V, P479S, A344T, I41..."


In [20]:
for index, row in detailed_results_df.iterrows():
    print(list(row['mutations_observed']))



['R408I', 'G446V', 'V367F', 'P479S', 'A522V', 'E484Q', 'Y508H', 'S373L', 'G476S', 'V483A', 'V483F', 'T478I', 'F490L', 'A475V', 'N450K', 'A344S', 'A520S', 'P384L', 'N481K', 'P384S', 'S477I', 'Y453F', 'A348S', 'N370S']
['S477N', 'R346K', 'R408I', 'S494P', 'G446V', 'V367F', 'P479S', 'A522V', 'E484Q', 'N354D', 'S459F', 'R357K', 'Y508H', 'A352S', 'S373L', 'I410V', 'K417N', 'G476S', 'V483A', 'L452R', 'V483F', 'V362F', 'T478I', 'N501T', 'F490L', 'A475V', 'N450K', 'K529R', 'A522S', 'F490S', 'E484K', 'V382L', 'A520S', 'G339D', 'T385I', 'A344S', 'K356R', 'N501Y', 'P384L', 'N440K', 'T478R', 'N481K', 'S477R', 'S477I', 'P384S', 'E406Q', 'L452M', 'Q414R', 'R408K', 'Y453F', 'S359N', 'A348S', 'N439K', 'N370S', 'N354K']
['V401L', 'G504S', 'F374L', 'R346T', 'K528G', 'G446V', 'P479S', 'V483L', 'Q414K', 'L455F', 'I410V', 'K417N', 'V367L', 'L452R', 'K529T', 'V483F', 'V362F', 'F490L', 'K529R', 'I468V', 'F490S', 'Y449S', 'V503I', 'R403K', 'V382L', 'A520S', 'K417T', 'N501Y', 'P384L', 'S530N', 'P384S', 'V341I'

In [21]:
EMBED = "esm3_coord"


In [22]:
df, train_x, train_y, targets = load_and_preprocess_data(EMBED)
pheno_columns = [
    "delta_log_kd_ACE2",
    "delta_log_kd_LY-CoV016",
    "delta_log_kd_REGN10987",
    "delta_log_kd_LY-CoV555",
    "delta_log_kd_S309",
]
pheno = df[pheno_columns]
pheno = pheno.to_numpy()
print(pheno.shape)

fitnesses = bio_model(pheno)
train_y = torch.tensor(
    np.array([fitnesses, fitnesses, fitnesses, fitnesses, fitnesses]).transpose()
)

dataset = Dataset_perso(train_x, train_y)
site_rbd_list = np.unique(df["site_SARS2"].values)





loaded embeddings of shape torch.Size([3803, 1536])
(3122, 5)


In [23]:


# Function to find mutation encoding and indexes
def get_mutation_indexes(detailed_df, site_mutation_df):
    # Prepare a list to store the results
    all_results = []

    # Iterate over each row in detailed_results_df
    for index, row in detailed_df.iterrows():
        mutations = list(row["mutations_observed"])
        mutation_indexes = []

        # Iterate through each mutation in the list
        for mutation in mutations:
            # Extract WT, site, and mutant from the mutation string (e.g., "Y453F")
            wt = mutation[0]
            site = float(mutation[1:-1])
            mutant = mutation[-1]

            # Find rows in the second dataframe matching mutant and site
            matching_rows = site_mutation_df[
                (site_mutation_df["mutation"] == mutant) & (site_mutation_df["site_SARS2"] == site)
            ]

            # Store the indexes of matching rows
            mutation_indexes.extend(matching_rows.index.tolist())

        # Append the result for the current row
        all_results.append(mutation_indexes)

    return all_results

# Call the function and print the results
mutation_indexes_results = get_mutation_indexes(detailed_results_df, df)
for i, indexes in enumerate(mutation_indexes_results):
    print(f"Row {i} mutation indexes: {indexes}")
    print(len(indexes))


Row 0 mutation indexes: [1275, 1835, 555, 2403, 3070, 2477, 2832, 674, 2346, 2446, 2450, 2377, 2568, 2329, 1903, 185, 3030, 864, 2416, 869, 2358, 1956, 261, 622]
24
Row 1 mutation indexes: [2362, 217, 1275, 2629, 1835, 555, 2403, 3070, 2477, 344, 2032, 388, 2832, 318, 674, 1322, 1411, 2346, 2446, 1946, 2450, 479, 2377, 2727, 2568, 2329, 1903, 3068, 2574, 2472, 826, 3030, 97, 881, 185, 374, 2730, 864, 1732, 2384, 2416, 2365, 2358, 869, 1242, 1942, 1376, 1276, 1956, 429, 261, 1713, 622, 350]
54
Row 2 mutation indexes: [1144, 2764, 692, 224, 1835, 2403, 2455, 1371, 1975, 1322, 1411, 560, 1946, 2450, 479, 2568, 2196, 2574, 1891, 2738, 1181, 826, 3030, 1416, 2730, 864, 869, 140, 2650, 1242, 642, 3065, 3066, 1804, 23, 1738, 1275, 555, 3070, 344, 388, 2346, 2446, 2821, 2727, 1903, 2300, 1790, 2472, 707, 2384, 2416, 3090, 2358, 2131, 2378, 2249, 2229, 1756, 1276, 2611, 1713, 2226, 84, 770, 677, 2775, 1414, 1339, 2477, 2032, 2832, 318, 729, 674, 2377, 66, 2329, 3068, 433, 2356, 1945, 881, 374, 

In [None]:
def save_indexes(active_learner, date_index):
    train_data_u, _ = active_learner.train_dataset.get_data()

    train_data_u_indexes = []
    for j in range(len(df)):
        for i in range(len(train_data_u)):
            if df[EMBED][j] == train_data_u[i].tolist():
                train_data_u_indexes.append(j)

    print("indexes checked by embedding", train_data_u_indexes)

    training_set = df.loc[train_data_u_indexes]
    # save training set in csv with name training_set_+strategy+"_run_"+run
    filename = "../script_results/1_batch_real_direct_esm3_coord/training_set_bloom_1000_date"+str(date_index)+".csv"

    # Save the training set to a CSV file
    training_set.to_csv(filename, index=False)

    hist_indexes = active_learner.get_training_indices_history()
    # save as a npy with the right name
    # Save the history as a .npy file
    print("hist_indexes", hist_indexes)
    filename = "../script_results/1_batch_real_direct_esm3_coord/training_indices_history_bloom_1000_date"+str(date_index)+".csv"

    # Save list of lists to a CSV file
    with open(filename, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(hist_indexes)

    print(f"List of lists saved as CSV to {filename}")

In [25]:

df_kd = pd.read_csv("../data_bloom/kd_bloom/df_bloom_processed.csv")
# rename columns site as site_SARS2
df_kd.rename(columns={"site": "site_SARS2"}, inplace=True)


train_data = load_esm_embeddings("esm3_coord")


targets = [
    "delta_log_kd_ACE2",
    "delta_log_kd_LY-CoV016",
    "delta_log_kd_REGN10987",
    "delta_log_kd_LY-CoV555",
    "delta_log_kd_S309",
]

train_x = torch.tensor(train_data)
train_y = torch.tensor(df_kd[targets].values)

# Remove rows with NaN
indexes_nan = np.unique(np.where(np.isnan(train_y))[0])
non_nan_indexes = np.setdiff1d(np.arange(train_y.shape[0]), indexes_nan)

train_x = train_x[non_nan_indexes]
train_y = train_y[non_nan_indexes]
df = df_kd.drop(indexes_nan).reset_index(drop=True)

df[EMBED] = train_x.tolist()
indexNames = []
for i in range(len(df)):
    if df["site_SARS2"][i] in [331, 332, 333, 527, 528, 529, 530, 531]:
        indexNames.append(i)
df.drop(indexNames, inplace=True)
#reset index
df = df.reset_index(drop=True)
train_x = np.delete(train_x, indexNames, axis=0)
train_y = np.delete(train_y, indexNames, axis=0)
print(df)

loaded embeddings of shape torch.Size([3803, 1536])
                                        mutant_sequence  log10Kd_ACE2  \
0     NITALCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.003694   
1     NITCLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.923694   
2     NITDLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.103694   
3     NITELCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.043694   
4     NITFLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.813694   
...                                                 ...           ...   
3117  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.013694   
3118  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.043694   
3119  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.993694   
3120  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      8.953694   
3121  NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...      9.013694   

      delta_log_kd_ACE2 mutation  site_SARS2  mut_escape_LY-CoV016  \
0

In [26]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

# Get end dates from detailed_results_df
end_dates = detailed_results_df['end_date'].values

# Prepare a list to hold all rows for the final dataframe
results = []

# Define biomodel and strategy
biomodel = 'direct'
strategy = 'greedy'

# Loop through mutation indexes and process results
for date_index, indexes in enumerate(mutation_indexes_results):
    NB_0 = len(indexes)  # Initial number of points
    NB_POINTS = 500 + 174 - NB_0  # Total points to use
    NB_ROUNDS = 1  # Only one round

    # Prepare initial training data
    train_0_indexes = indexes
    train_0_x = torch.stack([dataset.data_x[i] for i in train_0_indexes])
    train_0_y = torch.stack([dataset.data_y[i] for i in train_0_indexes])
    dataset_0 = Dataset_perso(train_0_x, train_0_y)

    # Run the active learner and collect r2, p, and var
    r2_list, p_list, var_list, active_learner = run_active_learner(
        strategy=strategy,
        dataset_0=dataset_0,
        dataset=dataset,
        NB_POINTS=NB_POINTS,
        NB_ROUNDS=NB_ROUNDS,
        biomodel=biomodel,
    )

    # Collect results for the single round (end of round 1)
    results.append({
        "end_dates": end_dates[date_index],
        "indexes": indexes,
        "p": p_list[-1],   # Last value of p (end of round 1)
        "r2": r2_list[-1], # Last value of r2 (end of round 1)
        "var": var_list[-1] # Last value of var (end of round 1)
    })

    save_indexes(active_learner, date_index)


 

# Convert the results list into a dataframe
results_df = pd.DataFrame(results)

 

created new AL object
training on  24  points


Training: 100%|██████████| 1000/1000 [00:04<00:00, 242.74it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:09<00:00, 101.77it/s]


total_y:  3122
Strategy: greedy, AUC: 0.7352107856556256, P: 0.34185303514376997, Var: 1.3743394613265991
indexes checked by embedding [17, 19, 24, 27, 29, 32, 33, 34, 38, 39, 40, 42, 45, 47, 48, 49, 50, 51, 53, 54, 55, 70, 74, 78, 80, 87, 99, 114, 115, 117, 118, 119, 120, 121, 124, 125, 126, 127, 128, 129, 130, 131, 132, 135, 138, 141, 145, 147, 152, 153, 154, 155, 156, 162, 163, 178, 185, 217, 223, 261, 302, 311, 325, 327, 329, 330, 331, 332, 337, 338, 341, 350, 355, 356, 361, 362, 363, 364, 366, 372, 381, 388, 401, 402, 404, 406, 410, 412, 438, 454, 456, 457, 458, 460, 462, 463, 466, 467, 468, 469, 471, 472, 483, 486, 489, 490, 491, 497, 500, 501, 502, 508, 509, 512, 513, 514, 515, 516, 517, 518, 525, 526, 529, 533, 534, 544, 549, 551, 555, 566, 567, 570, 571, 572, 573, 575, 580, 581, 582, 583, 584, 585, 589, 591, 594, 597, 598, 600, 601, 603, 616, 619, 622, 625, 627, 632, 639, 644, 648, 651, 653, 656, 658, 659, 660, 661, 670, 674, 677, 684, 685, 686, 687, 688, 695, 696, 699, 704, 7

Training: 100%|██████████| 1000/1000 [00:04<00:00, 206.02it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:10<00:00, 99.39it/s]


total_y:  3122
Strategy: greedy, AUC: 0.7182167624783282, P: 0.3450479233226837, Var: 1.621087670326233
indexes checked by embedding [17, 20, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 58, 74, 77, 78, 79, 80, 83, 86, 87, 89, 90, 96, 97, 114, 115, 119, 120, 121, 123, 125, 126, 127, 129, 131, 134, 135, 138, 141, 145, 147, 148, 152, 153, 154, 155, 156, 157, 159, 162, 163, 164, 165, 166, 171, 185, 188, 191, 210, 217, 229, 232, 261, 267, 304, 318, 324, 325, 327, 343, 344, 350, 362, 363, 364, 365, 366, 368, 370, 372, 374, 376, 377, 379, 381, 388, 397, 399, 400, 401, 402, 404, 405, 406, 409, 410, 411, 412, 413, 414, 416, 419, 429, 438, 444, 448, 452, 453, 454, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 479, 497, 500, 513, 514, 515, 516, 518, 519, 521, 524, 525, 526, 527, 528, 529, 530, 533, 544, 549, 552, 555, 570, 571, 572, 573, 575, 576, 578, 580, 581, 582, 583, 584, 585, 590, 601, 619, 622, 625, 628, 639, 644,

Training: 100%|██████████| 1000/1000 [00:05<00:00, 188.02it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:09<00:00, 103.08it/s]


total_y:  3122
Strategy: greedy, AUC: 0.7652831006478693, P: 0.3354632587859425, Var: 1.569736361503601
indexes checked by embedding [17, 19, 20, 23, 30, 39, 40, 42, 43, 45, 46, 47, 48, 50, 51, 55, 58, 61, 66, 67, 74, 75, 77, 80, 83, 84, 87, 89, 96, 97, 115, 117, 119, 121, 123, 127, 131, 134, 138, 140, 147, 153, 154, 155, 156, 157, 159, 162, 163, 164, 165, 171, 185, 188, 191, 207, 210, 217, 223, 224, 226, 229, 261, 264, 267, 304, 318, 324, 325, 327, 328, 330, 332, 335, 338, 341, 343, 344, 350, 357, 359, 362, 365, 366, 368, 369, 370, 372, 374, 376, 377, 378, 379, 381, 388, 397, 399, 400, 401, 402, 404, 405, 406, 409, 410, 411, 412, 413, 414, 416, 419, 429, 433, 435, 438, 441, 444, 447, 453, 454, 456, 457, 458, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 479, 497, 500, 502, 503, 510, 512, 513, 514, 515, 516, 517, 518, 521, 522, 525, 526, 527, 528, 529, 530, 533, 542, 549, 551, 552, 553, 554, 555, 560, 564, 567, 571, 572, 573, 575, 576, 578, 580, 581, 582, 5

Training: 100%|██████████| 1000/1000 [00:07<00:00, 140.53it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:09<00:00, 106.75it/s]


total_y:  3122
Strategy: greedy, AUC: 0.76510174285975, P: 0.3610223642172524, Var: 1.1898654699325562
indexes checked by embedding [17, 23, 32, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 55, 56, 58, 59, 60, 62, 65, 66, 67, 68, 69, 70, 71, 74, 84, 97, 104, 109, 114, 116, 118, 121, 127, 128, 129, 130, 134, 140, 153, 154, 155, 156, 157, 163, 164, 165, 166, 172, 175, 177, 180, 185, 186, 187, 188, 191, 193, 199, 200, 201, 205, 207, 208, 214, 216, 217, 222, 223, 224, 230, 231, 232, 239, 244, 249, 256, 261, 262, 268, 269, 271, 278, 282, 287, 288, 294, 318, 319, 320, 326, 327, 328, 329, 330, 334, 335, 338, 340, 341, 342, 344, 345, 348, 350, 352, 356, 357, 362, 364, 367, 369, 370, 371, 373, 374, 381, 388, 397, 401, 402, 403, 412, 415, 418, 419, 423, 428, 429, 432, 433, 438, 440, 447, 450, 454, 456, 457, 458, 460, 462, 463, 464, 466, 467, 468, 469, 470, 471, 472, 473, 479, 489, 490, 493, 500, 504, 510, 515, 516, 526, 530, 532, 533, 542, 549, 551, 555, 558, 560, 569, 572, 57

Training: 100%|██████████| 1000/1000 [00:07<00:00, 138.02it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:09<00:00, 107.82it/s]


total_y:  3122
Strategy: greedy, AUC: 0.7467817547221463, P: 0.3035143769968051, Var: 1.357261061668396
indexes checked by embedding [17, 23, 36, 39, 40, 42, 45, 47, 50, 51, 55, 58, 59, 62, 66, 67, 70, 71, 74, 78, 79, 80, 83, 84, 89, 96, 97, 100, 104, 105, 109, 113, 114, 115, 116, 118, 121, 127, 128, 129, 130, 131, 134, 140, 153, 154, 155, 156, 157, 159, 161, 163, 164, 165, 166, 169, 175, 185, 186, 188, 191, 193, 199, 200, 201, 205, 207, 210, 214, 216, 217, 222, 223, 224, 229, 232, 244, 258, 261, 262, 267, 271, 278, 282, 294, 318, 319, 320, 326, 327, 335, 340, 341, 343, 344, 348, 350, 356, 357, 359, 362, 364, 367, 370, 371, 373, 374, 376, 377, 381, 388, 397, 401, 402, 403, 404, 412, 415, 418, 419, 423, 428, 429, 432, 433, 438, 440, 447, 449, 450, 453, 454, 456, 457, 458, 462, 463, 464, 466, 467, 468, 469, 470, 471, 472, 473, 479, 490, 504, 510, 514, 515, 516, 526, 530, 532, 533, 542, 544, 549, 551, 555, 558, 560, 571, 572, 573, 575, 576, 577, 578, 579, 580, 582, 583, 589, 590, 601, 609

Training: 100%|██████████| 1000/1000 [00:07<00:00, 136.79it/s]


total_y:  3122
acquisition function with strategy:  greedy
training on  674  points


Training: 100%|██████████| 1000/1000 [00:09<00:00, 107.20it/s]


total_y:  3122
Strategy: greedy, AUC: 0.7581953189159595, P: 0.38338658146964855, Var: 1.226318359375
indexes checked by embedding [17, 23, 32, 35, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 58, 59, 61, 62, 64, 65, 66, 67, 68, 69, 70, 71, 74, 75, 78, 84, 96, 97, 100, 103, 105, 109, 113, 114, 115, 116, 118, 121, 124, 126, 127, 128, 129, 130, 131, 134, 140, 153, 154, 155, 156, 157, 160, 164, 165, 166, 170, 175, 177, 185, 186, 191, 199, 200, 201, 205, 210, 212, 214, 216, 217, 222, 223, 224, 230, 232, 239, 244, 258, 261, 262, 271, 278, 282, 287, 294, 318, 319, 320, 326, 327, 329, 330, 332, 334, 335, 338, 340, 341, 342, 343, 344, 348, 350, 354, 356, 357, 362, 364, 367, 370, 371, 373, 374, 376, 377, 381, 388, 397, 401, 402, 403, 412, 415, 418, 419, 423, 429, 432, 433, 437, 438, 448, 450, 453, 454, 456, 457, 458, 460, 462, 463, 464, 466, 467, 468, 469, 470, 471, 472, 473, 474, 479, 490, 500, 504, 510, 515, 516, 526, 530, 532, 533, 549, 551, 555, 558, 560, 571,

In [None]:
results_df = results_df[["end_dates", "p", "r2", "var", "indexes"]]

# Display or save the dataframe
print(results_df)
results_df.to_csv("../script_results/1_batch_real_direct_esm3_coord/bloom_1_batch_real_1000.csv", index=False) 

    end_dates         p        r2       var  \
0  2020-06-30  0.341853  0.735211  1.374339   
1  2020-12-31  0.345048  0.718217  1.621088   
2  2021-06-30  0.335463  0.765283  1.569736   
3  2021-12-31  0.361022  0.765102  1.189865   
4  2022-06-30  0.303514  0.746782  1.357261   
5  2022-12-31  0.383387  0.758195  1.226318   

                                             indexes  
0  [1275, 1835, 555, 2403, 3070, 2477, 2832, 674,...  
1  [2362, 217, 1275, 2629, 1835, 555, 2403, 3070,...  
2  [1144, 2764, 692, 224, 1835, 2403, 2455, 1371,...  
3  [2764, 224, 1835, 2403, 186, 1322, 560, 1946, ...  
4  [2764, 224, 1835, 2403, 186, 1322, 560, 2509, ...  
5  [2764, 224, 1835, 2403, 186, 1322, 560, 2509, ...  


# CM dataset (Desai)

In [28]:
EMBED='esm3_coord'
df, train_x, train_y, targets = load_and_preprocess_data(EMBED, 'desai')
pheno_columns = [
    "delta_log_kd_ACE2",
    "delta_log_kd_LY-CoV016",
    "delta_log_kd_REGN10987",
    "delta_log_kd_LY-CoV555",
    "delta_log_kd_S309",
]
pheno = df[pheno_columns]
pheno = pheno.to_numpy()
print(pheno.shape)

fitnesses = bio_model(pheno)
train_y = torch.tensor(
    np.array([fitnesses, fitnesses, fitnesses, fitnesses, fitnesses]).transpose()
)

dataset = Dataset_perso(train_x, train_y)

df


loaded embeddings of shape torch.Size([32768, 1536])
(32768, 5)


Unnamed: 0.1,Unnamed: 0,mutant_sequence,log10Kd_ACE2,log10Kd_CB6,log10Kd_CoV555,log10Kd_REGN10987,log10Kd_S309,count,real_f,pred_f,average_date,grammaticality,semantic_change,delta_log_kd_ACE2,delta_log_kd_LY-CoV016,delta_log_kd_LY-CoV555,delta_log_kd_REGN10987,delta_log_kd_S309,esm3_coord
0,0,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,8.961147,9.952776,10.153433,9.983337,9.314233,29270.0,1.412018,1.449018,2021-01-19,-99.609791,0.226296,0.038853,-0.152776,0.046567,0.016663,-0.014233,"[1.0022965669631958, -1.1477677822113037, 0.73..."
1,1,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,8.141450,9.714128,10.098001,9.976880,9.363101,6.0,1.456767,1.394222,2021-11-29,-103.413168,0.457919,0.858550,0.085872,0.101999,0.023120,-0.063101,"[0.3557801842689514, -1.0486171245574951, 0.61..."
2,2,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,9.307811,9.684239,10.091608,10.053366,9.366271,102.0,1.468875,1.534511,2021-03-19,-102.308620,0.433742,-0.307811,0.115761,0.108392,-0.053366,-0.066271,"[1.7678442001342773, -0.7169723510742188, 1.35..."
3,3,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,9.035689,9.833378,10.119421,9.959428,9.315663,0.0,1.515426,1.487580,5,-106.121214,0.570067,-0.035689,-0.033378,0.080579,0.040572,-0.015663,"[0.947010338306427, -0.8925265073776245, 0.945..."
4,4,NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFST...,8.820523,9.716320,9.991677,9.996232,9.280689,2.0,1.455213,1.496562,2021-03-31,-102.091247,0.351558,0.179477,0.083680,0.208323,0.003768,0.019311,"[1.0406644344329834, -0.7063531875610352, 1.21..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
32763,32763,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFT...,8.716156,5.000000,5.000000,5.000000,8.572664,0.0,1.598322,1.438997,5,-129.311489,0.911976,0.283844,4.800000,5.200000,5.000000,0.727336,"[-1.1058300733566284, 0.7039570212364197, -1.3..."
32764,32764,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFT...,8.124485,5.000000,5.000000,5.000000,8.550620,0.0,1.534816,1.432982,5,-125.234889,1.030324,0.875515,4.800000,5.200000,5.000000,0.749380,"[-1.0378551483154297, 0.8639916181564331, -0.9..."
32765,32765,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFT...,7.885902,5.000000,5.000000,5.000000,8.487258,0.0,1.583456,1.442057,5,-128.895375,1.473291,1.114098,4.800000,5.200000,5.000000,0.812742,"[-1.5707778930664062, 0.7529399394989014, -0.9..."
32766,32766,NITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFT...,9.526724,5.000000,5.000000,5.000000,8.562622,0.0,1.596617,1.469037,5,-128.512596,0.959698,-0.526724,4.800000,5.200000,5.000000,0.737378,"[-0.11324627697467804, 0.934541642665863, -0.6..."


In [None]:

# Define date intervals for subsets
intervals = [
    ("2020-01-01", "2020-06-30"),
    ("2020-01-01", "2020-12-31"),
    ("2020-01-01", "2021-06-30"),
    ("2020-01-01", "2021-12-31"),
    ("2020-01-01", "2022-06-30"),
    ("2020-01-01", "2022-12-31")
]
# Initialize list to store detailed results
detailed_results = []

# Iterate over each date interval, similar to previous analysis
for start_date, end_date in intervals:
    # Filter by date range
    subset = rbd_df[(rbd_df['q05_date'] >= start_date) & (rbd_df['q05_date'] <= end_date)]
    subset = subset[subset['count'] >= 1000]
    
    # Get unique sequences in the subset
    unique_sequences = subset['seq'].unique()
    
    # Calculate mutations using the get_mutation function
    
    # Append result for the current subset
    detailed_results.append({
        'start_date': start_date,
        'end_date': end_date,
        'unique_sequences': unique_sequences,

    })

# Convert results to a DataFrame for easy display
detailed_results_df = pd.DataFrame(detailed_results)
detailed_results_df

Unnamed: 0,start_date,end_date,unique_sequences
0,2020-01-01,2020-06-30,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...
1,2020-01-01,2020-12-31,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...
2,2020-01-01,2021-06-30,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...
3,2020-01-01,2021-12-31,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...
4,2020-01-01,2022-06-30,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...
5,2020-01-01,2022-12-31,[NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFS...


In [30]:
# Function to find mutation encoding and indexes
def get_variant_indexes(detailed_df, desai_df):
    # Create a dictionary mapping mutant_sequence to index for efficient lookups
    sequence_to_index = {sequence: idx for idx, sequence in desai_df["mutant_sequence"].items()}

    # Prepare a list to store the results
    all_results = []

    # Iterate over each row in detailed_results_df
    for row_index, row in detailed_df.iterrows():
        variants = row["unique_sequences"]  # Contains a list of sequences
        variant_indexes = []

        # Iterate through each mutation in the list
        for variant in variants:
            if variant in sequence_to_index:
                # Add the corresponding index from the dictionary
                variant_indexes.append(sequence_to_index[variant])
            # else:
            #     # If the variant is not found, log it (optional)
            #     print(f"Variant '{variant}' not found in desai_df.")

        # Append the result for the current row
        all_results.append(variant_indexes)

    return all_results

# Call the function and print the results
variant_indexes_results = get_variant_indexes(detailed_results_df, df)

# Print results for each row
for i, indexes in enumerate(variant_indexes_results):
    print(f"Row {i} variant indexes: {indexes}")
    print(f"Number of indexes found: {len(indexes)}")


Row 0 variant indexes: [16384]
Number of indexes found: 1
Row 1 variant indexes: [16384, 16448, 16640, 16896, 24576]
Number of indexes found: 5
Row 2 variant indexes: [16386, 16384, 16448, 0, 16640, 16898, 16394, 16896, 24576, 16402, 16387, 2]
Number of indexes found: 12
Row 3 variant indexes: [16386, 16384, 16383, 16448, 0, 15871, 15487, 16640, 16127, 16898, 16382, 16352, 8703, 16255, 16376, 32640, 32128, 16394, 1023, 16896, 24576, 15999, 8191, 31744, 255, 511, 31775, 16402, 15615, 16319, 127, 12287, 16320, 15455, 9215, 16387, 15359, 8447, 16351, 2]
Number of indexes found: 40
Row 4 variant indexes: [16386, 16384, 16383, 16448, 0, 15871, 15487, 16640, 16127, 119, 16898, 16382, 16352, 8703, 16255, 16376, 32640, 32128, 16394, 8311, 1023, 16896, 24576, 15999, 8191, 31744, 255, 511, 16367, 31775, 16402, 15615, 16319, 127, 12287, 16320, 15455, 9215, 16247, 16387, 15359, 16439, 8447, 16351, 2]
Number of indexes found: 45
Row 5 variant indexes: [16386, 16384, 16383, 16448, 0, 15871, 15487, 1

In [31]:
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt

# Get end dates from detailed_results_df
end_dates = detailed_results_df['end_date'].values

# Prepare a list to hold all rows for the final dataframe
results = []

# Define biomodel and strategy
biomodel = 'fitness'
strategy = 'greedy'

# Loop through mutation indexes and process results
for i, indexes in enumerate(variant_indexes_results):
    NB_0 = len(indexes)  # Initial number of points
    NB_POINTS = 100 + 20 - NB_0  # Total points to use
    NB_ROUNDS = 1  # Only one round

    # Prepare initial training data
    train_0_indexes = indexes
    train_0_x = torch.stack([dataset.data_x[i] for i in train_0_indexes])
    train_0_y = torch.stack([dataset.data_y[i] for i in train_0_indexes])
    dataset_0 = Dataset_perso(train_0_x, train_0_y)

    # Run the active learner and collect r2, p, and var
    try:
        r2_list, p_list, var_list, active_learner = run_active_learner(
            strategy=strategy,
            dataset_0=dataset_0,
            dataset=dataset,
            NB_POINTS=NB_POINTS,
            NB_ROUNDS=NB_ROUNDS,
            biomodel='direct',
        )

        # Collect results for the single round (end of round 1)
        results.append({
            "end_dates": end_dates[i],
            "indexes": indexes,
            "p": p_list[-1],   # Last value of p (end of round 1)
            "r2": r2_list[-1], # Last value of r2 (end of round 1)
            "var": var_list[-1] # Last value of var (end of round 1)
        })
    except Exception as e:
        print(f"Error processing row {i}: {e}")

# Convert the results list into a dataframe
results_df = pd.DataFrame(results)



created new AL object
training on  1  points


Training: 100%|██████████| 1000/1000 [00:03<00:00, 273.63it/s]


total_y:  32768


  train_x.var(dim=0).mean().detach().numpy(),


acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:05<00:00, 194.68it/s]


total_y:  32768
Strategy: greedy, AUC: 0.6746325849091006, P: 0.0, Var: 0.3496619164943695
created new AL object
training on  5  points


Training: 100%|██████████| 1000/1000 [00:03<00:00, 250.76it/s]


total_y:  32768
acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:05<00:00, 195.24it/s]


total_y:  32768
Strategy: greedy, AUC: 0.572642658665987, P: 0.02197131522734208, Var: 0.5142082571983337
created new AL object
training on  12  points


Training: 100%|██████████| 1000/1000 [00:04<00:00, 243.74it/s]


total_y:  32768
acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:05<00:00, 193.27it/s]


total_y:  32768
Strategy: greedy, AUC: 0.5839050100629511, P: 0.018919743667989014, Var: 0.7121909260749817
created new AL object
training on  40  points


Training: 100%|██████████| 1000/1000 [00:03<00:00, 258.47it/s]


total_y:  32768
acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:04<00:00, 219.96it/s]


total_y:  32768
Strategy: greedy, AUC: 0.6907267913303448, P: 0.015868172108635947, Var: 1.0642000436782837
created new AL object
training on  45  points


Training: 100%|██████████| 1000/1000 [00:04<00:00, 225.78it/s]


total_y:  32768
acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:05<00:00, 185.86it/s]


total_y:  32768
Strategy: greedy, AUC: 0.7220550290577755, P: 0.010680500457735734, Var: 1.0792138576507568
created new AL object
training on  45  points


Training: 100%|██████████| 1000/1000 [00:04<00:00, 231.75it/s]


total_y:  32768
acquisition function with strategy:  greedy
training on  120  points


Training: 100%|██████████| 1000/1000 [00:05<00:00, 193.32it/s]


total_y:  32768
Strategy: greedy, AUC: 0.6754946592995894, P: 0.01098565761367104, Var: 1.1257914304733276


In [None]:
results_df = results_df[["end_dates", "p", "r2", "var", "indexes"]]

# Display or save the dataframe
print(results_df)
results_df.to_csv("../script_results/1_batch_real_direct_esm3_coord/desai_1_batch_real_1000.csv", index=False) 

    end_dates         p        r2       var  \
0  2020-06-30  0.000000  0.674633  0.349662   
1  2020-12-31  0.021971  0.572643  0.514208   
2  2021-06-30  0.018920  0.583905  0.712191   
3  2021-12-31  0.015868  0.690727  1.064200   
4  2022-06-30  0.010681  0.722055  1.079214   
5  2022-12-31  0.010986  0.675495  1.125791   

                                             indexes  
0                                            [16384]  
1                [16384, 16448, 16640, 16896, 24576]  
2  [16386, 16384, 16448, 0, 16640, 16898, 16394, ...  
3  [16386, 16384, 16383, 16448, 0, 15871, 15487, ...  
4  [16386, 16384, 16383, 16448, 0, 15871, 15487, ...  
5  [16386, 16384, 16383, 16448, 0, 15871, 15487, ...  
