In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from collections import defaultdict

In [None]:
from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/al_test_experiments/Tdomain'

In [None]:
results_file = os.path.join(data_path, 'results_tdomain_zeroshot.csv')
df_zeroshot = pd.read_csv(results_file)

In [None]:
results_file = os.path.join(data_path, 'results_tdomain_embed.csv')
df_embed = pd.read_csv(results_file)

In [None]:
results_file = os.path.join(data_path, 'results_tdomain_confit.csv')
df_confit = pd.read_csv(results_file)

In [None]:
df_zeroshot.head()

In [None]:
df_embed.head()

In [None]:
df_confit.head()

#### Tables

In [None]:
def get_split_mask(df):

    train_mask = (df['split_id'] == 2)

    val_mask = df['split_id'] == 1
    test_mask = df['split_id'] == 0

    return train_mask, val_mask, test_mask

In [None]:
def get_table(df, omit=True, fit_label='fitness_log', omit_label=None):
    table = {}
    for label in df.columns[df.columns.str.contains('pred')]:
        train_mask, val_mask, test_mask = get_split_mask(df)
        assert fit_label in df.columns
        
        if omit:
            assert omit_label is not None and omit_label in df.columns
            omit_mask = df[omit_label] != 0
            train_corr = round(stats.spearmanr(df.loc[train_mask & omit_mask, fit_label], df.loc[train_mask & omit_mask, label]).statistic, 2)
            val_corr = round(stats.spearmanr(df.loc[val_mask & omit_mask, fit_label], df.loc[val_mask & omit_mask, label]).statistic, 2)
            test_corr = round(stats.spearmanr(df.loc[test_mask & omit_mask, fit_label], df.loc[test_mask & omit_mask, label]).statistic, 2)
        else:
            train_corr = round(stats.spearmanr(df.loc[train_mask, fit_label], df.loc[train_mask, label]).statistic, 2)
            val_corr = round(stats.spearmanr(df.loc[val_mask, fit_label], df.loc[val_mask, label]).statistic, 2)
            test_corr = round(stats.spearmanr(df.loc[test_mask, fit_label], df.loc[test_mask, label]).statistic, 2)
        
        table[label] = [train_corr, val_corr, test_corr]

    df_table = pd.DataFrame(table).T
    df_table.columns = ['train', 'val', 'test']

    return df_table

In [None]:
res_table_embed = get_table(df_embed, omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
res_table_embed.head()

In [None]:
res_table_embed.loc[res_table_embed.index.str.contains('RF')]

In [None]:
res_table_zeroshot = get_table(df_zeroshot, omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
res_table_zeroshot.head()

In [None]:
res_table_zeroshot.loc[res_table_zeroshot.index.str.contains('masked')]

In [None]:
res_table_confit = get_table(df_confit, omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
res_table_confit

In [None]:
def get_panel(df, label, omit=True, fit_label='fitness_log', omit_label=None):
    train_mask, val_mask, test_mask = get_split_mask(df)
    assert fit_label in df.columns
    
    fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')

    if omit:
        assert omit_label is not None and omit_label in df.columns
        omit_mask = df[omit_label] != 0
        ax[0].plot(df.loc[train_mask & omit_mask, fit_label], df.loc[train_mask & omit_mask, label], '.', alpha=0.9)
        ax[1].plot(df.loc[val_mask & omit_mask, fit_label], df.loc[val_mask & omit_mask, label], '.', alpha=0.9)
        ax[2].plot(df.loc[test_mask & omit_mask,fit_label], df.loc[test_mask & omit_mask, label], '.', alpha=0.9)

        train_corr = round(stats.spearmanr(df.loc[train_mask & omit_mask, fit_label], df.loc[train_mask & omit_mask, label]).statistic, 2)
        val_corr = round(stats.spearmanr(df.loc[val_mask & omit_mask, fit_label], df.loc[val_mask & omit_mask, label]).statistic, 2)
        test_corr = round(stats.spearmanr(df.loc[test_mask & omit_mask, fit_label], df.loc[test_mask & omit_mask, label]).statistic, 2)

        ax[0].set_title(f'Train\nspearman corr {train_corr}')
        ax[1].set_title(f'Val\nspearman corr {val_corr}')
        ax[2].set_title(f'Test\nspearman corr {test_corr}')
    else:
        ax[0].plot(df.loc[train_mask, fit_label], df.loc[train_mask, label], '.', alpha=0.9)
        ax[1].plot(df.loc[val_mask, fit_label], df.loc[val_mask, label], '.', alpha=0.9)
        ax[2].plot(df.loc[test_mask, fit_label], df.loc[test_mask, label], '.', alpha=0.9)

        train_corr = round(stats.spearmanr(df.loc[train_mask, fit_label], df.loc[train_mask, label]).statistic, 2)
        val_corr = round(stats.spearmanr(df.loc[val_mask, fit_label], df.loc[val_mask, label]).statistic, 2)
        test_corr = round(stats.spearmanr(df.loc[test_mask, fit_label], df.loc[test_mask, label]).statistic, 2)

        ax[0].set_title(f'Train\nspearman corr {train_corr}')
        ax[1].set_title(f'Val\nspearman corr {val_corr}')
        ax[2].set_title(f'Test\nspearman corr {test_corr}')

    for i in range(3):
        ax[i].set_xlabel('True')
        ax[i].set_ylabel('Pred')

    plt.show()

In [None]:
get_panel(df_embed, label='pred_ESM650M_res_mean_ridge', omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
get_panel(df_zeroshot, label='pred_ESM650M_masked_marginal', omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
get_panel(df_confit, label='pred_ESM650M_confit', omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
train_mask, val_mask, test_mask = get_split_mask(df_zeroshot)
fig, ax = plt.subplots(1,3, figsize=(10,3), layout='constrained')
_omit_mask = df_zeroshot['fitness_raw'] != 0
ax[0].hist(df_zeroshot.loc[train_mask & _omit_mask, 'n_mut'])
ax[1].hist(df_zeroshot.loc[val_mask & _omit_mask, 'n_mut'])
ax[2].hist(df_zeroshot.loc[test_mask & _omit_mask,'n_mut'])
for i in range(3):
    ax[i].set_xlabel('# mutations')

plt.show()

In [None]:
def plot_comparison_corr_methods(df):
    ridge = defaultdict(list)
    rf    = defaultdict(list)
    mlp   = defaultdict(list)
    for idx in df.index:
        if not 'cls' in idx:
            if 'ridge' in idx:
                ridge['train'].append(df.loc[idx, 'train'])
                ridge['val'].append(df.loc[idx, 'val'])
                ridge['test'].append(df.loc[idx, 'test'])
            elif 'RF' in idx:
                rf['train'].append(df.loc[idx, 'train'])
                rf['val'].append(df.loc[idx, 'val'])
                rf['test'].append(df.loc[idx, 'test'])
            elif 'MLP' in idx:
                mlp['train'].append(df.loc[idx, 'train'])
                mlp['val'].append(df.loc[idx, 'val'])
                mlp['test'].append(df.loc[idx, 'test'])
            else:
                raise Exception('Huh?')
            
    def my_floor(a, precision=1):
        return np.round(a - 0.5 * 10**(-precision), precision)
            
    fig, ax = plt.subplots(1, 3, figsize=(10, 3), layout='constrained')
    for i, dset in enumerate(['train', 'val', 'test']):
        # ax[i].hist(ridge[dset], label='ridge', alpha=0.8, 
        #            bins=np.arange(my_floor(np.nanmin(ridge[dset])),1.05,0.1))
        # ax[i].hist(rf[dset], label='rf', alpha=0.5,
        #            bins=np.arange(my_floor(np.nanmin(rf[dset])),1.05,0.1))
        # ax[i].hist(mlp[dset], label='mlp', alpha=0.3,
        #            bins=np.arange(my_floor(np.nanmin(mlp[dset])),1.05,0.1))
        
        _min = my_floor(np.nanmin(ridge[dset]+rf[dset]+mlp[dset]))
        ax[i].hist([ridge[dset], rf[dset], mlp[dset]], 
                   label=['linear', 'rf', 'mlp'], 
                   alpha=0.8, 
                   bins=np.arange(my_floor(_min),1.05,0.1))


        ax[i].legend()

    plt.show()

In [None]:
plot_comparison_corr_methods(res_table_embed)

In [None]:
omit = True
fig, ax = plt.subplots(2, 2, figsize=(6,6), layout='constrained')

if omit:
    _omit_mask = df_zeroshot['fitness_raw'] != 0
    for _label, _ax in zip(df_zeroshot.columns[df_zeroshot.columns.str.contains('perplexity')], ax.flatten()):
        _ax.plot(df_zeroshot.loc[_omit_mask, 'fitness_raw'], df_zeroshot.loc[_omit_mask, _label], '.', alpha=0.8)
        _ax.set_xlabel('true / titer')
        _ax.set_ylabel('perplexity')
        _ax.set_title(_label.replace('pred_', '').replace('_perplexity', ''))
else:
    for _label, _ax in zip(df_zeroshot.columns[df_zeroshot.columns.str.contains('perplexity')], ax.flatten()):
        _ax.plot(df_zeroshot['fitness_raw'], df_zeroshot[_label], '.', alpha=0.8)
        _ax.set_xlabel('true / titer')
        _ax.set_ylabel('perplexity')
        _ax.set_title(_label.replace('pred_', '').replace('_perplexity', ''))

In [None]:
def get_table_all(df, omit=True, fit_label='fitness_log', omit_label='fitness_raw'):
    table = {}
    for label in df.columns[df.columns.str.contains('pred')]:
        assert fit_label in df.columns
        
        if omit:
            assert omit_label is not None and omit_label in df.columns
            omit_mask = df[omit_label] != 0
            corr = round(stats.spearmanr(df.loc[omit_mask, fit_label], df.loc[omit_mask, label]).statistic, 2)
        else:
            corr = round(stats.spearmanr(df.loc[fit_label], df.loc[label]).statistic, 2)
        
        table[label] = [corr]

    df_table = pd.DataFrame(table).T
    df_table.columns = ['corr']

    return df_table

In [None]:
get_table_all(df_zeroshot, omit=True, fit_label='fitness_log', omit_label='fitness_raw')

In [None]:
omit = True
# predictor = 'masked_marginal'
predictor = 'wt_marginal'
fig, ax = plt.subplots(2, 2, figsize=(6,6), layout='constrained')

if omit:
    _omit_mask = df_zeroshot['fitness_raw'] != 0
    for _label, _ax in zip(df_zeroshot.columns[df_zeroshot.columns.str.contains(predictor)], ax.flatten()):
        _ax.plot(df_zeroshot.loc[_omit_mask, 'fitness_log'], df_zeroshot.loc[_omit_mask, _label], '.', alpha=0.8)
        corr = round(stats.spearmanr(df_zeroshot.loc[_omit_mask, 'fitness_log'], df_zeroshot.loc[_omit_mask, _label]).statistic, 2)
        _ax.set_xlabel('true')
        _ax.set_ylabel('pred')
        _ax.set_title(f"{_label.replace('pred_', '')}\nspearman corr {corr}", size=10)
else:
    for _label, _ax in zip(df_zeroshot.columns[df_zeroshot.columns.str.contains(predictor)], ax.flatten()):
        _ax.plot(df_zeroshot['fitness_log'], df_zeroshot[_label], '.', alpha=0.8)
        corr = round(stats.spearmanr(df_zeroshot['fitness_log'], df_zeroshot[_label]).statistic, 2)
        _ax.set_xlabel('true')
        _ax.set_ylabel('perplexity')
        _ax.set_title(f"{_label.replace('pred_', '')}\nspearman corr {corr}", size=10)

In [None]:
def _plot(df, mask, predictor, label):
    fig, ax = plt.subplots(1, 2, figsize=(7,3), layout='constrained')

    ax[0].plot(df.loc[mask, 'fitness_log'], df.loc[mask, label], '.', alpha=0.8)
    corr = round(stats.spearmanr(df.loc[mask, 'fitness_log'], df.loc[mask, label]).statistic, 2)
    ax[0].set_xlabel('true')
    ax[0].set_ylabel('pred')
    ax[0].set_title(f"{_label.replace('pred_', '')}\nspearman corr {corr}", size=10)

    ax[1].hist(df_zeroshot.loc[mask, 'n_mut'])
    ax[1].set_xlabel('# mutations')
    plt.show()

predictor = 'masked_marginal'
_label = 'pred_ESMC600M_masked_marginal'
_omit_mask = df_zeroshot['fitness_raw'] != 0
# _omit_mask = df_zeroshot['fitness_raw'] > -1000

_plot(df_zeroshot, _omit_mask, predictor, _label)

mut_mask_20 = df_zeroshot['n_mut'] < 20
_mask = _omit_mask & mut_mask_20
_plot(df_zeroshot, _mask, predictor, _label)
_mask = _omit_mask & ~mut_mask_20
_plot(df_zeroshot, _mask, predictor, _label)

_mask = _omit_mask & (df_zeroshot['n_mut'] < 2.5)
_plot(df_zeroshot, _mask, predictor, _label)

_mask = _omit_mask & (df_zeroshot['n_mut'].between(2.5,20))
_plot(df_zeroshot, _mask, predictor, _label)
