In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.al import top_model as topmodel
from DomainPrediction.al.embeddings import one_hot_encode

In [None]:
from DomainPrediction.esm.esm2 import ESM2
from DomainPrediction.al.confit import ESM2ConFit

In [None]:
sys.path.append('../../esm')
from DomainPrediction.esm.esm3 import ESM3LM
from DomainPrediction.esm.esmc import ESMCLM
from DomainPrediction.al.confit import ESMCConFit

#### Load Data

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/al_test_experiments/Envision'

In [None]:
file = os.path.join(data_path, 'dataset_tem1.csv')
df = pd.read_csv(file)

In [None]:
df.head()

In [None]:
results_file = os.path.join(data_path, 'results_TEM1_zeroshot.csv')

In [None]:
if os.path.isfile(results_file):
    df_results = pd.read_csv(results_file)
else:
    df_results = df.copy()

In [None]:
df_results.head()

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
len(df_results.columns[df_results.columns.str.contains('pred')])

In [None]:
def get_split_mask(df, omit_zero=False):
    if omit_zero:
        train_mask = (df['split_id'] == 2)
    else:
        train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

#### OHE

In [None]:
embeddings = one_hot_encode(df['seq'])

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

# y_train = df.loc[train_mask, 'fitness_norm'].to_numpy().astype(np.float32)
# y_val = df.loc[val_mask, 'fitness_norm'].to_numpy().astype(np.float32)
# y_test = df.loc[test_mask, 'fitness_norm'].to_numpy().astype(np.float32)

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_ridge'] = y_pred

mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_OHE_ridge'])
corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_OHE_ridge'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_RF'] = y_pred

mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_OHE_RF'])
corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_OHE_RF'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [5720, 512, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 10,
        'early_stopping': True,
        'lr': 1e-3,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_MLP'] = y_pred

mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_OHE_MLP'])
corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_OHE_MLP'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ESM2, ESM3 and ESMC Embeddings

In [None]:
# esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
# esm3 = ESM3LM(device='gpu')
# esmc = ESMCLM(name='esmc_300m', device='gpu')
# esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
# embeddings = esm2.get_embeddings_flatten(df['seq'])
# embeddings = esm3.get_embeddings_flatten(df['seq'])
embeddings = esmc.get_embeddings_flatten(df['seq'])

In [None]:
embeddings.shape

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

# y_train = df.loc[train_mask, 'fitness_norm'].to_numpy().astype(np.float32)
# y_val = df.loc[val_mask, 'fitness_norm'].to_numpy().astype(np.float32)
# y_test = df.loc[test_mask, 'fitness_norm'].to_numpy().astype(np.float32)

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_ESMC600M_concat_ridge'] = y_pred

mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_ESMC600M_concat_ridge'])
corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_ESMC600M_concat_ridge'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# y_pred = surrogate.predict(embeddings)
# assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_feat_mean_RF'] = y_pred

# mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_ESMC600M_feat_mean_RF'])
# corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_ESMC600M_feat_mean_RF'])
# s_corr = round(corr.statistic, 2)
# print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [286, 64, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 20,
        'early_stopping': True,
        'lr': 1e-3,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# y_pred = surrogate.predict(embeddings)
# assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_feat_mean_MLP'] = y_pred

# mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_ESMC600M_feat_mean_MLP'])
# corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_ESMC600M_feat_mean_MLP'])
# s_corr = round(corr.statistic, 2)
# print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ESM2, ESM3, ESMC - ZeroShot Margials

In [None]:
# esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
# esm3 = ESM3LM(device='gpu')
# esmc = ESMCLM(name='esmc_300m', device='gpu')
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'TEM1_WT.fasta'), mode='str')[0]

In [None]:
## masked marginals
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    # score, n_muts = esm2.get_masked_marginal(mt_sequence, wt_sequence)
    # score, n_muts = esm3.get_masked_marginal(mt_sequence, wt_sequence)
    score, n_muts = esmc.get_masked_marginal(mt_sequence, wt_sequence)

    assert n_muts == row['n_mut']

    y_pred.append(score)

y_pred = np.array(y_pred)

In [None]:
y = df['fitness_raw'].to_numpy().astype(np.float32)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(3,3), layout='constrained')

ax.plot(y, y_pred, '.', alpha=0.5)
mse = mean_squared_error(y, y_pred)
corr = stats.spearmanr(y, y_pred)
s_corr = round(corr.statistic, 2)
ax.set_title(f'Full Dataset \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

In [None]:
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')

train_mask, val_mask, test_mask = get_split_mask(df)

for i, (dset, _mask) in enumerate(zip(['train', 'val', 'test'], 
                                      [train_mask, val_mask, test_mask])):
    ax[i].plot(y[_mask], y_pred[_mask], '.', alpha=0.5)
    mse = mean_squared_error(y[_mask], y_pred[_mask])
    corr = stats.spearmanr(y[_mask], y_pred[_mask])
    s_corr = round(corr.statistic, 2)
    ax[i].set_title(f'{dset} \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_masked_marginal'] = y_pred

# mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_ESMC600M_masked_marginal'])
# corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_ESMC600M_masked_marginal'])
# s_corr = round(corr.statistic, 2)
# print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ConFit - Contrastive Fitness Learning

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)
df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

In [None]:
config={'epoch': 10, 
        'batch_size': 16,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'device': 'gpu'}
surrogate = ESMCConFit(name='esmc_300m', config=config)
surrogate.print_trainable_parameters(surrogate.model)

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'TEM1_WT.fasta'), mode='str')[0]

In [None]:
surrogate.sanity_check(df_train, wt_sequence)

In [None]:
surrogate.trainmodel(df_train, wt_sequence, df_val)

In [None]:
## masked marginals
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    score, n_muts = surrogate.get_masked_marginal(mt_sequence, wt_sequence)

    assert n_muts == row['n_mut']

    y_pred.append(score)

y_pred = np.array(y_pred)
y = df['fitness_raw'].to_numpy().astype(np.float32)

In [None]:
y_train_pred, y_train = y_pred[train_mask], y[train_mask]
y_val_pred, y_val = y_pred[val_mask], y[val_mask]
y_test_pred, y_test = y_pred[test_mask], y[test_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.3)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr = stats.spearmanr(y_train, y_train_pred)
s_corr = round(corr.statistic, 2)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_val, y_val_pred)
corr = stats.spearmanr(y_val, y_val_pred)
s_corr = round(corr.statistic, 2)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

mse = mean_squared_error(y_test, y_test_pred)
corr = stats.spearmanr(y_test, y_test_pred)
s_corr = round(corr.statistic, 2)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {s_corr}')

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC300M_confit'] = y_pred

# mse = mean_squared_error(df_results['fitness_raw'], df_results['pred_ESMC300M_confit'])
# corr = stats.spearmanr(df_results['fitness_raw'], df_results['pred_ESMC300M_confit'])
# s_corr = round(corr.statistic, 2)
# print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### Simulation Low-N setting

In [None]:
n_samples = 53

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)
print(f'n samples: train {train_mask.sum()} val {val_mask.sum()} test {test_mask.sum()}')

random_indices = df[train_mask].sample(n=n_samples, random_state=0).index
selected_mask = df.index.isin(random_indices)

train_mask_selected = selected_mask
train_mask_rest = train_mask & ~selected_mask
print(f'n samples: train selected {train_mask_selected.sum()} train rest {train_mask_rest.sum()}')

df_train = df[train_mask_selected]
df_train_rest = df[train_mask_rest]
df_val = df[val_mask]
df_test = df[test_mask]

In [None]:
df_train.shape[0] + df_train_rest.shape[0] + df_val.shape[0] + df_test.shape[0] == df.shape[0]

In [None]:
# Index([ 593, 1955, 2673, 1569, 4422, 3783, 4495, 2588, 1461,  410, 5104, 2561,
#        4194, 2647, 4675, 1503, 4722,  793, 3294, 5192, 4313, 4681,  306,  673,
#        3732, 1213, 4123, 4173,  247, 3459, 5033, 3498, 2718, 3406,  963,  158,
#        4658, 1970, 1328, 4207,  735, 3717,  119, 5038, 2925, 2602, 3560, 2376,
#         476, 4519,  768,  676, 2228],
#       dtype='int64')

In [None]:
random_indices

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
embeddings = esmc.get_embeddings_flatten(df['seq'])

In [None]:
X_train = embeddings[train_mask_selected]
X_val = embeddings[val_mask]

y_train = df.loc[train_mask_selected, 'fitness_raw'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_pred = surrogate.predict(embeddings)
y = df['fitness_raw'].to_numpy().astype(np.float32)

In [None]:
def get_spearmanr_bootstrap(a, b, n=1000):
    assert type(a) == type(b) == np.ndarray
    assert len(a) == len(b)
    corr = []
    p_value = []
    np.random.seed(0)
    for _ in range(n):
        indices = np.random.choice(len(a), size=len(a), replace=True)
        res = stats.spearmanr(a[indices], b[indices])
        
        if not np.isnan(res.statistic):
            corr.append(res.statistic)
            p_value.append(res.pvalue)

    ci_lower, ci_upper = np.percentile(corr, [5, 95]) 
    # stats.t.interval(confidence=0.95, df=len(corr)-1, loc=np.mean(corr), scale=np.std(corr))
    mean_corr = np.mean(corr)

    return round(mean_corr, 2), round(ci_lower, 2), round(ci_upper, 2), corr, p_value

In [None]:
fig, ax = plt.subplots(1,4, figsize=(13,3), layout='constrained')

ax[0].plot(y[train_mask_selected], y_pred[train_mask_selected], '.', alpha=0.8)
ax[1].plot(y[train_mask_rest], y_pred[train_mask_rest], '.', alpha=0.3)
ax[2].plot(y[val_mask], y_pred[val_mask], '.', alpha=0.5)
ax[3].plot(y[test_mask], y_pred[test_mask], '.', alpha=0.3)

mse = mean_squared_error(y[train_mask_selected], y_pred[train_mask_selected])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_selected], y_pred[train_mask_selected])
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[train_mask_rest], y_pred[train_mask_rest])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_rest], y_pred[train_mask_rest])
ax[1].set_title(f'Train rest \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[val_mask], y_pred[val_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[val_mask], y_pred[val_mask])
ax[2].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[test_mask], y_pred[test_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[test_mask], y_pred[test_mask])
ax[3].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(4):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')

ax[0].plot(y[train_mask_selected], y_pred[train_mask_selected], '.', alpha=0.8)
ax[1].plot(y[val_mask], y_pred[val_mask], '.', alpha=0.5)
ax[2].plot(y[test_mask], y_pred[test_mask], '.', alpha=0.3)

mse = mean_squared_error(y[train_mask_selected], y_pred[train_mask_selected])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_selected], y_pred[train_mask_selected])
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)


mse = mean_squared_error(y[val_mask], y_pred[val_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[val_mask], y_pred[val_mask])
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[test_mask], y_pred[test_mask])
corr = stats.spearmanr(y[test_mask], y_pred[test_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[test_mask], y_pred[test_mask])
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
config={'epoch': 30, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'device': 'gpu'}
surrogate = ESMCConFit(name='esmc_600m', config=config)
surrogate.print_trainable_parameters(surrogate.model)

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'TEM1_WT.fasta'), mode='str')[0]

In [None]:
surrogate.sanity_check(df_train, wt_sequence)

In [None]:
surrogate.trainmodel(df_train, wt_sequence, df_val)

In [None]:
## masked marginals
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    score, n_muts = surrogate.get_masked_marginal(mt_sequence, wt_sequence)

    assert n_muts == row['n_mut']

    y_pred.append(score)

y_pred = np.array(y_pred)
y = df['fitness_raw'].to_numpy().astype(np.float32)

In [None]:
fig, ax = plt.subplots(1,4, figsize=(13,3), layout='constrained')

ax[0].plot(y[train_mask_selected], y_pred[train_mask_selected], '.', alpha=0.8)
ax[1].plot(y[train_mask_rest], y_pred[train_mask_rest], '.', alpha=0.3)
ax[2].plot(y[val_mask], y_pred[val_mask], '.', alpha=0.5)
ax[3].plot(y[test_mask], y_pred[test_mask], '.', alpha=0.3)

mse = mean_squared_error(y[train_mask_selected], y_pred[train_mask_selected])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_selected], y_pred[train_mask_selected])
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[train_mask_rest], y_pred[train_mask_rest])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_rest], y_pred[train_mask_rest])
ax[1].set_title(f'Train rest \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[val_mask], y_pred[val_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[val_mask], y_pred[val_mask])
ax[2].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[test_mask], y_pred[test_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[test_mask], y_pred[test_mask])
ax[3].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(4):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')

ax[0].plot(y[train_mask_selected], y_pred[train_mask_selected], '.', alpha=0.8)
ax[1].plot(y[val_mask], y_pred[val_mask], '.', alpha=0.5)
ax[2].plot(y[test_mask], y_pred[test_mask], '.', alpha=0.3)

mse = mean_squared_error(y[train_mask_selected], y_pred[train_mask_selected])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[train_mask_selected], y_pred[train_mask_selected])
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)


mse = mean_squared_error(y[val_mask], y_pred[val_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[val_mask], y_pred[val_mask])
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y[test_mask], y_pred[test_mask])
corr = stats.spearmanr(y[test_mask], y_pred[test_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[test_mask], y_pred[test_mask])
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()