In [1]:
import sys
sys.path.append('..')

In [2]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
import torch

from sklearn.metrics import mean_squared_error
from scipy import stats

In [4]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.al import top_model as topmodel
from DomainPrediction.al.embeddings import one_hot_encode

2025-04-16 11:27:51.606879: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-16 11:27:51.609201: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-16 11:27:51.614735: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-16 11:27:51.625688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-16 11:27:51.625706: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting t

In [5]:
env = os.environ['CONDA_DEFAULT_ENV']
if env == 'workspace':
    sys.path.append('../../esm')
    from DomainPrediction.esm.esmc import ESMCLM
    from DomainPrediction.al.finetuning import ESMCLoraRegression, ESMCConFit
elif env == 'workspace-esm':
    from DomainPrediction.esm.esm2 import ESM2
    from DomainPrediction.al.finetuning import ESM2LoraRegression, ESM2ConFit
else:
    raise Exception('I designed this for my envs. Feel free to modify accordingly')

### Load Data

In [6]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/fitness_prediction/TEM1'

In [7]:
file = os.path.join(data_path, 'dataset_tem1.csv')
df = pd.read_csv(file)

In [8]:
df.head()

Unnamed: 0,pos,variant,fitness_raw,fitness_norm,wt_aa,n_mut,seq,fold_id,split_id
0,58,F58N,1.53724,1.142208,F,1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,1,2
1,19,L19W,1.05372,1.094952,L,1,MSIQHFRVALIPFFAAFCWPVFAHPETLVKVKDAEDQLGARVGYIE...,0,2
2,15,A15Y,0.823567,1.072459,A,1,MSIQHFRVALIPFFYAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,0,2
3,182,A182M,0.768011,1.067029,A,1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,6,0
4,281,S281F,0.768011,1.067029,S,1,MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIE...,9,2


In [9]:
df.shape

(5198, 9)

In [10]:
results_file = os.path.join(data_path, 'results_tem1.csv')

In [11]:
if os.path.isfile(results_file):
    df_results = pd.read_csv(results_file)
else:
    df_results = df.copy()

In [12]:
df_results.columns[df_results.columns.str.contains('pred')]

Index([], dtype='object')

In [13]:
def get_split_mask(df, omit_zero=False):
    if omit_zero:
        train_mask = (df['split_id'] == 2) & (df['fitness_raw'] != 0)
    else:
        train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

In [14]:
def get_spearmanr_bootstrap(a, b, n=1000, ci = 95):
    assert type(a) == type(b) == np.ndarray
    assert len(a) == len(b)
    corr = []
    p_value = []
    np.random.seed(0)
    for _ in range(n):
        indices = np.random.choice(len(a), size=len(a), replace=True)
        res = stats.spearmanr(a[indices], b[indices])
        
        if not np.isnan(res.statistic):
            corr.append(res.statistic)
            p_value.append(res.pvalue)

    ci_lower, ci_upper = np.percentile(corr, [100-ci, ci]) 
    # stats.t.interval(confidence=ci, df=len(corr)-1, loc=np.mean(corr), scale=np.std(corr))
    mean_corr = np.mean(corr)

    return round(mean_corr, 2), round(ci_lower, 2), round(ci_upper, 2), corr, p_value

### OHE based models

#### Get embeddings and splits

In [15]:
embeddings = one_hot_encode(df['seq'])

100%|██████████| 5198/5198 [00:00<00:00, 27463.96it/s]


In [18]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

In [19]:
print(f'train {train_mask.sum()} val {val_mask.sum()} test {test_mask.sum()}')

train 3639 val 517 test 1042


#### Linear model

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

#### Random Forest

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

#### MLP

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [2300, 512, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 100,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

### Embedding-based Models

#### Get Embedding and splits

In [None]:
model_choices = 'esmc'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

In [None]:
embedding_choice = 'mean' # or concat

if embedding_choice == 'mean':
    embeddings = base_model.get_embeddings_mean(df['seq'])
elif embedding_choice == 'concat':
    embeddings = base_model.get_embeddings_flatten(df['seq'])
else:
    raise ValueError('model not found')

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_log'].to_numpy().astype(np.float32)

In [None]:
print(f'train {train_mask.sum()} test {test_mask.sum()}')

#### Linear Model

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

#### Random Forest

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

#### MLP

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [1152, 512, 1], 
        'epoch': 300, 
        'batch_size': 16,
        'patience': 200,
        'early_stopping': True,
        'lr': 1e-4,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

### Zero-Shot PLMs

In [None]:
model_choices = 'esmc'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

In [None]:
wt_sequence = df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
zero_shot_method = 'pseudolikelihood'

In [None]:
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']

    if zero_shot_method == 'wt_marginal':
        score, n_muts = base_model.get_wildtype_marginal(mt_sequence, wt_sequence)
        assert n_muts == row['n_mut']
    elif zero_shot_method == 'masked_marginal':
        score, n_muts = base_model.get_masked_marginal(mt_sequence, wt_sequence)
        assert n_muts == row['n_mut']
    elif zero_shot_method == 'pseudolikelihood':
        score = base_model.pseudolikelihood(mt_sequence)
    else:
        raise ValueError('method not found')

    y_pred.append(score)

y_pred = np.array(y_pred)

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

y = df['fitness_log'].to_numpy().astype(np.float32)

y_train = y[train_mask]
y_val = y[val_mask]
y_test = y[test_mask]

y_train_pred = y_pred[train_mask]
y_val_pred = y_pred[val_mask]
y_test_pred = y_pred[test_mask]

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

ax[0].plot(y, y_pred, '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, y_pred)
ax[0].set_title(f'Full Dataset \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mask = ~(df['fitness_raw'] == 0)
ax[1].plot(y[mask], y_pred[mask], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask], y_pred[mask])
ax[1].set_title(f'Omit fitness = 0 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

### Regression Finetuning

In [None]:
model_choices = 'esm2'

config={'epoch': 10, 
        'batch_size': 4,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 1,
        'device': 'gpu'}

if model_choices == 'esmc':
    surrogate = ESMCLoraRegression(name='esmc_600m', config=config)
elif model_choices == 'esm2':
    surrogate = ESM2LoraRegression(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

surrogate.print_trainable_parameters(surrogate)

In [None]:
surrogate

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

y_train = df_train['fitness_log'].to_numpy().astype(np.float32)
y_val = df_val['fitness_log'].to_numpy().astype(np.float32)
y_test = df_test['fitness_log'].to_numpy().astype(np.float32)

In [None]:
surrogate.trainmodel(df_train, df_val)

In [None]:
y_train_pred = surrogate.predict(df_train['seq'])
y_test_pred = surrogate.predict(df_test['seq'])

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

### Contrastive Finetuning

In [None]:
model_choices = 'esm2'

config={'epoch': 50, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'model_checkpoint': True,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'use_seq_head': True,
        'device': 'gpu'}

if model_choices == 'esmc':
    surrogate = ESMCConFit(name='esmc_600m', config=config)
elif model_choices == 'esm2':
    surrogate = ESM2ConFit(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

surrogate.print_trainable_parameters(surrogate)

In [None]:
surrogate

In [None]:
wt_sequence = df.loc[df['name'] == 'WT', 'seq'].iloc[0]

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

y_train = df_train['fitness_log'].to_numpy().astype(np.float32)
y_val = df_val['fitness_log'].to_numpy().astype(np.float32)
y_test = df_test['fitness_log'].to_numpy().astype(np.float32)

In [None]:
surrogate.trainmodel(df_train, wt_sequence, val=df_test)

In [None]:
surrogate.trainer.checkpoint_callback.best_model_path

In [None]:
surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])

In [None]:
y_train_pred = surrogate.predict(df_train['seq'], wt_sequence)
y_test_pred = surrogate.predict(df_test['seq'], wt_sequence)

fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_test[df_test['fitness_raw'] > 0.1], y_test_pred[df_test['fitness_raw'] > 0.1], '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test[df_test['fitness_raw'] > 0.1], y_test_pred[df_test['fitness_raw'] > 0.1])
ax[1].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()