In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import time

In [None]:
import torch

from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.al import top_model as topmodel
from DomainPrediction.al.embeddings import one_hot_encode

In [None]:
env = os.environ['CONDA_DEFAULT_ENV']
if env == 'workspace':
    sys.path.append('../../esm')
    from DomainPrediction.esm.esmc import ESMCLM
    from DomainPrediction.al.finetuning import ESMCLoraRegression, ESMCConFit
elif env == 'workspace-esm':
    from DomainPrediction.esm.esm2 import ESM2
    from DomainPrediction.al.finetuning import ESM2LoraRegression, ESM2ConFit
else:
    raise Exception('I designed this for my envs. Feel free to modify accordingly')

### Load Data

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/fitness_prediction/TEM1'

In [None]:
file = os.path.join(data_path, 'dataset_tem1.csv')
df = pd.read_csv(file)

In [None]:
df.head()

In [None]:
df.shape

In [None]:
results_file = os.path.join(data_path, 'results_tem1_lowdata.csv')

In [None]:
if os.path.isfile(results_file):
    df_results = pd.read_csv(results_file)
else:
    df_results = df.copy()

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
len(df_results.columns[df_results.columns.str.contains('pred')])

In [None]:
def get_split_mask(df, omit_zero=False):
    if omit_zero:
        train_mask = (df['split_id'] == 2) & (df['fitness_raw'] != 0)
    else:
        train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

In [None]:
def get_spearmanr_bootstrap(a, b, n=1000, ci = 95):
    assert type(a) == type(b) == np.ndarray
    assert len(a) == len(b)
    corr = []
    p_values = []
    np.random.seed(0)
    for _ in range(n):
        indices = np.random.choice(len(a), size=len(a), replace=True)
        res = stats.spearmanr(a[indices], b[indices])
        
        if not np.isnan(res.statistic):
            corr.append(res.statistic)
            p_values.append(res.pvalue)

    ci_lower, ci_upper = np.percentile(corr, [100-ci, ci]) 
    # stats.t.interval(confidence=ci, df=len(corr)-1, loc=np.mean(corr), scale=np.std(corr))
    mean_corr = np.mean(corr)
    p_value = np.mean(np.array(corr) < 0)

    return round(mean_corr, 2), round(ci_lower, 2), round(ci_upper, 2), p_value, corr, p_values

In [None]:
import pickle
file = os.path.join(data_path, 'sampled_tem1.pkl')
with open(file, 'rb') as f:
    sample_dict = pickle.load(f)

### ESM2/ESMC mean-embeddings RF

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'
embedding_choice = 'mean'  # 'mean' or 'concat'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/nethome/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

if embedding_choice == 'mean':
    embeddings = base_model.get_embeddings_mean(df['seq'])
elif embedding_choice == 'concat':
    embeddings = base_model.get_embeddings_flatten(df['seq'])
else:
    raise ValueError('model not found')

In [None]:
for sample_size in sample_dict:
    for fold in sample_dict[sample_size]:

        train_mask = sample_dict[sample_size][fold]
        _, val_mask, test_mask = get_split_mask(df, omit_zero=False)

        X_train = embeddings[train_mask]
        X_val = embeddings[val_mask]
        X_test = embeddings[test_mask]

        y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
        y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
        y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

        print(f'train size {X_train.shape[0]}')

        plt.figure(figsize=(5,3))
        plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
        plt.show()

        surrogate = topmodel.RFSurrogate()
        surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

        y_train_pred = surrogate.predict(X_train)
        y_val_pred = surrogate.predict(X_val)
        y_test_pred = surrogate.predict(X_test)

        fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
        ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
        ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
        ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

        mse = mean_squared_error(y_train, y_train_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
        ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        mse = mean_squared_error(y_val, y_val_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
        ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        mse = mean_squared_error(y_test, y_test_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
        ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        for i in range(3):
            ax[i].set_xlabel('True')
            ax[i].set_ylabel('Pred')

        plt.show()
        
        y_pred = surrogate.predict(embeddings)
        assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

        df_results[f'pred_{model_choices}_{embedding_choice}_RF_d{sample_size}_{fold}'] = y_pred
        df_results.to_csv(results_file, index=False)

        print(df_results.columns[df_results.columns.str.contains('pred')])

### ESM2/ESMC concat-embeddings linear

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'
embedding_choice = 'concat'  # 'mean' or 'concat'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/nethome/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

if embedding_choice == 'mean':
    embeddings = base_model.get_embeddings_mean(df['seq'])
elif embedding_choice == 'concat':
    embeddings = base_model.get_embeddings_flatten(df['seq'])
else:
    raise ValueError('model not found')

In [None]:
for sample_size in sample_dict:
    for fold in sample_dict[sample_size]:

        train_mask = sample_dict[sample_size][fold]
        _, val_mask, test_mask = get_split_mask(df, omit_zero=False)

        X_train = embeddings[train_mask]
        X_val = embeddings[val_mask]
        X_test = embeddings[test_mask]

        y_train = df.loc[train_mask, 'fitness_raw'].to_numpy().astype(np.float32)
        y_val = df.loc[val_mask, 'fitness_raw'].to_numpy().astype(np.float32)
        y_test = df.loc[test_mask, 'fitness_raw'].to_numpy().astype(np.float32)

        print(f'train size {X_train.shape[0]}')

        plt.figure(figsize=(5,3))
        plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
        plt.show()

        surrogate = topmodel.RidgeSurrogate(alpha=1.0)
        surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

        y_train_pred = surrogate.predict(X_train)
        y_val_pred = surrogate.predict(X_val)
        y_test_pred = surrogate.predict(X_test)

        fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
        ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
        ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
        ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

        mse = mean_squared_error(y_train, y_train_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
        ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        mse = mean_squared_error(y_val, y_val_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
        ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        mse = mean_squared_error(y_test, y_test_pred)
        corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
        ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

        for i in range(3):
            ax[i].set_xlabel('True')
            ax[i].set_ylabel('Pred')

        plt.show()
        
        y_pred = surrogate.predict(embeddings)
        assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

        df_results[f'pred_{model_choices}_{embedding_choice}_ridge_d{sample_size}_{fold}'] = y_pred
        df_results.to_csv(results_file, index=False)

        print(df_results.columns[df_results.columns.str.contains('pred')])

### Regression Finetuning

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'

config={'epoch': 30, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 1,
        'device': 'gpu'}

# if model_choices == 'esmc':
#     surrogate = ESMCLoraRegression(name='esmc_600m', config=config)
# elif model_choices == 'esm2':
#     surrogate = ESM2LoraRegression(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

# surrogate.print_trainable_parameters(surrogate)

In [None]:
import gc

In [None]:
sample_size = 200

for fold in range(2, 5):

    print(f'd_{sample_size} fold {fold}')

    surrogate = ESMCLoraRegression(name='esmc_600m', config=config)
    surrogate.print_trainable_parameters(surrogate)

    train_mask = sample_dict[sample_size][fold]
    _, val_mask, test_mask = get_split_mask(df, omit_zero=False)

    df_train = df.loc[train_mask]
    df_val = df[val_mask]
    df_test = df[test_mask]

    y_train = df_train['fitness_raw'].to_numpy().astype(np.float32)
    y_val = df_val['fitness_raw'].to_numpy().astype(np.float32)
    y_test = df_test['fitness_raw'].to_numpy().astype(np.float32)

    print(f'train size {df_train.shape[0]}')

    plt.figure(figsize=(5,3))
    plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
    plt.show()

    surrogate.trainmodel(df_train, df_val)

    surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])

    y_train_pred = surrogate.predict(df_train['seq'])
    y_val_pred = surrogate.predict(df_val['seq'])
    y_test_pred = surrogate.predict(df_test['seq'])

    fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
    ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
    ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
    ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

    mse = mean_squared_error(y_train, y_train_pred)
    corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
    ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

    mse = mean_squared_error(y_val, y_val_pred)
    corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
    ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

    mse = mean_squared_error(y_test, y_test_pred)
    corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
    ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

    for i in range(3):
        ax[i].set_xlabel('True')
        ax[i].set_ylabel('Pred')

    plt.show()

    y_pred = surrogate.predict(df['seq'])
    assert y_pred.shape[0] == df_results.shape[0]

    df_results[f'pred_{model_choices}_regfit_d{sample_size}_{fold}'] = y_pred
    df_results.to_csv(results_file, index=False)

    print(df_results.columns[df_results.columns.str.contains('pred')])

    surrogate = None

    time.sleep(5)
    gc.collect()
    time.sleep(5)

    torch.cuda.empty_cache()


In [None]:
# sample_size = 50
# fold = 1

# train_mask = sample_dict[sample_size][fold]
# _, val_mask, test_mask = get_split_mask(df, omit_zero=False)

# df_train = df.loc[train_mask]
# df_val = df[val_mask]
# df_test = df[test_mask]

# y_train = df_train['fitness_raw'].to_numpy().astype(np.float32)
# y_val = df_val['fitness_raw'].to_numpy().astype(np.float32)
# y_test = df_test['fitness_raw'].to_numpy().astype(np.float32)

# print(f'train size {df_train.shape[0]}')

# plt.figure(figsize=(5,3))
# plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
# plt.show() 

In [None]:
# surrogate.trainmodel(df_train, df_val)

In [None]:
# surrogate.trainer.checkpoint_callback.best_model_path

In [None]:
# surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])

In [None]:
# y_train_pred = surrogate.predict(df_train['seq'])
# y_val_pred = surrogate.predict(df_val['seq'])
# y_test_pred = surrogate.predict(df_test['seq'])

# fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
# ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
# ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
# ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

# mse = mean_squared_error(y_train, y_train_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
# ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# mse = mean_squared_error(y_val, y_val_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
# ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# mse = mean_squared_error(y_test, y_test_pred)
# corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
# ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

# for i in range(3):
#     ax[i].set_xlabel('True')
#     ax[i].set_ylabel('Pred')

# plt.show()

In [None]:
# y_pred = surrogate.predict(df['seq'])
# assert y_pred.shape[0] == df_results.shape[0]

# df_results[f'pred_{model_choices}_regfit_d{sample_size}_{fold}'] = y_pred
# df_results.to_csv(results_file, index=False)

# df_results.columns[df_results.columns.str.contains('pred')]

### Contrastive Finetuning

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'

config={'epoch': 60, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'model_checkpoint': True,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'use_seq_head': True,
        'device': 'gpu'}

if model_choices == 'esmc':
    surrogate = ESMCConFit(name='esmc_600m', config=config)
elif model_choices == 'esm2':
    surrogate = ESM2ConFit(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

surrogate.print_trainable_parameters(surrogate)

In [None]:
import gc

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'TEM1_WT.fasta'), mode='str')[0]

In [None]:
# sample_size = 200

# for fold in range(5):

#     print(f'd_{sample_size} fold {fold}')

#     surrogate = ESMCConFit(name='esmc_600m', config=config)
#     surrogate.print_trainable_parameters(surrogate)

#     train_mask = sample_dict[sample_size][fold]
#     _, val_mask, test_mask = get_split_mask(df, omit_zero=False)

#     df_train = df.loc[train_mask]
#     df_val = df[val_mask]
#     df_test = df[test_mask]

#     y_train = df_train['fitness_raw'].to_numpy().astype(np.float32)
#     y_val = df_val['fitness_raw'].to_numpy().astype(np.float32)
#     y_test = df_test['fitness_raw'].to_numpy().astype(np.float32)

#     print(f'train size {df_train.shape[0]}')

#     plt.figure(figsize=(5,3))
#     plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
#     plt.show()

#     surrogate.trainmodel(df_train, wt_sequence, val=df_val)

#     # surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])

#     y_train_pred = surrogate.predict(df_train['seq'], wt_sequence)
#     y_val_pred = surrogate.predict(df_val['seq'], wt_sequence)
#     y_test_pred = surrogate.predict(df_test['seq'], wt_sequence)

#     fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
#     ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
#     ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
#     ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

#     mse = mean_squared_error(y_train, y_train_pred)
#     corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
#     ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

#     mse = mean_squared_error(y_val, y_val_pred)
#     corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
#     ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

#     mse = mean_squared_error(y_test, y_test_pred)
#     corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
#     ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

#     for i in range(3):
#         ax[i].set_xlabel('True')
#         ax[i].set_ylabel('Pred')

#     plt.show()

#     y_pred = surrogate.predict(df['seq'], wt_sequence)
#     assert y_pred.shape[0] == df_results.shape[0]

#     df_results[f'pred_{model_choices}_confit_d{sample_size}_{fold}'] = y_pred
#     df_results.to_csv(results_file, index=False)

#     print(df_results.columns[df_results.columns.str.contains('pred')])

#     surrogate = None

#     time.sleep(5)
#     gc.collect()
#     time.sleep(5)

#     torch.cuda.empty_cache()


In [None]:
sample_size = 200
fold = 4

train_mask = sample_dict[sample_size][fold]
_, val_mask, test_mask = get_split_mask(df, omit_zero=False)

df_train = df.loc[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

y_train = df_train['fitness_raw'].to_numpy().astype(np.float32)
y_val = df_val['fitness_raw'].to_numpy().astype(np.float32)
y_test = df_test['fitness_raw'].to_numpy().astype(np.float32)

print(f'train size {df_train.shape[0]}')

plt.figure(figsize=(5,3))
plt.hist(df.loc[train_mask, 'fitness_norm'].to_numpy())
plt.show()

In [None]:
surrogate.trainmodel(df_train, wt_sequence, val=df_val)

In [None]:
# surrogate.trainer.checkpoint_callback.best_model_path

In [None]:
# surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])
# surrogate.load_state_dict(torch.load('/nethome/kgeorge/workspace/DomainPrediction/src/fitness_prediction/lightning_logs/version_3/checkpoints/best-checkpoint-epoch=52.ckpt')['state_dict'])

In [None]:
y_train_pred = surrogate.predict(df_train['seq'], wt_sequence)
y_val_pred = surrogate.predict(df_val['seq'], wt_sequence)
y_test_pred = surrogate.predict(df_test['seq'], wt_sequence)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(df['seq'], wt_sequence)
assert y_pred.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_confit_d{sample_size}_{fold}'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')] 