In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.al import top_model as topmodel
from DomainPrediction.al.embeddings import one_hot_encode

In [None]:
from DomainPrediction.esm.esm2 import ESM2
from DomainPrediction.al.confit import ESM2ConFit

In [None]:
sys.path.append('../../esm')
from DomainPrediction.esm.esm3 import ESM3LM
from DomainPrediction.esm.esmc import ESMCLM
from DomainPrediction.al.confit import ESMCConFit

#### Load Data

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/al_test_experiments/Tdomain'

In [None]:
file = os.path.join(data_path, 'dataset_2_tdomain.csv')
df = pd.read_csv(file)

In [None]:
df.head()

In [None]:
df.shape

In [None]:
results_file = os.path.join(data_path, 'results_2_tdomain_confit.csv')

In [None]:
if os.path.isfile(results_file):
    df_results = pd.read_csv(results_file)
else:
    df_results = df.copy()

In [None]:
df_results.head()

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
len(df_results.columns[df_results.columns.str.contains('pred')])

In [None]:
def get_split_mask(df, omit_zero=False):
    if omit_zero:
        train_mask = (df['split_id'] == 2) & (df['fitness_raw'] != 0)
    else:
        train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'].isin([0, 1])
    # test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

In [None]:
def get_spearmanr_bootstrap(a, b, n=1000):
    assert type(a) == type(b) == np.ndarray
    assert len(a) == len(b)
    corr = []
    p_value = []
    np.random.seed(0)
    for _ in range(n):
        indices = np.random.choice(len(a), size=len(a), replace=True)
        res = stats.spearmanr(a[indices], b[indices])
        
        if not np.isnan(res.statistic):
            corr.append(res.statistic)
            p_value.append(res.pvalue)

    ci_lower, ci_upper = np.percentile(corr, [5, 95]) 
    # stats.t.interval(confidence=0.95, df=len(corr)-1, loc=np.mean(corr), scale=np.std(corr))
    mean_corr = np.mean(corr)

    return round(mean_corr, 2), round(ci_lower, 2), round(ci_upper, 2), corr, p_value

#### OHE

In [None]:
embeddings = one_hot_encode(df['seq'])

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

# y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
# y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
# y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

y_train = df.loc[train_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_log'].to_numpy().astype(np.float32)

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_ridge'] = y_pred

mse = mean_squared_error(df_results['fitness_log'], df_results['pred_OHE_ridge'])
corr = stats.spearmanr(df_results['fitness_log'], df_results['pred_OHE_ridge'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_RF'] = y_pred

mse = mean_squared_error(df_results['fitness_log'], df_results['pred_OHE_RF'])
corr = stats.spearmanr(df_results['fitness_log'], df_results['pred_OHE_RF'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [2300, 512, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 100,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_OHE_MLP'] = y_pred

mse = mean_squared_error(df_results['fitness_log'], df_results['pred_OHE_MLP'])
corr = stats.spearmanr(df_results['fitness_log'], df_results['pred_OHE_MLP'])
s_corr = round(corr.statistic, 2)
print(f'mse : {str(round(mse, 2))} || spearman correlation = {s_corr}')

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ESM2, ESM3 and ESMC Embeddings

In [None]:
# esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
# esm3 = ESM3LM(device='gpu')
# esmc = ESMCLM(name='esmc_300m', device='gpu')
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
# embeddings = esm2.get_embeddings_flatten(df['seq'])
# embeddings = esm3.get_embeddings_flatten(df['seq'])
embeddings = esmc.get_embeddings_flatten(df['seq'])

In [None]:
embeddings.shape

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

# y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
# y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
# y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

y_train = df.loc[train_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_log'].to_numpy().astype(np.float32)

In [None]:
print(f'train {train_mask.sum()} val {val_mask.sum()} test {test_mask.sum()}')

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_ESMC600M_concat_ridge'] = y_pred

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
df_results['pred_ESMC600M_concat_RF'] = y_pred

df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [1152, 512, 1], 
        'epoch': 300, 
        'batch_size': 16,
        'patience': 200,
        'early_stopping': True,
        'lr': 1e-4,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# y_pred = surrogate.predict(embeddings)
# assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_feat_mean_MLP'] = y_pred

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ESM2, ESM3, ESMC - perplexities

In [None]:
# esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
# esm3 = ESM3LM(device='gpu')
# esmc = ESMCLM(name='esmc_300m', device='gpu')
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
y = df['fitness_log'].to_numpy().astype(np.float32)

In [None]:
y_pred = []
for seq in tqdm(df['seq']):
    # perplexity = esm2.compute_perplexity(seq)
    # perplexity = esm3.compute_perplexity(seq)
    perplexity = esmc.compute_perplexity(seq)
    y_pred.append(perplexity)

y_pred = np.array(y_pred)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

ax[0].plot(y, -y_pred, '.', alpha=0.5)
corr = stats.spearmanr(y, -y_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, -y_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mask = ~(df['fitness_raw'] == 0)
ax[1].plot(y[mask], -y_pred[mask], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask], -y_pred[mask])
ax[1].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_perplexity'] = y_pred

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ESM2, ESM3, ESMC - ZeroShot Margiaals

In [None]:
# from DomainPrediction.utils.constants import *

In [None]:
# gxps_wt = helper.read_fasta('/nethome/kgeorge/workspace/DomainPrediction/Data/gxps/GxpS_ATC.fasta', mode='str')[0]
# A_domain_wt = ''.join([s for i, s in enumerate(gxps_wt) if i in A_gxps_atc])
# C_domain_wt = ''.join([s for i, s in enumerate(gxps_wt) if i in C_gxps_atc])
# TplusLinker_wt = ''.join([s for i, s in enumerate(gxps_wt) if i not in A_gxps_atc+C_gxps_atc])

In [None]:
# assert TplusLinker_wt == df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
# df['TplusLinker'] = df['seq']

In [None]:
# assert gxps_wt ==  A_domain_wt + TplusLinker_wt + C_domain_wt

In [None]:
# df['seq'] = df['seq'].apply(lambda x: A_domain_wt+x+C_domain_wt)

In [None]:
# df.head()

In [None]:
# assert gxps_wt == df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
# esm2 = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
# esm3 = ESM3LM(device='gpu')
# esmc = ESMCLM(name='esmc_300m', device='gpu')
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
wt_sequence = df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
# ## wt marginals
# # wt_log_prob = esm2.get_log_prob(wt_sequence)
# # wt_log_prob = esm3.get_log_prob(wt_sequence)
# wt_log_prob = esmc.get_log_prob(wt_sequence)
# y_pred = []
# for i, row in tqdm(df.iterrows()):
#     mt_sequence = row['seq']
#     # score, n_muts = esm2.get_wildtype_marginal(mt_sequence, wt_sequence, wt_log_prob)
#     # score, n_muts = esm3.get_wildtype_marginal(mt_sequence, wt_sequence, wt_log_prob)
#     score, n_muts = esmc.get_wildtype_marginal(mt_sequence, wt_sequence, wt_log_prob)

#     assert n_muts == row['n_mut']

#     y_pred.append(score)

# y_pred = np.array(y_pred)

In [None]:
# ## masked marginals
# y_pred = []
# for i, row in tqdm(df.iterrows()):
#     mt_sequence = row['seq']
#     # score, n_muts = esm2.get_masked_marginal(mt_sequence, wt_sequence)
#     # score, n_muts = esm3.get_masked_marginal(mt_sequence, wt_sequence)
#     score, n_muts = esmc.get_masked_marginal(mt_sequence, wt_sequence)

#     assert n_muts == row['n_mut']

#     y_pred.append(score)

# y_pred = np.array(y_pred)

In [None]:
# ## masked marginals mutant
# y_pred = []
# for i, row in tqdm(df.iterrows()):
#     mt_sequence = row['seq']
#     # score, n_muts = esm2.get_masked_marginal_var(mt_sequence, wt_sequence)
#     # score, n_muts = esm3.get_masked_marginal_var(mt_sequence, wt_sequence)
#     score, n_muts = esmc.get_masked_marginal_var(mt_sequence, wt_sequence, mode='mt')

#     assert n_muts == row['n_mut']

#     y_pred.append(score)

# y_pred = np.array(y_pred)

In [None]:
## pseudolikelihood
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    # score, n_muts = esm2.get_masked_marginal_var(mt_sequence, wt_sequence)
    # score, n_muts = esm3.get_masked_marginal_var(mt_sequence, wt_sequence)
    score = esmc.pseudolikelihood(mt_sequence)

    y_pred.append(score)

y_pred = np.array(y_pred)

In [None]:
y = df['fitness_log'].to_numpy().astype(np.float32)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

ax[0].plot(y, y_pred, '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, y_pred)
ax[0].set_title(f'Full Dataset \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mask = ~(df['fitness_raw'] == 0)
ax[1].plot(y[mask], y_pred[mask], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask], y_pred[mask])
ax[1].set_title(f'Omit fitness = 0 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
plt.figure(figsize=(3, 3))
_ = df['n_mut'].hist()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

mask_muts = (df['n_mut'] < 15)

ax[0].plot(y[mask_muts], y_pred[mask_muts], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask_muts], y_pred[mask_muts])
ax[0].set_title(f'n mutations < 15 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)


ax[1].plot(y[~mask_muts], y_pred[~mask_muts], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[~mask_muts], y_pred[~mask_muts])
ax[1].set_title(f'n mutations > 15 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

mask_muts = (df['n_mut'] < 15)
mask_omit = df['fitness_raw'] > 0.01

ax[0].plot(y[mask_muts & mask_omit], y_pred[mask_muts & mask_omit], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask_muts & mask_omit], y_pred[mask_muts & mask_omit])
ax[0].set_title(f'n mutations < 15 and omit \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

ax[1].plot(y[~mask_muts & mask_omit], y_pred[~mask_muts & mask_omit], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[~mask_muts & mask_omit], y_pred[~mask_muts & mask_omit])
ax[1].set_title(f'n mutations > 15 and omit \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600_wt_marginal'] = y_pred

# df_results.to_csv(results_file, index=False)

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600_masked_marginal'] = y_pred

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### ConFit - Contrastive Fitness Learning

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)
df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

In [None]:
# config={'model_path': '/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt',
#         'epoch': 30, 
#         'batch_size': 8,
#         'lambda': 0.1,
#         'accumulate_batch_size': 32,
#         'patience': 20,
#         'early_stopping': False,
#         'lr': 5e-4,
#         'print_every_n_epoch': 1,
#         'device': 'gpu'}
# surrogate = ESM2ConFit(config=config)
# surrogate.print_trainable_parameters(surrogate.model)

In [None]:
config={'epoch': 10, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'device': 'gpu'}
# surrogate = ESMCConFit(name='esmc_300m', config=config)
surrogate = ESMCConFit(name='esmc_600m', config=config)
surrogate.print_trainable_parameters(surrogate.model)

In [None]:
wt_sequence = df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
surrogate.sanity_check(df_train, wt_sequence)

In [None]:
surrogate.config['epoch'] = 20

In [None]:
surrogate.trainmodel(df_train, wt_sequence, df_test[df_test['fitness_raw']>0.01])

In [None]:
## masked marginals
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    score, n_muts = surrogate.get_masked_marginal(mt_sequence, wt_sequence)

    assert n_muts == row['n_mut']

    y_pred.append(score)

y_pred = np.array(y_pred)
y = df['fitness_log'].to_numpy().astype(np.float32)

In [None]:
y_train_pred, y_train = y_pred[train_mask], y[train_mask]
y_val_pred, y_val = y_pred[val_mask], y[val_mask]
y_test_pred, y_test = y_pred[test_mask], y[test_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.8)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.8)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.8)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
omit_mask = df['fitness_raw'] != 0
# omit_mask = df['fitness_raw'] > 0.01
y_train_pred, y_train = y_pred[train_mask & omit_mask], y[train_mask & omit_mask]
y_val_pred, y_val = y_pred[val_mask & omit_mask], y[val_mask & omit_mask]
y_test_pred, y_test = y_pred[test_mask & omit_mask], y[test_mask & omit_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.8)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.8)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.8)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# omit_mask = df['fitness_raw'] != 0
omit_mask = df['fitness_raw'] > 0.01
y_train_pred, y_train = y_pred[train_mask & omit_mask], y[train_mask & omit_mask]
y_val_pred, y_val = y_pred[val_mask & omit_mask], y[val_mask & omit_mask]
y_test_pred, y_test = y_pred[test_mask & omit_mask], y[test_mask & omit_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.8)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.8)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.8)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# mse = mean_squared_error(y_val, y_val_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
# ax[1].set_title(f'Val \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
## pseudolikelihood
pseudo_likelihood_surrogate = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    score = surrogate.pseudolikelihood(mt_sequence)

    pseudo_likelihood_surrogate.append(score)

pseudo_likelihood_surrogate = np.array(pseudo_likelihood_surrogate)

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')
## pseudolikelihood
pseudo_likelihood_baseline = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    score = esmc.pseudolikelihood(mt_sequence)

    pseudo_likelihood_baseline.append(score)

pseudo_likelihood_baseline = np.array(pseudo_likelihood_baseline)

In [None]:
omit_mask = df['fitness_raw'] > 0.01
y_train_pred, y_train = pseudo_likelihood_baseline[train_mask & omit_mask], y[train_mask & omit_mask]
y_val_pred, y_val = pseudo_likelihood_baseline[val_mask & omit_mask], y[val_mask & omit_mask]
y_test_pred, y_test = pseudo_likelihood_baseline[test_mask & omit_mask], y[test_mask & omit_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.8)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.8)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.8)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# mse = mean_squared_error(y_val, y_val_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
# ax[1].set_title(f'Val \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
omit_mask = df['fitness_raw'] > 0.01
y_train_pred, y_train = pseudo_likelihood_surrogate[train_mask & omit_mask], y[train_mask & omit_mask]
y_val_pred, y_val = pseudo_likelihood_surrogate[val_mask & omit_mask], y[val_mask & omit_mask]
y_test_pred, y_test = pseudo_likelihood_surrogate[test_mask & omit_mask], y[test_mask & omit_mask]

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.8)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.8)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.8)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# mse = mean_squared_error(y_val, y_val_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
# ax[1].set_title(f'Val \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
# assert y_pred.shape[0] == df_results.shape[0]
# df_results['pred_ESMC600M_confit'] = y_pred

# df_results.to_csv(results_file, index=False)

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

#### Changing WT to ESM3-2

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')

In [None]:
wt_sequence = df.loc[df['name'] == 'WT', 'seq'].iloc[0]
esm2_seq = df.loc[df['name'] == 'ESM2', 'seq'].iloc[0]

In [None]:
## masked marginals
y_pred_wt = []
y_pred_esm2 = []
muts_wt = []
muts_esm2 = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']
    # score_wt, n_muts_wt = esmc.get_masked_marginal(mt_sequence, wt_sequence)
    # score_esm2, n_muts_esm2 = esmc.get_masked_marginal(mt_sequence, esm2_seq)
    
    score_wt, n_muts_wt = esmc.get_wildtype_marginal(mt_sequence, wt_sequence)
    score_esm2, n_muts_esm2 = esmc.get_wildtype_marginal(mt_sequence, esm2_seq)

    assert n_muts_wt == row['n_mut']

    y_pred_wt.append(score_wt)
    y_pred_esm2.append(score_esm2)
    muts_wt.append(n_muts_wt)
    muts_esm2.append(n_muts_esm2)


y_pred_wt = np.array(y_pred_wt)
y_pred_esm2 = np.array(y_pred_esm2)
muts_wt = np.array(muts_wt)
muts_esm2 = np.array(muts_esm2)

In [None]:
y = df['fitness_log'].to_numpy().astype(np.float32)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

ax[0].plot(y, y_pred_wt, '.', alpha=0.8)
corr = stats.spearmanr(y, y_pred_wt)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, y_pred_wt)
ax[0].set_title(f'wrt wt \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

ax[1].plot(y, y_pred_esm2, '.', alpha=0.9)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, y_pred_esm2)
ax[1].set_title(f'wrt esm2 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

_mask = df['fitness_raw'] > 0.01

ax[0].plot(y[_mask], y_pred_wt[_mask], '.', alpha=0.8)
corr = stats.spearmanr(y[_mask], y_pred_wt[_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_wt[_mask])
ax[0].set_title(f'wrt wt \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

ax[1].plot(y[_mask], y_pred_esm2[_mask], '.', alpha=0.9)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_esm2[_mask])
ax[1].set_title(f'wrt esm2 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

fig, ax = plt.subplots(1,2, figsize=(5,2), layout='constrained')

ax[0].hist(muts_wt[_mask])
ax[0].set_title(f'wrt wt')

ax[1].hist(muts_esm2[_mask])
ax[1].set_title(f'wrt esm2')

plt.show()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

__mask = df['fitness_raw'] > 0.01
_mask = __mask & (df['n_mut'] < 15)

print(df.loc[_mask, 'name'].to_numpy())

ax[0].plot(y[_mask], y_pred_wt[_mask], '.', alpha=0.8)
corr = stats.spearmanr(y[_mask], y_pred_wt[_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_wt[_mask])
ax[0].set_title(f'wrt wt \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

ax[1].plot(y[_mask], y_pred_esm2[_mask], '.', alpha=0.9)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_esm2[_mask])
ax[1].set_title(f'wrt esm2 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

fig, ax = plt.subplots(1,2, figsize=(5,2), layout='constrained')

ax[0].hist(muts_wt[_mask])
ax[0].set_title(f'wrt wt')

ax[1].hist(muts_esm2[_mask])
ax[1].set_title(f'wrt esm2')

plt.show()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

__mask = df['fitness_raw'] > 0.01
_mask = __mask & (df['n_mut'] > 15)

print(df.loc[_mask, 'name'].to_numpy())

ax[0].plot(y[_mask], y_pred_wt[_mask], '.', alpha=0.8)
corr = stats.spearmanr(y[_mask], y_pred_wt[_mask])
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_wt[_mask])
ax[0].set_title(f'wrt wt \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

ax[1].plot(y[_mask], y_pred_esm2[_mask], '.', alpha=0.9)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[_mask], y_pred_esm2[_mask])
ax[1].set_title(f'wrt esm2 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

fig, ax = plt.subplots(1,2, figsize=(5,2), layout='constrained')

ax[0].hist(muts_wt[_mask])
ax[0].set_title(f'wrt wt')

ax[1].hist(muts_esm2[_mask])
ax[1].set_title(f'wrt esm2')

plt.show()