In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import torch

from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
from DomainPrediction.utils import helper
from DomainPrediction.eval import metrics
from DomainPrediction.al import top_model as topmodel
from DomainPrediction.al.embeddings import one_hot_encode

In [None]:
env = os.environ['CONDA_DEFAULT_ENV']
if env == 'workspace':
    sys.path.append('../../esm')
    from DomainPrediction.esm.esmc import ESMCLM
    from DomainPrediction.al.finetuning import ESMCLoraRegression, ESMCConFit
elif env == 'workspace-esm':
    from DomainPrediction.esm.esm2 import ESM2
    from DomainPrediction.al.finetuning import ESM2LoraRegression, ESM2ConFit
else:
    raise Exception('I designed this for my envs. Feel free to modify accordingly')

### Load Data

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/fitness_prediction/GB1'

In [None]:
file = os.path.join(data_path, 'dataset_gb1.csv')
df = pd.read_csv(file)

In [None]:
df.head()

In [None]:
df.shape

In [None]:
results_file = os.path.join(data_path, 'results_gb1.csv')

In [None]:
if os.path.isfile(results_file):
    df_results = pd.read_csv(results_file)
else:
    df_results = df.copy()

In [None]:
df_results.columns[df_results.columns.str.contains('pred')]

In [None]:
len(df_results.columns[df_results.columns.str.contains('pred')])

In [None]:
def get_split_mask(df, omit_zero=False):
    if omit_zero:
        train_mask = (df['split_id'] == 2) & (df['fitness_raw'] != 0)
    else:
        train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

In [None]:
def get_spearmanr_bootstrap(a, b, n=1000, ci = 95):
    assert type(a) == type(b) == np.ndarray
    assert len(a) == len(b)
    corr = []
    p_values = []
    np.random.seed(0)
    for _ in range(n):
        indices = np.random.choice(len(a), size=len(a), replace=True)
        res = stats.spearmanr(a[indices], b[indices])
        
        if not np.isnan(res.statistic):
            corr.append(res.statistic)
            p_values.append(res.pvalue)

    ci_lower, ci_upper = np.percentile(corr, [100-ci, ci]) 
    # stats.t.interval(confidence=ci, df=len(corr)-1, loc=np.mean(corr), scale=np.std(corr))
    mean_corr = np.mean(corr)
    p_value = np.mean(np.array(corr) < 0)

    return round(mean_corr, 2), round(ci_lower, 2), round(ci_upper, 2), p_value, corr, p_values

### OHE based models

#### Get embeddings and splits

In [None]:
embeddings = one_hot_encode(df['seq'])

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_log'].to_numpy().astype(np.float32)

In [None]:
print(f'train {train_mask.sum()} val {val_mask.sum()} test {test_mask.sum()}')

#### Linear model

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_val, y_val))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results['pred_OHE_ridge'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

#### Random Forest

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results['pred_OHE_RF'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

#### MLP

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [5300, 512, 1], 
        'epoch': 100, 
        'batch_size': 16,
        'patience': 100,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results['pred_OHE_MLP'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

### Embedding-based Models

#### Get Embedding and splits

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

In [None]:
embedding_choice = 'concat' # or concat

if embedding_choice == 'mean':
    embeddings = base_model.get_embeddings_mean(df['seq'])
elif embedding_choice == 'concat':
    embeddings = base_model.get_embeddings_flatten(df['seq'])
else:
    raise ValueError('model not found')

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

X_train = embeddings[train_mask]
X_val = embeddings[val_mask]
X_test = embeddings[test_mask]

y_train = df.loc[train_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_val = df.loc[val_mask, 'fitness_log'].to_numpy().astype(np.float32)
y_test = df.loc[test_mask, 'fitness_log'].to_numpy().astype(np.float32)

In [None]:
print(f'train {train_mask.sum()} val {val_mask.sum()} test {test_mask.sum()}')

#### Linear Model

In [None]:
surrogate = topmodel.RidgeSurrogate(alpha=1.0)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_{embedding_choice}_ridge'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

#### Random Forest

In [None]:
surrogate = topmodel.RFSurrogate()
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_{embedding_choice}_RF'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

#### MLP

In [None]:
print(f'input layer shape: {X_train.shape[1]}')

In [None]:
config={'layers': [1152, 512, 1], 
        'epoch': 300, 
        'batch_size': 16,
        'patience': 200,
        'early_stopping': True,
        'lr': 1e-4,
        'print_every_n_epoch': 10,
        'debug': True}
surrogate = topmodel.MLPSurrogate(config=config)
surrogate.trainmodel(X=X_train, y=y_train, val=(X_test, y_test))

In [None]:
y_train_pred = surrogate.predict(X_train)
y_val_pred = surrogate.predict(X_val)
y_test_pred = surrogate.predict(X_test)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(embeddings)
assert y_pred.shape[0] == embeddings.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_{embedding_choice}_MLP'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

### Zero-Shot PLMs

In [None]:
model_choices = 'esmc'

if model_choices == 'esm2':
    base_model = ESM2(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', device='gpu')
elif model_choices == 'esmc':
    base_model = ESMCLM(name='esmc_600m', device='gpu')
else:
    raise ValueError('model not found')

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'GB1_WT.fasta'), mode='str')[0]

In [None]:
zero_shot_method = 'masked_marginal'

In [None]:
y_pred = []
for i, row in tqdm(df.iterrows()):
    mt_sequence = row['seq']

    if zero_shot_method == 'wt_marginal':
        score, n_muts = base_model.get_wildtype_marginal(mt_sequence, wt_sequence)
        assert n_muts == row['n_mut']
    elif zero_shot_method == 'masked_marginal':
        score, n_muts = base_model.get_masked_marginal(mt_sequence, wt_sequence)
        assert n_muts == row['n_mut']
    elif zero_shot_method == 'pseudolikelihood':
        score = base_model.pseudolikelihood(mt_sequence)
    else:
        raise ValueError('method not found')

    y_pred.append(score)

y_pred = np.array(y_pred)

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

y = df['fitness_log'].to_numpy().astype(np.float32)

y_train = y[train_mask]
y_val = y[val_mask]
y_test = y[test_mask]

y_train_pred = y_pred[train_mask]
y_val_pred = y_pred[val_mask]
y_test_pred = y_pred[test_mask]

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3), layout='constrained')

ax[0].plot(y, y_pred, '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y, y_pred)
ax[0].set_title(f'Full Dataset \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mask = ~(df['fitness_raw'] == 0)
ax[1].plot(y[mask], y_pred[mask], '.', alpha=0.5)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y[mask], y_pred[mask])
ax[1].set_title(f'Omit fitness = 0 \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(2):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
assert y_pred.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_{zero_shot_method}'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

### Regression Finetuning

In [None]:
model_choices = 'esmc'  # 'esm2', 'esmc'

config={'epoch': 50, 
        'batch_size': 8,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'lr': 1e-3,
        'print_every_n_epoch': 1,
        'device': 'gpu'}

if model_choices == 'esmc':
    surrogate = ESMCLoraRegression(name='esmc_600m', config=config)
elif model_choices == 'esm2':
    surrogate = ESM2LoraRegression(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

surrogate.print_trainable_parameters(surrogate)

In [None]:
surrogate

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

y_train = df_train['fitness_log'].to_numpy().astype(np.float32)
y_val = df_val['fitness_log'].to_numpy().astype(np.float32)
y_test = df_test['fitness_log'].to_numpy().astype(np.float32)

In [None]:
surrogate.trainmodel(df_train, df_val)

In [None]:
surrogate.trainer.checkpoint_callback.best_model_path

In [None]:
surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])

In [None]:
y_train_pred = surrogate.predict(df_train['seq'])
y_val_pred = surrogate.predict(df_val['seq'])
y_test_pred = surrogate.predict(df_test['seq'])

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(df['seq'])
assert y_pred.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_regfit'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')]

### Contrastive Finetuning

In [None]:
model_choices = 'esm2'  # 'esm2', 'esmc'

config={'epoch': 60, 
        'batch_size': 32,
        'lambda': 0.1,
        'accumulate_batch_size': 32,
        'patience': 20,
        'early_stopping': False,
        'model_checkpoint': True,
        'lr': 5e-4,
        'print_every_n_epoch': 1,
        'use_seq_head': True,
        'device': 'gpu'}

if model_choices == 'esmc':
    surrogate = ESMCConFit(name='esmc_600m', config=config)
elif model_choices == 'esm2':
    surrogate = ESM2ConFit(model_path='/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t33_650M_UR50D.pt', config=config)

surrogate.print_trainable_parameters(surrogate)

In [None]:
surrogate

In [None]:
wt_sequence = helper.read_fasta(os.path.join(data_path, 'GB1_WT.fasta'), mode='str')[0]

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df, omit_zero=False)

df_train = df[train_mask]
df_val = df[val_mask]
df_test = df[test_mask]

y_train = df_train['fitness_log'].to_numpy().astype(np.float32)
y_val = df_val['fitness_log'].to_numpy().astype(np.float32)
y_test = df_test['fitness_log'].to_numpy().astype(np.float32)

In [None]:
surrogate.trainmodel(df_train, wt_sequence, val=df_val)

In [None]:
surrogate.trainer.checkpoint_callback.best_model_path

In [None]:
# surrogate.load_state_dict(torch.load(surrogate.trainer.checkpoint_callback.best_model_path)['state_dict'])
# surrogate.load_state_dict(torch.load('/nethome/kgeorge/workspace/DomainPrediction/src/fitness_prediction/lightning_logs/version_2/checkpoints/best-checkpoint-epoch=37.ckpt')['state_dict'])

In [None]:
y_train_pred = surrogate.predict(df_train['seq'], wt_sequence)
y_val_pred = surrogate.predict(df_val['seq'], wt_sequence)
y_test_pred = surrogate.predict(df_test['seq'], wt_sequence)

fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
ax[0].plot(y_train, y_train_pred, '.', alpha=0.5)
ax[1].plot(y_val, y_val_pred, '.', alpha=0.5)
ax[2].plot(y_test, y_test_pred, '.', alpha=0.5)

mse = mean_squared_error(y_train, y_train_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_train, y_train_pred)
ax[0].set_title(f'Train \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_val, y_val_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_val, y_val_pred)
ax[1].set_title(f'Val \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

mse = mean_squared_error(y_test, y_test_pred)
corr, ci_lower, ci_upper, *_  = get_spearmanr_bootstrap(y_test, y_test_pred)
ax[2].set_title(f'Test \nmse : {str(round(mse, 2))} \nspearman correlation = {corr} CI ({ci_lower}, {ci_upper})', size=10)

for i in range(3):
    ax[i].set_xlabel('True')
    ax[i].set_ylabel('Pred')

plt.show()

In [None]:
y_pred = surrogate.predict(df['seq'], wt_sequence)
assert y_pred.shape[0] == df_results.shape[0]

df_results[f'pred_{model_choices}_confit'] = y_pred
df_results.to_csv(results_file, index=False)

df_results.columns[df_results.columns.str.contains('pred')] 