In [None]:
import sys
sys.path.append('../..')

In [None]:
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from DomainPrediction.utils import helper

In [None]:
data_path = '../../../Data/al_test_experiments/Tdomain'

### Load Data for processing

In [None]:
df = pd.read_csv(os.path.join(data_path, 'Tdomainr2.csv'))

In [None]:
df.head()

In [None]:
df.shape

In [None]:
wt = helper.read_fasta(os.path.join(data_path, 'Tdomain_WT.fasta'), mode='str')[0]

In [None]:
len(wt)

In [None]:
## sanilty check sequences
check_seqs = helper.read_fasta(os.path.join(data_path, 'round_1_extraction.fasta'), mode='str')
for i, seq in enumerate(check_seqs):
    assert df['Sequence'][i] == seq

In [None]:
def hamming_distance(seq1, seq2):
    # Ensure the sequences are of the same length
    if len(seq1) != len(seq2):
        raise ValueError("Sequences must be of equal length to compute Hamming distance.")
    
    # Count differences
    return sum(c1 != c2 for c1, c2 in zip(seq1, seq2))

In [None]:
dist_matrix = np.eye(df['Sequence'].shape[0])
sequences = df['Sequence'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

In [None]:
dist_matrix[0,:]

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.show()

In [None]:
dist_from_wt = []
for seq in df['Sequence']:
    dist_from_wt.append(hamming_distance(wt, seq))

fig, ax = plt.subplots(1,1, figsize=(4,4))
ax.hist(dist_from_wt, bins=np.linspace(0, 100, 21))
plt.show()
# print(np.array(dist_from_wt))
# print(df['Name'].to_numpy())

#### Duplicates

In [None]:
for name in df['Name']:
    _df = df[df['Sequence'] == df.loc[df['Name'] == name, 'Sequence'].iloc[0]]
    if _df.shape[0] > 1:
        print(name)

In [None]:
df[df['Sequence'] == df.loc[df['Name'] == 'WT', 'Sequence'].iloc[0]]

In [None]:
df[df['Sequence'] == df.loc[df['Name'] == 'ESM2', 'Sequence'].iloc[0]]

In [None]:
df[df['Sequence'] == df.loc[df['Name'] == 'ESM2_EP_10_1_4', 'Sequence'].iloc[0]]

In [None]:
df = df.drop_duplicates(subset='Sequence')

In [None]:
df.shape

In [None]:
for name in df['Name']:
    _df = df[df['Sequence'] == df.loc[df['Name'] == name, 'Sequence'].iloc[0]]
    if _df.shape[0] > 1:
        print(name)

#### Splitting

In [None]:
dist_matrix = np.eye(df['Sequence'].shape[0])
sequences = df['Sequence'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.show()

In [None]:
df.shape[0], df['Name'].str.contains("ESM2").sum(), df['Name'].str.contains("esm3").sum(), df['Name'].str.contains("evodiff").sum() #need 13 more to set up a dataset

In [None]:
ai_gen = ['WT', 'ESM1', 'ESM2', 'ESM3', 'ESM4', 'IN1', 'IN2', 'IN3', 'IN4',
            'IN5', 'IN6', 'IN7', 'esm3_gen_17', 'esm3_str_gen_53',
            'esm3_str_gen_170', 'esm3_gen_192', 'esm3_gen_333',
            'esm3_str_gen_365', 'esm3_gen_385', 'evodiff_gen_77',
            'II_guidance-ESM3', 'Test-IN1-improved', 'Test-IN4-improved',
            'Test-IN7-improved', 'evodiff_gen_21', 'evodiff_gen_88',
            'evodiff_gen_155', 'evodiff_gen_357', 'esm3_gen_93',
            'esm3_str_gen_314', 'esm3_str_gen_323', 'II guidance-WT',
            'II guidance-ESM2', 'II guidance-IN2', 'prob_guidancd-ESM2',
            'Test-IN2-improved', 'Test-IN5-improved']

mask = df['Name'].isin(ai_gen)

dist_matrix = np.eye(df[mask]['Sequence'].shape[0])
sequences = df[mask]['Sequence'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.xticks(np.arange(len(ai_gen)), ai_gen, size=7, rotation=90)
plt.yticks(np.arange(len(ai_gen)), ai_gen, size=7)
plt.show()

In [None]:
df.shape[0], df['Name'].str.contains("ESM2").sum(), df['Name'].str.contains("esm3").sum(), df['Name'].str.contains("evodiff").sum() #need 13 more to set up a dataset

In [None]:
df.assign(split_id=2)
df.loc[df['Name'].str.contains("ESM2"), 'split_id'] = 0
df.loc[df['Name'].str.contains("esm3_str"), 'split_id'] = 0
df.loc[df['Name'].str.contains("evodiff"), 'split_id'] = 0
df.loc[df['Name'].str.contains("esm3_gen"), 'split_id'] = 1

train_size = df[df['split_id']==2].shape[0]
val_size = df[df['split_id']==1].shape[0]
test_size = df[df['split_id']==0].shape[0]
print(f"train: {train_size}, val: {val_size}, test: {test_size}")

In [None]:
dist_matrix = np.eye(df['Sequence'].shape[0])
sequences = np.concatenate((df.loc[df['split_id']==2, 'Sequence'].to_numpy(), df.loc[df['split_id'].isin([0, 1]), 'Sequence'].to_numpy()))
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.xticks(np.arange(len(df)), [2]*(df['split_id']==2).sum() + [0]*df['split_id'].isin([0,1]).sum(), size=7, rotation=0)
plt.show()

#### creating new df

In [None]:
df_new = df[['Name', 'Sequence', 'norm_WT', 'split_id']].copy()
df_new.columns = ['name', 'seq', 'fitness_raw', 'split_id']

In [None]:
df_new.head()

In [None]:
n_muts = []
for i, row in df_new.iterrows():
    variant_seq = row["seq"]
    assert len(variant_seq) == len(wt)
    n = 0
    for a, b in zip(variant_seq, wt):
        if a != b:
            n += 1
    
    n_muts.append(n)

df_new['n_mut'] = n_muts

In [None]:
df_new.head()

In [None]:
df_new['n_mut'].unique()

In [None]:
df_new['n_mut'].to_numpy()

In [None]:
dist_matrix = np.eye(df_new['seq'].shape[0])
sequences = df_new['seq'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

assert (df_new['n_mut'].to_numpy() == dist_matrix[0]).all()

In [None]:
from collections import Counter

In [None]:
for split in df_new['split_id'].unique():
    temp = df_new.loc[df_new['split_id']==split, 'n_mut'].unique()
    print(f'n mutation for split {split}: {temp}')
    print(Counter(df_new.loc[df_new['split_id']==split, 'n_mut']))

In [None]:
temp = []
for x in df_new['fitness_raw']:
    if x == 0:
        temp.append(1e-6)
    else:
        temp.append(x)

In [None]:
df_new['fitness_log'] = np.log(temp)

In [None]:
# file = os.path.join(data_path, 'dataset_2_tdomain.csv')
# df_new.to_csv(file, index=False)

### Load data for analysis

In [None]:
file = os.path.join(data_path, 'dataset_2_tdomain.csv')
df_new = pd.read_csv(file)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7,3), layout='constrained')
ax[0].hist(df_new.loc[df_new['split_id']==2, 'fitness_raw'])
ax[1].hist(df_new.loc[df_new['split_id'].isin([0, 1]), 'fitness_raw'])
ax[0].set_title('Train')
ax[1].set_title('Test')

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7,3), layout='constrained')
ax[0].hist(df_new.loc[df_new['split_id']==2, 'fitness_log'])
ax[1].hist(df_new.loc[df_new['split_id'].isin([0, 1]), 'fitness_log'])
ax[0].set_title('Train')
ax[1].set_title('Test')

In [None]:
df_new['fitness_log'].min()

In [None]:
df_new.loc[df_new['fitness_raw'] == 0]

In [None]:
np.sort(df_new['n_mut'].unique())

In [None]:
df_new[df_new['n_mut']==0]

In [None]:
df_new.shape

In [None]:
for name in df_new['name']:
    _df = df_new[df_new['seq'] == df_new.loc[df_new['name'] == name, 'seq'].iloc[0]]
    if _df.shape[0] > 1:
        print(name)

In [None]:
print(f"train: {(df_new['split_id']==2).sum()}, val: {(df_new['split_id']==1).sum()}, test: {(df_new['split_id']==0).sum()}")

In [None]:
def hamming_distance(seq1, seq2):
    # Ensure the sequences are of the same length
    if len(seq1) != len(seq2):
        raise ValueError("Sequences must be of equal length to compute Hamming distance.")
    
    # Count differences
    return sum(c1 != c2 for c1, c2 in zip(seq1, seq2))

In [None]:
dist_matrix = np.eye(df_new['seq'].shape[0])
sequences = df_new['seq'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.show()

In [None]:
esm2_dist = dist_matrix[df_new['name'].str.contains('ESM2')]
esm2_dist = esm2_dist[:, ~df_new['name'].str.contains('ESM2')]
print(esm2_dist.shape)
plt.figure(figsize=(3,3))
plt.hist(esm2_dist.flatten())
plt.xlabel('n mutations')
plt.show()

In [None]:
gen_dist = dist_matrix[df_new['name'].str.contains('esm3') | df_new['name'].str.contains('evodiff')]
gen_dist = gen_dist[:, ~(df_new['name'].str.contains('esm3') | df_new['name'].str.contains('evodiff'))]
print(gen_dist.shape)
plt.figure(figsize=(3,3))
plt.hist(gen_dist.flatten())
plt.xlabel('n mutations')
plt.show()

In [None]:
test_train_dist = dist_matrix[df_new['split_id'].isin([0, 1])]
test_train_dist = test_train_dist[:, ~df_new['split_id'].isin([0, 1])]
print(test_train_dist.shape)
plt.figure(figsize=(3,3))
plt.hist(test_train_dist.min(axis=1).flatten())
plt.xlabel('n mutations to closest seq \nin train set', size=10)
plt.xlim(left=20)
plt.show()

In [None]:
dist_matrix = np.eye(df_new['seq'].shape[0])
sequences = np.concatenate((df_new.loc[df_new['split_id']==2, 'seq'].to_numpy(), df_new.loc[df_new['split_id'].isin([0, 1]), 'seq'].to_numpy()))
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.xticks(np.arange(len(df_new)), [2]*(df_new['split_id']==2).sum() + [0]*df_new['split_id'].isin([0,1]).sum(), size=7, rotation=0)
plt.show()

In [None]:
ai_gen = ['WT', 'ESM1', 'ESM2', 'ESM3', 'ESM4', 'IN1', 'IN2', 'IN3', 'IN4', 'IN5', 'IN6', 'IN7', 
          'esm3_gen_17', 'esm3_gen_93', 'esm3_gen_192', 'esm3_gen_333', 'esm3_gen_385', 
          'esm3_str_gen_53', 'esm3_str_gen_170', 'esm3_str_gen_314', 'esm3_str_gen_323', 'esm3_str_gen_365', 
          'evodiff_gen_21', 'evodiff_gen_77', 'evodiff_gen_88', 'evodiff_gen_155', 'evodiff_gen_357', 
          'Test-IN1-improved', 'Test-IN2-improved', 'Test-IN4-improved', 'Test-IN5-improved', 'Test-IN7-improved', 
          'II_guidance-ESM3', 'II guidance-WT', 'II guidance-ESM2', 'II guidance-IN2', 'prob_guidancd-ESM2']

mask = df_new['name'].isin(ai_gen)
_df = df_new[mask]
_df['name'] = pd.Categorical(_df['name'], categories=ai_gen, ordered=True)
_df = _df.sort_values('name')

dist_matrix = np.eye(_df['seq'].shape[0])
sequences = _df['seq'].to_numpy()
for i in range(sequences.shape[0]):
    for j in range(sequences.shape[0]):
        dist_matrix[i, j] = hamming_distance(sequences[i], sequences[j])

plt.figure(figsize=(10,10))
plt.imshow(dist_matrix, cmap='hot')
plt.colorbar(shrink=0.75)
plt.xticks(np.arange(len(ai_gen)), ai_gen, size=7, rotation=90)
plt.yticks(np.arange(len(ai_gen)), ai_gen, size=7)
plt.show()

In [None]:
df_new['name'].to_numpy()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7,3), layout='constrained')
ax[0].hist(df_new.loc[df_new['split_id']==2, 'n_mut'])
ax[1].hist(df_new.loc[df_new['split_id'].isin([0, 1]), 'n_mut'])
ax[0].set_title('Train')
ax[1].set_title('Test')

In [None]:
bin_edges = np.arange(0, max(df_new['n_mut']) + 5, 5)
df_new['n_mut_bin'] = pd.cut(df_new['n_mut'], bins=bin_edges, right=False)
grouped = df_new.groupby('n_mut_bin')['fitness_raw'].apply(list)
boxplot_data = [group for group in grouped]

plt.figure(figsize=(10, 3))
plt.boxplot(boxplot_data, labels=[str(group) for group in grouped.index], vert=True)
plt.xticks(rotation=90)
plt.xlabel('n_mut Bins')
plt.ylabel('Fitness Values')

plt.figure(figsize=(10, 3))
plt.bar([str(group) for group in grouped.index],[len(x) for x in boxplot_data])
plt.xticks(rotation=90)
plt.xlabel('n_mut Bins')
plt.ylabel('# of points')