In [None]:
import sys
sys.path.append('..')

In [None]:
import os
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from sklearn.metrics import mean_squared_error
from scipy import stats

In [None]:
from DomainPrediction.utils import helper

In [None]:
sys.path.append('../../esm')
from DomainPrediction.esm.esm3 import ESM3LM
from DomainPrediction.esm.esmc import ESMCLM

### Load and Verify Generations

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/new_system'

In [None]:
fasta_PSG = os.path.join(data_path, 'PSG_KG.fasta')
fasta_pCK = os.path.join(data_path, 'pCK_KG.fasta')

pCK_base_seq = helper.read_fasta(fasta_pCK, mode='str')[0]
PSG_base_seq = helper.read_fasta(fasta_PSG, mode='str')[0]

In [None]:
PSG_annotations = {'A': [i for i in range(0, 954)], 
                   'T': [i for i in range(968, 1039)], 
                   'C': [i for i in range(1057, 1489)]}

pCK_star_annotations = {'A': [i for i in range(0, 485)], 
                        'T': [i for i in range(485, 610)], 
                        'C': [i for i in range(610, 1017)]}

In [None]:
fasta_files_pcK = ['pCK_KG_esm3_str_1000.fasta', 'pCK_KG_evodiff_1000.fasta']
fasta_files_PSG = ['PSG_KG_esm3_str_1000.fasta', 'PSG_KG_evodiff_1000.fasta']

def create_df(fasta_files: list, annotations: dict, wt_seq: str):

    def _extract_seq(seq, mask):
        return ''.join([seq[i] for i in range(len(seq)) if i in mask])
    
    gen_list = []
    for _name in fasta_files:
        _fasta = os.path.join(data_path, _name)
        gen_fasta = helper.read_fasta(_fasta)

        for rec in gen_fasta:
            gen_list.append({
                'name': rec.id,
                'full_seq': str(rec.seq),
                'masked_seq': ''.join([str(rec.seq)[i] for i in range(len(str(rec.seq))) if i not in annotations['A']+annotations['C']]),
            })

            assert _extract_seq(str(rec.seq), annotations['A']) == _extract_seq(wt_seq, annotations['A'])
            assert _extract_seq(str(rec.seq), annotations['C']) == _extract_seq(wt_seq, annotations['C'])
    
    return pd.DataFrame(gen_list)

df_pCK = create_df(fasta_files_pcK, pCK_star_annotations, pCK_base_seq)
df_PSG = create_df(fasta_files_PSG, PSG_annotations, PSG_base_seq)

In [None]:
assert (df_PSG['masked_seq'].apply(len) == 103).all()
assert (df_pCK['masked_seq'].apply(len) == 125).all()

assert df_PSG.shape[0] == 2000 == df_pCK.shape[0]

In [None]:
df_pCK.head()

In [None]:
df_PSG.head()

In [None]:
PSG_masked_WT = ''.join([PSG_base_seq[i] for i in range(len(PSG_base_seq)) if i not in PSG_annotations['A']+PSG_annotations['C']])
assert len(PSG_masked_WT) == 103

pCK_masked_WT = ''.join([pCK_base_seq[i] for i in range(len(pCK_base_seq)) if i not in pCK_star_annotations['A']+pCK_star_annotations['C']])
pCK_masked_anchor = df_pCK['masked_seq'][580] ## anchor sequence
assert len(pCK_masked_WT) == 125 == len(pCK_masked_anchor)

PSG_full_WT = PSG_base_seq
pCK_full_WT = pCK_base_seq
pCK_full_anchor = df_pCK['full_seq'][580] ## anchor sequence

In [None]:
def hamming_distance(seq1, seq2):
    # Ensure the sequences are of the same length
    if len(seq1) != len(seq2):
        raise ValueError("Sequences must be of equal length to compute Hamming distance.")
    
    # Count differences
    return sum(c1 != c2 for c1, c2 in zip(seq1, seq2))

df_PSG['n_mut'] = df_PSG['masked_seq'].apply(lambda x: hamming_distance(x, PSG_masked_WT))
df_pCK['n_mut'] = df_pCK['masked_seq'].apply(lambda x: hamming_distance(x, pCK_masked_WT))
df_pCK['n_mut_anchor'] = df_pCK['masked_seq'].apply(lambda x: hamming_distance(x, pCK_masked_anchor))

In [None]:
assert (df_PSG['name'].str.contains('esm3')).sum() == 1000
assert (df_PSG['name'].str.contains('evodiff')).sum() == 1000
assert (df_pCK['name'].str.contains('esm3')).sum() == 1000
assert (df_pCK['name'].str.contains('evodiff')).sum() == 1000

fig, ax = plt.subplots(1, 3, figsize=(10, 3))

ax[0].hist(df_PSG.loc[(df_PSG['name'].str.contains('esm3')), 'n_mut'], bins=20, label='esm3', alpha=0.5)
ax[0].hist(df_PSG.loc[(df_PSG['name'].str.contains('evodiff')), 'n_mut'], bins=20, label='evodiff', alpha=0.5)
ax[0].set_title('PSG')

ax[1].hist(df_pCK.loc[(df_pCK['name'].str.contains('esm3')), 'n_mut'], bins=20, label='esm3', alpha=0.5)
ax[1].hist(df_pCK.loc[(df_pCK['name'].str.contains('evodiff')), 'n_mut'], bins=20, label='evodiff', alpha=0.5)
ax[1].set_title('pCK')

ax[2].hist(df_pCK.loc[(df_pCK['name'].str.contains('esm3')), 'n_mut_anchor'], bins=20, label='esm3', alpha=0.5)
ax[2].hist(df_pCK.loc[(df_pCK['name'].str.contains('evodiff')), 'n_mut_anchor'], bins=20, label='evodiff', alpha=0.5)
ax[2].set_title('pCK anchor')

plt.legend()
plt.show()

### Predictiing Zero-shot fitness values

#### PSG

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')
esm3 = ESM3LM(device='gpu')

In [None]:
y_pred = []
for i, row in tqdm(df_PSG.iterrows()):

    masked_sequence = row['masked_seq']
    
    esmc_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(masked_sequence, PSG_masked_WT)
    assert n_muts == row['n_mut']
    esmc_score_masked_marginal, n_muts = esmc.get_masked_marginal(masked_sequence, PSG_masked_WT)
    assert n_muts == row['n_mut']
    esmc_score_pll = esmc.pseudolikelihood(masked_sequence)[0]

    esm3_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(masked_sequence, PSG_masked_WT)
    assert n_muts == row['n_mut']
    esm3_score_masked_marginal, n_muts = esm3.get_masked_marginal(masked_sequence, PSG_masked_WT)
    assert n_muts == row['n_mut']
    esm3_score_pll = esm3.pseudolikelihood(masked_sequence)[0]
    

    full_sequence = row['full_seq']
    
    esmc_full_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(full_sequence, PSG_full_WT)
    assert n_muts == row['n_mut']
    esmc_full_score_masked_marginal, n_muts = esmc.get_masked_marginal(full_sequence, PSG_full_WT)
    assert n_muts == row['n_mut']

    esm3_full_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(full_sequence, PSG_full_WT)
    assert n_muts == row['n_mut']
    esm3_full_score_masked_marginal, n_muts = esm3.get_masked_marginal(full_sequence, PSG_full_WT)
    assert n_muts == row['n_mut']

    y_pred.append({
        'esmc_wt_marginal': esmc_score_wt_marginal,
        'esmc_masked_marginal': esmc_score_masked_marginal,
        'esmc_pll': esmc_score_pll,
        
        'esm3_wt_marginal': esm3_score_wt_marginal,
        'esm3_masked_marginal': esm3_score_masked_marginal,
        'esm3_pll': esm3_score_pll,

        'esmc_full_wt_marginal': esmc_full_score_wt_marginal,
        'esmc_full_masked_marginal': esmc_full_score_masked_marginal,

        'esm3_full_wt_marginal': esm3_full_score_wt_marginal,
        'esm3_full_masked_marginal': esm3_full_score_masked_marginal,
    })

In [None]:
df_pred = pd.DataFrame(y_pred)

In [None]:
df_pred.head()

In [None]:
df_PSG = pd.concat([df_PSG, df_pred], axis=1)

In [None]:
df_PSG.to_csv(os.path.join(data_path, 'PSG_KG_results.csv'), index=False)

#### pCK for given WT

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')
esm3 = ESM3LM(device='gpu')

In [None]:
y_pred = []
for i, row in tqdm(df_pCK.iterrows()):

    masked_sequence = row['masked_seq']
    
    esmc_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(masked_sequence, pCK_masked_WT)
    assert n_muts == row['n_mut']
    esmc_score_masked_marginal, n_muts = esmc.get_masked_marginal(masked_sequence, pCK_masked_WT)
    assert n_muts == row['n_mut']
    esmc_score_pll = esmc.pseudolikelihood(masked_sequence)[0]

    esm3_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(masked_sequence, pCK_masked_WT)
    assert n_muts == row['n_mut']
    esm3_score_masked_marginal, n_muts = esm3.get_masked_marginal(masked_sequence, pCK_masked_WT)
    assert n_muts == row['n_mut']
    esm3_score_pll = esm3.pseudolikelihood(masked_sequence)[0]
    

    full_sequence = row['full_seq']
    
    esmc_full_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(full_sequence, pCK_full_WT)
    assert n_muts == row['n_mut']
    esmc_full_score_masked_marginal, n_muts = esmc.get_masked_marginal(full_sequence, pCK_full_WT)
    assert n_muts == row['n_mut']

    esm3_full_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(full_sequence, pCK_full_WT)
    assert n_muts == row['n_mut']
    esm3_full_score_masked_marginal, n_muts = esm3.get_masked_marginal(full_sequence, pCK_full_WT)
    assert n_muts == row['n_mut']

    y_pred.append({
        'esmc_wt_marginal': esmc_score_wt_marginal,
        'esmc_masked_marginal': esmc_score_masked_marginal,
        'esmc_pll': esmc_score_pll,
        
        'esm3_wt_marginal': esm3_score_wt_marginal,
        'esm3_masked_marginal': esm3_score_masked_marginal,
        'esm3_pll': esm3_score_pll,

        'esmc_full_wt_marginal': esmc_full_score_wt_marginal,
        'esmc_full_masked_marginal': esmc_full_score_masked_marginal,

        'esm3_full_wt_marginal': esm3_full_score_wt_marginal,
        'esm3_full_masked_marginal': esm3_full_score_masked_marginal,
    })

In [None]:
df_pred = pd.DataFrame(y_pred)

In [None]:
df_pred.head()

In [None]:
df_pCK = pd.concat([df_pCK, df_pred], axis=1)

In [None]:
df_pCK.head()

In [None]:
df_pCK.to_csv(os.path.join(data_path, 'pCK_KG_base_results.csv'), index=False)

#### pCK for anchor seq

In [None]:
esmc = ESMCLM(name='esmc_600m', device='gpu')
esm3 = ESM3LM(device='gpu')

In [None]:
y_pred = []
for i, row in tqdm(df_pCK.iterrows()):

    masked_sequence = row['masked_seq']
    
    esmc_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(masked_sequence, pCK_masked_anchor)
    assert n_muts == row['n_mut_anchor']
    esmc_score_masked_marginal, n_muts = esmc.get_masked_marginal(masked_sequence, pCK_masked_anchor)
    assert n_muts == row['n_mut_anchor']

    esm3_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(masked_sequence, pCK_masked_anchor)
    assert n_muts == row['n_mut_anchor']
    esm3_score_masked_marginal, n_muts = esm3.get_masked_marginal(masked_sequence, pCK_masked_anchor)
    assert n_muts == row['n_mut_anchor']
    

    full_sequence = row['full_seq']
    
    esmc_full_score_wt_marginal, n_muts = esmc.get_wildtype_marginal(full_sequence, pCK_full_anchor)
    assert n_muts == row['n_mut_anchor']
    esmc_full_score_masked_marginal, n_muts = esmc.get_masked_marginal(full_sequence, pCK_full_anchor)
    assert n_muts == row['n_mut_anchor']

    esm3_full_score_wt_marginal, n_muts = esm3.get_wildtype_marginal(full_sequence, pCK_full_anchor)
    assert n_muts == row['n_mut_anchor']
    esm3_full_score_masked_marginal, n_muts = esm3.get_masked_marginal(full_sequence, pCK_full_anchor)
    assert n_muts == row['n_mut_anchor']

    y_pred.append({
        'esmc_wt_marginal': esmc_score_wt_marginal,
        'esmc_masked_marginal': esmc_score_masked_marginal,
        
        'esm3_wt_marginal': esm3_score_wt_marginal,
        'esm3_masked_marginal': esm3_score_masked_marginal,

        'esmc_full_wt_marginal': esmc_full_score_wt_marginal,
        'esmc_full_masked_marginal': esmc_full_score_masked_marginal,

        'esm3_full_wt_marginal': esm3_full_score_wt_marginal,
        'esm3_full_masked_marginal': esm3_full_score_masked_marginal,
    })

In [None]:
df_pred = pd.DataFrame(y_pred)

In [None]:
df_pred.head()

In [None]:
df_pCK = pd.concat([df_pCK, df_pred], axis=1)

In [None]:
df_pCK.head()

In [None]:
df_pCK.to_csv(os.path.join(data_path, 'pCK_KG_anchor_results.csv'), index=False)

### Selection

#### PSG selection

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/new_system'

df_PSG = pd.read_csv(os.path.join(data_path, 'PSG_KG_results.csv'))

In [None]:
df_PSG.head()

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_PSG['name'].str.contains('esm3')
print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_PSG['esmc_masked_marginal'][mask], df_PSG['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_PSG['esmc_wt_marginal'][mask], df_PSG['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_PSG['esm3_masked_marginal'][mask], df_PSG['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_PSG['esm3_wt_marginal'][mask], df_PSG['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_PSG['name'].str.contains('evodiff')
print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_PSG['esmc_masked_marginal'][mask], df_PSG['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_PSG['esmc_wt_marginal'][mask], df_PSG['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_PSG['esm3_masked_marginal'][mask], df_PSG['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_PSG['esm3_wt_marginal'][mask], df_PSG['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences')


In [None]:
fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

val_mask = df_PSG['esmc_wt_marginal'] > -300

mask = df_PSG['name'].str.contains('esm3') & val_mask
print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_PSG['esmc_masked_marginal'][mask], df_PSG['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_PSG['esmc_wt_marginal'][mask], df_PSG['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_PSG['esm3_masked_marginal'][mask], df_PSG['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_PSG['esm3_wt_marginal'][mask], df_PSG['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_PSG['name'].str.contains('evodiff') & val_mask
print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_PSG['esmc_masked_marginal'][mask], df_PSG['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_PSG['esmc_wt_marginal'][mask], df_PSG['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_PSG['esm3_masked_marginal'][mask], df_PSG['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_PSG['esm3_wt_marginal'][mask], df_PSG['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences')

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_PSG['name'].str.contains('esm3')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_PSG[x_anchor][~mask], df_PSG[predictors[i]][~mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_PSG[x_anchor][mask], df_PSG[predictors[i]][mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_PSG['name'].str.contains('esm3')
val_mask = df_PSG['esmc_wt_marginal'] > -300

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_PSG[x_anchor][~mask & val_mask], df_PSG[predictors[i]][~mask & val_mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_PSG[x_anchor][mask & val_mask], df_PSG[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_PSG['name'].str.contains('esm3')
val_mask = (df_PSG['esmc_wt_marginal'] > -300) & df_PSG['esmc_masked_marginal'].between(-10, 10)

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

print(f'no of esm3 sequences: {(mask & val_mask).sum()}')
print(f'no of evodiff sequences: {(~mask & val_mask).sum()}')
for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_PSG[x_anchor][mask & val_mask], df_PSG[predictors[i]][mask & val_mask], alpha=0.4, label='esm3')
    ax_.scatter(df_PSG[x_anchor][~mask & val_mask], df_PSG[predictors[i]][~mask & val_mask], alpha=0.7, label='evodiff')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
## Select ESM3 sequences by binning

mask = df_PSG['name'].str.contains('esm3') & (df_PSG['esmc_wt_marginal'] > -300) & df_PSG['esmc_masked_marginal'].between(-10, 10)
df_filtered = df_PSG[mask]

# Define the number of bins
num_bins = 6
# Generate the bin edges
bin_edges = np.linspace(-10, 10, num_bins + 1)

# Bin the esmc_masked_marginal column
df_filtered['bin'] = pd.cut(df_filtered['esmc_masked_marginal'], bins=bin_edges, labels=False)

print(df_filtered['bin'].unique())

# Select the sequence with the highest value in each bin
df_selected = df_filtered.groupby('bin').apply(lambda x: x.loc[x['esmc_masked_marginal'].idxmax()])

dist_matrix = np.eye(df_selected.shape[0])
for i in range(df_selected['masked_seq'].shape[0]):
    for j in range(df_selected['masked_seq'].shape[0]):
        dist_matrix[i, j] = helper.hamming_distance(df_selected['masked_seq'][i], df_selected['masked_seq'][j])

plt.imshow(dist_matrix, cmap='viridis')
plt.colorbar()

esm3_selected = df_selected['name'].to_list()

In [None]:
mask = df_PSG['name'].str.contains('evodiff') & (df_PSG['esmc_wt_marginal'] > -300) & df_PSG['esmc_masked_marginal'].between(-10, 10)

df_filtered = df_PSG[mask]

# Define the number of bins
num_bins = 4
# Generate the bin edges
bin_edges = np.linspace(df_filtered['esmc_wt_marginal'].min()-1, df_filtered['esmc_wt_marginal'].max()+1, num_bins + 1)

# Bin the esmc_masked_marginal column
df_filtered['bin'] = pd.cut(df_filtered['esmc_wt_marginal'], bins=bin_edges, labels=False)

print(df_filtered['bin'].unique())

# Select the sequence with the highest value in each bin
df_selected = df_filtered.groupby('bin').apply(lambda x: x.loc[x['esmc_masked_marginal'].idxmax()])

dist_matrix = np.eye(df_selected.shape[0])
for i in range(df_selected['masked_seq'].shape[0]):
    for j in range(df_selected['masked_seq'].shape[0]):
        dist_matrix[i, j] = helper.hamming_distance(df_selected['masked_seq'][i], df_selected['masked_seq'][j])

plt.imshow(dist_matrix, cmap='viridis')
plt.colorbar()

evodiff_selected = df_selected['name'].to_list()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_PSG['name'].str.contains('esm3')
val_mask = (df_PSG['esmc_wt_marginal'] > -300) & df_PSG['esmc_masked_marginal'].between(-10, 10)

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

print(f'no of esm3 sequences: {(mask & val_mask).sum()}')
print(f'no of evodiff sequences: {(~mask & val_mask).sum()}')
for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_PSG[x_anchor][mask & val_mask], df_PSG[predictors[i]][mask & val_mask], alpha=0.4, label='esm3')
    ax_.scatter(df_PSG[x_anchor][~mask & val_mask], df_PSG[predictors[i]][~mask & val_mask], alpha=0.7, label='evodiff')
    ax_.scatter(df_PSG[x_anchor][df_PSG['name'].isin(esm3_selected)], df_PSG[predictors[i]][df_PSG['name'].isin(esm3_selected)], alpha=0.7, label='selected_esm3', color='red')
    ax_.scatter(df_PSG[x_anchor][df_PSG['name'].isin(evodiff_selected)], df_PSG[predictors[i]][df_PSG['name'].isin(evodiff_selected)], alpha=0.7, label='selected_evodiff', color='green')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
df_selected = df_PSG[df_PSG['name'].isin(esm3_selected+evodiff_selected)]

In [None]:
df_selected

In [None]:
# fasta_full = os.path.join(data_path, 'PSG_KG_selected_full.fasta')
# fasta_masked = os.path.join(data_path, 'PSG_KG_selected_masked.fasta')
# for i, row in df_selected.iterrows():

#     helper.create_fasta({
#         row['name']: row['full_seq']
#     }, file=fasta_full, append=True)

#     helper.create_fasta({
#         row['name']: row['masked_seq']
#     }, file=fasta_masked, append=True)

#### pCK selection

In [None]:
data_path = '/nethome/kgeorge/workspace/DomainPrediction/Data/new_system'

df_pCK = pd.read_csv(os.path.join(data_path, 'pCK_KG_base_results.csv'))
df_pCK_anchor = pd.read_csv(os.path.join(data_path, 'pCK_KG_anchor_results.csv'))

In [None]:
df_pCK.head()

In [None]:
df_pCK_anchor.head()

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK['name'].str.contains('esm3')

print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_pCK['esmc_masked_marginal'][mask], df_pCK['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK['esmc_wt_marginal'][mask], df_pCK['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK['esm3_masked_marginal'][mask], df_pCK['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK['esm3_wt_marginal'][mask], df_pCK['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences - base')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK['name'].str.contains('evodiff')
print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_pCK['esmc_masked_marginal'][mask], df_pCK['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK['esmc_wt_marginal'][mask], df_pCK['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK['esm3_masked_marginal'][mask], df_pCK['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK['esm3_wt_marginal'][mask], df_pCK['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences - base')

### anchor

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK_anchor['name'].str.contains('esm3')

print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_pCK_anchor['esmc_masked_marginal'][mask], df_pCK_anchor['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK_anchor['esmc_wt_marginal'][mask], df_pCK_anchor['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK_anchor['esm3_masked_marginal'][mask], df_pCK_anchor['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK_anchor['esm3_wt_marginal'][mask], df_pCK_anchor['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences - anchor')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK_anchor['name'].str.contains('evodiff')

print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_pCK_anchor['esmc_masked_marginal'][mask], df_pCK_anchor['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK_anchor['esmc_wt_marginal'][mask], df_pCK_anchor['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK_anchor['esm3_masked_marginal'][mask], df_pCK_anchor['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK_anchor['esm3_wt_marginal'][mask], df_pCK_anchor['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences - anchor')


In [None]:
val_mask = df_pCK['esmc_wt_marginal'] > -300

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK['name'].str.contains('esm3') & val_mask

print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_pCK['esmc_masked_marginal'][mask], df_pCK['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK['esmc_wt_marginal'][mask], df_pCK['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK['esm3_masked_marginal'][mask], df_pCK['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK['esm3_wt_marginal'][mask], df_pCK['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences - base')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK['name'].str.contains('evodiff') & val_mask

print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_pCK['esmc_masked_marginal'][mask], df_pCK['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK['esmc_wt_marginal'][mask], df_pCK['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK['esm3_masked_marginal'][mask], df_pCK['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK['esm3_wt_marginal'][mask], df_pCK['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences - base')

### anchor

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK_anchor['name'].str.contains('esm3') & val_mask

print(f'no of esm3 sequences: {mask.sum()}')
ax[0].scatter(df_pCK_anchor['esmc_masked_marginal'][mask], df_pCK_anchor['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK_anchor['esmc_wt_marginal'][mask], df_pCK_anchor['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK_anchor['esm3_masked_marginal'][mask], df_pCK_anchor['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK_anchor['esm3_wt_marginal'][mask], df_pCK_anchor['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('esm3 sequences - anchor')

fig, ax = plt.subplots(1, 4, figsize=(13, 3), layout='constrained')

mask = df_pCK_anchor['name'].str.contains('evodiff') & val_mask

print(f'no of evodiff sequences: {mask.sum()}')
ax[0].scatter(df_pCK_anchor['esmc_masked_marginal'][mask], df_pCK_anchor['esmc_full_masked_marginal'][mask], alpha=0.5)
ax[0].set_xlabel('esmc_masked_marginal')
ax[0].set_ylabel('esmc_full_masked_marginal')
ax[1].scatter(df_pCK_anchor['esmc_wt_marginal'][mask], df_pCK_anchor['esmc_full_wt_marginal'][mask], alpha=0.5)
ax[1].set_xlabel('esmc_wt_marginal')
ax[1].set_ylabel('esmc_full_wt_marginal')
ax[2].scatter(df_pCK_anchor['esm3_masked_marginal'][mask], df_pCK_anchor['esm3_full_masked_marginal'][mask], alpha=0.5)
ax[2].set_xlabel('esm3_masked_marginal')
ax[2].set_ylabel('esm3_full_masked_marginal')
ax[3].scatter(df_pCK_anchor['esm3_wt_marginal'][mask], df_pCK_anchor['esm3_full_wt_marginal'][mask], alpha=0.5)
ax[3].set_xlabel('esm3_wt_marginal')
ax[3].set_ylabel('esm3_full_wt_marginal')

fig.suptitle('evodiff sequences - anchor')


In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_pCK['name'].str.contains('esm3')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][~mask], df_pCK[predictors[i]][~mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK[x_anchor][mask], df_pCK[predictors[i]][mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][~mask], df_pCK_anchor[predictors[i]][~mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK_anchor[x_anchor][mask], df_pCK_anchor[predictors[i]][mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

mask = df_pCK['name'].str.contains('esm3')
val_mask = df_pCK['esmc_wt_marginal'] > -300

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][~mask & val_mask], df_pCK[predictors[i]][~mask & val_mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK[x_anchor][mask & val_mask], df_pCK[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][~mask & val_mask], df_pCK_anchor[predictors[i]][~mask & val_mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK_anchor[x_anchor][mask & val_mask], df_pCK_anchor[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

##### EvoDiff

In [None]:
mask = df_pCK['name'].str.contains('evodiff')
val_mask = (df_pCK['esmc_wt_marginal'] > -300) & df_pCK['esmc_masked_marginal'].between(-10, 10) & (df_pCK_anchor['esmc_masked_marginal'] > -50)

print(f'no of evodiff sequences: {(mask & val_mask).sum()}')

fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][mask & val_mask], df_pCK[predictors[i]][mask & val_mask], alpha=0.5, label='evodiff')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][mask & val_mask], df_pCK_anchor[predictors[i]][mask & val_mask], alpha=0.5, label='evodiff')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
mask = df_pCK['name'].str.contains('evodiff') & (df_pCK['esmc_wt_marginal'] > -300) & df_pCK['esmc_masked_marginal'].between(-10, 10) & (df_pCK_anchor['esmc_masked_marginal'] > -50)

df_filtered = df_pCK[mask]

# Define the number of bins
num_bins = 4
# Generate the bin edges
bin_edges = np.linspace(df_filtered['esmc_wt_marginal'].min()-1, df_filtered['esmc_wt_marginal'].max()+1, num_bins + 1)

# Bin the esmc_masked_marginal column
df_filtered['bin'] = pd.cut(df_filtered['esmc_wt_marginal'], bins=bin_edges, labels=False)

print(df_filtered['bin'].unique())

# Select the sequence with the highest value in each bin
df_selected = df_filtered.groupby('bin').apply(lambda x: x.loc[x['esmc_masked_marginal'].idxmax()])

dist_matrix = np.eye(df_selected.shape[0])
for i in range(df_selected['masked_seq'].shape[0]):
    for j in range(df_selected['masked_seq'].shape[0]):
        dist_matrix[i, j] = helper.hamming_distance(df_selected['masked_seq'][i], df_selected['masked_seq'][j])

plt.imshow(dist_matrix, cmap='viridis')
plt.colorbar()

evodiff_selected = df_selected['name'].to_list()

In [None]:
mask = df_pCK['name'].str.contains('evodiff')
val_mask = (df_pCK['esmc_wt_marginal'] > -300) & df_pCK['esmc_masked_marginal'].between(-10, 10) & (df_pCK_anchor['esmc_masked_marginal'] > -50)

print(f'no of evodiff sequences: {(mask & val_mask).sum()}')

fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][mask & val_mask], df_pCK[predictors[i]][mask & val_mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK[df_pCK['name'].isin(evodiff_selected)][x_anchor], df_pCK[df_pCK['name'].isin(evodiff_selected)][predictors[i]], alpha=0.7, label='selected_evodiff', color='green')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][mask & val_mask], df_pCK_anchor[predictors[i]][mask & val_mask], alpha=0.5, label='evodiff')
    ax_.scatter(df_pCK_anchor[df_pCK_anchor['name'].isin(evodiff_selected)][x_anchor], df_pCK_anchor[df_pCK_anchor['name'].isin(evodiff_selected)][predictors[i]], alpha=0.7, label='selected_evodiff', color='green')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

##### ESM3

In [None]:
mask = df_pCK['name'].str.contains('esm3')
val_mask = (df_pCK['esmc_wt_marginal'] > -300) & df_pCK_anchor['esmc_masked_marginal'].between(-10, 10)

print(f'no of esm3 sequences: {(mask & val_mask).sum()}')

fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][mask & val_mask], df_pCK[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][mask & val_mask], df_pCK_anchor[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

In [None]:
mask = df_pCK['name'].str.contains('esm3') & (df_pCK['esmc_wt_marginal'] > -300) & df_pCK_anchor['esmc_masked_marginal'].between(-10, 10)

df_filtered = df_pCK[mask]

# Define the number of bins
num_bins = 6
# Generate the bin edges
bin_edges = np.linspace(df_filtered['esmc_masked_marginal'].min()-1, df_filtered['esmc_masked_marginal'].max()+1, num_bins + 1)

# Bin the esmc_masked_marginal column
df_filtered['bin'] = pd.cut(df_filtered['esmc_masked_marginal'], bins=bin_edges, labels=False)

print(df_filtered['bin'].unique())

# Select the sequence with the highest value in each bin
df_selected = df_filtered.groupby('bin').apply(lambda x: x.loc[x['esmc_masked_marginal'].idxmax()])

dist_matrix = np.eye(df_selected.shape[0])
for i in range(df_selected['masked_seq'].shape[0]):
    for j in range(df_selected['masked_seq'].shape[0]):
        dist_matrix[i, j] = helper.hamming_distance(df_selected['masked_seq'][i], df_selected['masked_seq'][j])

plt.imshow(dist_matrix, cmap='viridis')
plt.colorbar()

esm3_selected = df_selected['name'].to_list()

In [None]:
mask = df_pCK['name'].str.contains('esm3')
val_mask = (df_pCK['esmc_wt_marginal'] > -300) & df_pCK_anchor['esmc_masked_marginal'].between(-10, 10)

print(f'no of esm3 sequences: {(mask & val_mask).sum()}')

fig, ax = plt.subplots(3, 3, figsize=(11, 9), layout='constrained')

predictors = ['esmc_wt_marginal', 'esmc_pll', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esm3_pll', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    ax_.scatter(df_pCK[x_anchor][mask & val_mask], df_pCK[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.scatter(df_pCK[df_pCK['name'].isin(esm3_selected)][x_anchor], df_pCK[df_pCK['name'].isin(esm3_selected)][predictors[i]], alpha=0.7, label='selected_esm3', color='red')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()


fig, ax = plt.subplots(2, 4, figsize=(13, 6), layout='constrained')

predictors = ['esmc_wt_marginal', 'esm3_wt_marginal', 'esm3_masked_marginal', 'esmc_full_wt_marginal', 'esmc_full_masked_marginal', 'esm3_full_wt_marginal', 'esm3_full_masked_marginal', 'no_plot']
x_anchor = 'esmc_masked_marginal'

for i, ax_ in enumerate(ax.flatten()):
    if predictors[i] == 'no_plot':
        ax_.axis('off')
        continue
    ax_.scatter(df_pCK_anchor[x_anchor][mask & val_mask], df_pCK_anchor[predictors[i]][mask & val_mask], alpha=0.5, label='esm3')
    ax_.scatter(df_pCK_anchor[df_pCK_anchor['name'].isin(esm3_selected)][x_anchor], df_pCK_anchor[df_pCK_anchor['name'].isin(esm3_selected)][predictors[i]], alpha=0.7, label='selected_esm3', color='red')
    ax_.set_xlabel(x_anchor)
    ax_.set_ylabel(predictors[i])
    ax_.legend()

##### Combinee selections

In [None]:
df_selected = df_pCK[df_pCK['name'].isin(esm3_selected+evodiff_selected)]

In [None]:
df_selected

In [None]:
# fasta_full = os.path.join(data_path, 'pCK_KG_selected_full.fasta')
# fasta_masked = os.path.join(data_path, 'pCK_KG_selected_masked.fasta')
# for i, row in df_selected.iterrows():

#     helper.create_fasta({
#         row['name']: row['full_seq']
#     }, file=fasta_full, append=True)

#     helper.create_fasta({
#         row['name']: row['masked_seq']
#     }, file=fasta_masked, append=True)