In [1]:
# Standard imports
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import re

# Insert local path to MAVE-NN at beginning of Python's path
import sys
sys.path.insert(0, '/Users/tareen/Desktop/Research_Projects/2020_mavenn_github/mavenn')

# Load mavenn
import mavenn
print(mavenn.__path__)

['/Users/tareen/Desktop/Research_Projects/2020_mavenn_github/mavenn/mavenn']


In [2]:
abreviation_dict = {
    'Ala':'A', 
    'Arg':'R',
    'Asn':'N',
    'Asp':'D',
    'Cys':'C',
    'Glu':'E',
    'Gln':'Q',
    'Gly':'G',
    'His':'H',
    'Ile':'I',
    'Leu':'L',
    'Lys':'K',
    'Met':'M',
    'Phe':'F',
    'Pro':'P',
    'Ser':'S',
    'Thr':'T',
    'Trp':'W',
    'Tyr':'Y',
    'Val':'V'
}

aas = list(abreviation_dict.values())

In [3]:
df  = pd.read_csv('../downloaded_files/121.txt',skiprows=4)
df.head()

Unnamed: 0,accession,hgvs_nt,hgvs_pro,score,sigma
0,urn:mavedb:00000053-a-1#1,,p.[Gly10Thr;Thr12Gln],-0.470584,1.414214
1,urn:mavedb:00000053-a-1#2,,p.[Thr12Phe;Gly13Ser],-0.470584,1.414214
2,urn:mavedb:00000053-a-1#3,,p.[Thr12Cys;Gly13Thr],-0.470584,0.816497
3,urn:mavedb:00000053-a-1#4,,p.[Gly10Ser;Thr12Gln],0.040241,0.730297
4,urn:mavedb:00000053-a-1#5,,p.[Ser11Gln;Thr12Gln],-1.163732,1.224745


In [4]:
# Create y_df
y_df = pd.DataFrame()
y_df['y'] = df['score'].astype(float).copy()
print(f'len(y_df): {len(y_df)}')
y_df.head()

len(y_df): 648022


Unnamed: 0,y
0,-0.470584
1,-0.470584
2,-0.470584
3,0.040241
4,-1.163732


In [5]:
# Parse hgvs notation
matches_list = [re.findall('([A-Za-z\*]+)([0-9]+)([A-Za-z\*]+)', s) for s in df['hgvs_pro']]

# Add hamming_dist col to y_df
y_df.insert(loc=0, column='hamming_dist', value=[len(m) for m in matches_list])

# Assign to trianing and test sets
N = len(y_df)
training_frac=.8
np.random.seed(0)
r = np.random.rand(N)
test_frac = .2
val_frac = .2
ix_train = (test_frac + val_frac <= r)
ix_val = (test_frac <= r) & (r < test_frac + val_frac)
ix_test = (r < test_frac)
y_df.insert(loc=0, column='set', value='')
y_df.loc[ix_train, 'set'] = 'training'
y_df.loc[ix_val, 'set'] = 'validation'
y_df.loc[ix_test, 'set'] = 'test'
assert all([len(x)>0 for x in y_df['set']])

y_df.head()

Unnamed: 0,set,hamming_dist,y
0,training,2,-0.470584
1,training,2,-0.470584
2,training,2,-0.470584
3,training,2,0.040241
4,training,2,-1.163732


In [6]:
### Create mut_df

# Parse strings in 'hgvs_pro' column
f = open('tmp.txt','w')
f.write('id,l,c\n')
for i, matches in enumerate(matches_list):
    for _, l, c in matches:
        f.write(f'{i},{int(l)-1},{c}\n')
f.close()
mut_df = pd.read_csv('tmp.txt')

# Map long-form aa to short-form aa
mut_df['c'] = mut_df['c'].map(abreviation_dict).astype(str)

# Remove all unrecognized 'c'
ix = mut_df['c'].isin(aas)
mut_df = mut_df[ix]

# preview mut_df
print(f'min l: {min(mut_df["l"])}')
print(f'max l: {max(mut_df["l"])}')
print(f'max id: {max(mut_df["id"])}')
mut_df.head()

min l: 9
max l: 74
max id: 648021


Unnamed: 0,id,l,c
0,0,9,T
1,0,11,Q
2,1,11,F
3,1,12,S
4,2,11,C


In [7]:

dna_wt_seq = 'CCACGCCGCATCGTCATCCACCGTGGGTCAACGGGGTTAGGCTTCAATATCGTCGGTGGAGAGGATGGTGAGGGAATCTTCATCTCATTCATTCTGGCGGGAGGACCGGCCGATTTAAGCGGAGAACTTCGCAAAGGTGACCAGATCCTTTCGGTGAATGGCGTAGATTTGCGCAACGCATCACACGAACAGGCGGCCATCGCATTAAAGAACGCCGGCCAGACCGTTACGATTATCGCGCAGTATAAA'
# obtained from EMBOSS Transeq from ebi.ac.uk/Tools
wt_seq = 'PRRIVIHRGSTGLGFNIVGGEDGEGIFISFILAGGPADLSGELRKGDQILSVNGVDLRNASHEQAAIALKNAGQTVTIIAQYK'

In [None]:
from mavenn.src.mavedb import mutations_to_dataset

# Create dataset
data_df = mutations_to_dataset(wt_seq=wt_seq, mut_df=mut_df, y_df=y_df)

# Dropna and reindex
data_df.dropna(inplace=True)
data_df.reset_index(inplace=True, drop=True)

data_df.head(10)

Unnamed: 0,set,hamming_dist,y,x
0,training,2,-0.470584,PRRIVIHRGTTQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
1,training,2,-0.470584,PRRIVIHRGSTFSGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
2,training,2,-0.470584,PRRIVIHRGSTCTGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
3,training,2,0.040241,PRRIVIHRGSTQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
4,training,2,-1.163732,PRRIVIHRGSQQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
5,training,2,-0.470584,PRRIVIHRGSCELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
6,training,2,-0.87605,PRRIVIHRGSEELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
7,training,2,-0.470584,PRRIVIHRGTTELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
8,training,2,-0.470584,PRRIVIHRGSQELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
9,validation,2,-0.470584,PRRIVIHRGFTGLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...


In [None]:
# Separate test from data_df
ix_test = data_df['set']=='test'
test_df = data_df[ix_test].reset_index(drop=True)
print(f'test N: {len(test_df):,}')

# Remove test data from data_df
data_df = data_df[~ix_test].reset_index(drop=True)
print(f'training + validation N: {len(data_df):,}')
data_df.head(10)

test N: 129,793
training + validation N: 518,229


Unnamed: 0,set,hamming_dist,y,x
0,training,2,-0.470584,PRRIVIHRGTTQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
1,training,2,-0.470584,PRRIVIHRGSTFSGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
2,training,2,-0.470584,PRRIVIHRGSTCTGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
3,training,2,0.040241,PRRIVIHRGSTQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
4,training,2,-1.163732,PRRIVIHRGSQQLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
5,training,2,-0.470584,PRRIVIHRGSCELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
6,training,2,-0.87605,PRRIVIHRGSEELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
7,training,2,-0.470584,PRRIVIHRGTTELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
8,training,2,-0.470584,PRRIVIHRGSQELGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...
9,validation,2,-0.470584,PRRIVIHRGFTGLGFNIVGGEDGEGIFISFILAGGPADLSGELRKG...


In [None]:
# Get sequence length
L = len(data_df['x'][0])

# Define model
model = mavenn.Model(regression_type='GE',
                     L=L,
                     alphabet='protein',
                     gpmap_type='pairwise',                     
                     ge_noise_model_type='SkewedT',
                     ge_nonlinearity_hidden_nodes=100,
                     ge_heteroskedasticity_order=0)

In [None]:
# Set training data
model.set_data(x=data_df['x'],
               y=data_df['y'],
               validation_flags=(data_df['set']=='validation'),
               shuffle=True)

N = 518,229 observations set as training data.
Using 24.9% for validation.
Data shuffled.
Time to set data: 75.4 sec.


In [None]:
# Fit model to data
history = model.fit(learning_rate=.0002,
                    epochs=1000,
                    batch_size=200,
                    early_stopping=True,
                    early_stopping_patience=25,
                    linear_initialization=False)

Epoch 1/1000
  16/1946 [..............................] - ETA: 3:23:53 - loss: 329.0460 - I_var: -0.6839

In [None]:
# Save model
#model.save('models/gfp_ge_additive_homogaussian')

In [None]:
# Load model
#model = mavenn.load('models/gfp_ge_additive_homogaussian')

In [None]:
# Subsample indices for easy plotting and information estimation
N_test = len(test_df)
ix = np.random.rand(N_test) < .5

# Get x and y
x_test = test_df['x'].values[ix]
y_test = test_df['y'].values[ix]

In [None]:
# Show training history
print('On test data:')

# Compute likelihood information
I_var, dI_var =  model.I_variational(x=x_test, y=y_test)
print(f'I_var_test: {I_var:.3f} +- {dI_var:.3f} bits') 

# Compute predictive information
I_pred, dI_pred = model.I_predictive(x=x_test, y=y_test)
print(f'I_pred_test: {I_pred:.3f} +- {dI_pred:.3f} bits')

I_var_hist = model.history['I_var']
val_I_var_hist = model.history['val_I_var']

fig, ax = plt.subplots(1,1,figsize=[4,4])
ax.plot(I_var_hist, label='I_var_train')
ax.plot(val_I_var_hist, label='I_var_val')
ax.axhline(I_var, color='C2', linestyle=':', label='I_var_test')
ax.axhline(I_pred, color='C3', linestyle=':', label='I_pred_test')
ax.legend()
ax.set_xlabel('epochs')
ax.set_ylabel('bits')
ax.set_title('training hisotry')
#ax.set_ylim([0, I_pred*1.2]);

In [None]:
# Compute phi and yhat values
phi = model.x_to_phi(x_test)
yhat = model.phi_to_yhat(phi)

# Create grid for plotting yhat and yqs
phi_lim = [-5, 2.5]
phi_grid = np.linspace(phi_lim[0], phi_lim[1], 1000)
yhat_grid = model.phi_to_yhat(phi_grid)
yqs_grid = model.yhat_to_yq(yhat_grid, q=[.16,.84])

# Create two panels
fig, ax = plt.subplots(1, 2, figsize=[8, 4])

# Illustrate measurement process with GE curve
ax[0].scatter(phi, y_test, color='C0', s=5, alpha=.1, label='test data')
ax[0].plot(phi_grid, yhat_grid, linewidth=2, color='C1',
        label='$\hat{y} = g(\phi)$')
ax[0].plot(phi_grid, yqs_grid[:, 0], linestyle='--', color='C1',
        label='68% CI')
ax[0].plot(phi_grid, yqs_grid[:, 1], linestyle='--', color='C1')
ax[0].set_xlim(phi_lim)
ax[0].set_xlabel('latent phenotype ($\phi$)')
ax[0].set_ylabel('measurement ($y$)')
ax[0].set_title('measurement process')
ax[0].legend()

R_square_on_test = np.corrcoef(yhat,y_test)[0][1]**2
ax[1].plot((min(yhat), max(yhat)), (min(y_test), max(y_test)), '--',color='black')
ax[1].scatter(yhat, y_test, color='C0', s=5, alpha=.1, label='test data',zorder=10)
ax[1].set_xlabel('predictions ($\hat{y}$)')
ax[1].set_ylabel('measurement ($y$)')
ax[1].set_title("$R^2 = $"+str(R_square_on_test)[0:5])


# Fix up plot
fig.tight_layout()
#plt.savefig('gfp_ge_additive_measurement_process_yhat.png',bbox_inches='tight',dpi=300)
plt.show()

In [None]:
# Set wild-type sequence
gfp_consensus_seq = model.x_stats['consensus_seq']

# Get effects of all single-point mutations on phi
theta_dict = model.get_theta(gauge='user',
                             x_wt=gfp_consensus_seq)

# Create two panels
fig, ax = plt.subplots(1, 1, figsize=[12, 4])

# Left panel: draw heatmap illustrating 1pt mutation effects
ax, cb = mavenn.heatmap(theta_dict['theta_lc'],
                        alphabet=theta_dict['alphabet'],
                        seq=gfp_consensus_seq,
                        #cmap='PiYG',
                        ccenter=0,
                        ax=ax)
ax.set_xlabel('position ($l$)')
ax.set_ylabel('amino acid ($c$)')
cb.set_label('effect ($\Delta\phi$)', rotation=-90, va="bottom")
ax.set_title('mutation effects')



# Fix up plot
fig.tight_layout()
#plt.savefig('gfp_additive_heatmap.png',bbox_inches='tight',dpi=300)
plt.show()
