# Set up Environment

In [None]:
# Python ≥3.5 is required
import sys, re
assert sys.version_info >= (3, 5)

# Is this notebook running on Colab or Kaggle?
IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score,cross_val_predict
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score, recall_score
from sklearn.model_selection import train_test_split

# Common imports
import numpy as np
import os
import pandas as pd
import plotly.express as px
import time

# Import custom utility functions
import glycan_bionames

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

# Define custom functions
def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)
    
def restrict_RBD_window(df,nm):
    '''Function to drop features of dataframe that correspond to glycans which are outside a given RBD neighborhood (in nm)'''
    #Get list of glycans
    glycans = list(np.unique([x.replace('RBD__2__','') for x in df.keys().to_list() if 'RBD__2__GLY' in x]))
    
    for g in glycans:
        if df['RBD__2__' + g].mean() > nm:
            for f in ['RBD__2__'+g,g+':ROF',g+':RMSD',g+'_x',g+'_y',g+'_z']:
                if f in df.keys().to_list():
                    df.drop([f],axis=1,inplace=True)    
    return df

def overlapping_hist(open_df,closed_df,feat):
    '''Plot overlapping histograms for a given feature of all datasets'''
    open_df[feat].hist(bins=50)
    closed_df[feat].hist(bins=50)
    mutant_df[feat].hist(bins=50)
    plt.legend(['Open','Closed','Mutant (open)'])
    plt.title(feat)
    if 'RBD__2__' in feat:
        plt.xlabel('nm')
        
def drop_feats(df,flag):
    '''Drops all features in df containing flag'''
    for f in df.keys().to_list():
        if flag in f:
            df.drop(f,axis=1,inplace=True)
    return df

def read_n_filter_dfs(fname,num_reps,RBD_wind,val_reps_open,val_reps_closed,label_val,dfs_train=None,dfs_val=None):
    '''Reads data and filters columns, then places in either train or validation dataframe list'''
    if dfs_train is None:
        dfs_train = []
    if dfs_val is None:
        dfs_val = []
        
    for i in range(1,num_reps+1):
        df = pd.read_csv(fname+'.csv').assign(label=label_val).iloc[:,1:]
        # Only use glycans within certain range of the RBD
        df = restrict_RBD_window(df,RBD_wind)
        # Drop _x, _y, and _z features
        df = drop_feats(df,'_x')
        df = drop_feats(df,'_y')
        #df = drop_feats(df,'RBD__2__')
        df = drop_feats(df,'_z')
        
        # Withold some replicants for use in a separate validation set
        if (label_val==1) & (i in val_reps_open):
            dfs_val.append(df)
        elif (label_val==0) & (i in val_reps_closed):
            dfs_val.append(df)
        else:
            dfs_train.append(df)
            
    return dfs_train, dfs_val

def remove_corr_feats(full_df,corr_thresh= 0.65):
    '''Remove highly correlated features'''
    corr_matrix = full_df.corr()
    final_features = corr_matrix['RBD_CA0:RMSD'][(corr_matrix['RBD_CA0:RMSD'] < corr_thresh) & (corr_matrix['RBD_CA0:RMSD'] > -corr_thresh)].reset_index().loc[:,'index'].to_list()
    if 'label' not in final_features:
        final_features.append('label')
    clf_df = full_df.loc[:,final_features]
    return clf_df

def prep_ML_data(clf_df,ts,rs,labelnames):
    '''Prepare data for use in training machine learning algorithm'''
    # Split training and testing data
    train_set, test_set = train_test_split(clf_df,test_size=ts, random_state=rs,stratify=labelnames)
    print(f'Train set : {train_set.shape}, Test set : {test_set.shape}')

    # Split data and labels
    train_X = train_set.drop("label", axis=1) # drop labels for training set
    train_labels = train_set["label"].copy()
    test_X = test_set.drop("label", axis=1) # drop labels for training set
    test_labels = test_set["label"].copy()

    return train_X, test_X, train_labels, test_labels
   

# Load data

### Load all replicants as one dataframe

In [None]:
# Open dataset
fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_spike_open_prot_glyc_amarolab/results/FinalExtractedFeature_open.csv'
open_df = pd.read_csv(fname).assign(label = 1).iloc[:,1:]

# Closed dataset
fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_spike_closed_prot_glyc_amarolab/results/FinalExtractedFeature_closed.csv'
closed_df = pd.read_csv(fname).assign(label = 0).iloc[:,1:]

In [None]:
# Mutant dataset
fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_spike_mutant_prot_glyc_amarolab/results/FinalExtractedFeature_mutant.csv'
mutant_df = pd.read_csv(fname).assign(label=1)

In [None]:
# Opening dataset
fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_continuous_spike_opening_WE_chong_and_amarolab/results/FinalExtractedFeature.csv'
opening_df = pd.read_csv(fname)
xidx = 1500
opening_df.loc[:xidx,'label'] = 0
opening_df.loc[xidx:,'label'] = 1
fig = px.line(opening_df,y = ['RBD__2__CH_CA0','RBD__2__backbone0'],title='Labeling Closed and Open')
fig.add_vline(x=xidx)


# Filter out features

In [None]:
# Only use glycans within 10 nm of the RBD
open_df = restrict_RBD_window(open_df,8)
closed_df = restrict_RBD_window(closed_df,8)
opening_df = restrict_RBD_window(opening_df,8)
print(open_df.shape)

# Drop _x, _y, _z features
#open_df = drop_feats(open_df,'_x')
#open_df = drop_feats(open_df,'_y')
#open_df = drop_feats(open_df,'_z')
#open_df = drop_feats(open_df,'RBD__2__')
#open_df.drop(['RBD__2__backbone0','RBD__2__CH_CA0'],axis=1,inplace=True)

# Only use columns that exist in all datasets
common_cols = list(set.intersection(*map(set,[open_df,closed_df,opening_df]))) #add or remove val datasets as needed
# Only use open & closed datasets for training
full_df = pd.concat([open_df.loc[:,common_cols],closed_df.loc[:,common_cols]]).drop(['frame'],axis=1)
# OR
# Use all datasets
#full_df = pd.concat([open_df.loc[:,common_cols],closed_df.loc[:,common_cols],opening_df.loc[:,common_cols],mutant_df.loc[:,common_cols]]).drop(['frame'],axis=1)
print(full_df.shape)

# Remove highly correlated features
clf_df = remove_corr_feats(full_df,0.5)
print(clf_df.shape)

# Prepare the Data for Machine Learning Algorithms

In [None]:
# Split train/test data
train_X, test_X, train_labels, test_labels = prep_ML_data(clf_df,0.3,42,full_df.label)


# Normalize data
num_pipeline = Pipeline([
       ('std_scaler', StandardScaler()),
    ])

train_X_prepared = num_pipeline.fit_transform(train_X)
test_X_prepared = num_pipeline.transform(test_X)

# Train and Test Model

In [None]:
# Initialize classifier
sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)

# Perform 10-fold cross-validation on training data
print('Cross Validation Scores:')
y_train_pred = cross_val_predict(sgd_clf,train_X_prepared, train_labels, cv=10)
t = time.time()
print(cross_val_score(sgd_clf, train_X_prepared, train_labels, cv=10, scoring="accuracy"))
print('')
print(str(time.time()-t) + ' sec elapsed')

# Get overall precision and recall for training data
confusion_matrix(train_labels, y_train_pred)
print('')
print(f' Train precison : {precision_score(train_labels, y_train_pred)}, Train recall {recall_score(train_labels, y_train_pred)}')

# Get overall precision and recall for testing data
sgd_clf.fit(train_X_prepared,train_labels)
y_test_pred = sgd_clf.predict(test_X_prepared)
print('')
print(f' Test precison : {precision_score(test_labels, y_test_pred)}, Test recall {recall_score(test_labels, y_test_pred)}')



### Display feature importances

In [None]:
x_vals = train_X.columns.to_list()
#x_vals = train_X.columns.to_list()
y_vals = np.abs(sgd_clf.coef_[0])
col_vals = train_X.columns.to_list()

fig1 = px.bar(x=x_vals,y=y_vals,color=col_vals,labels={'x':'Feature','y':'Importance','color':'Importance'}).update_xaxes(categoryorder='total ascending')
fig1.show()

# Iterative Replicant Analysis

Run iterative leave-one-out analysis wherein 1/3 of the replicants are withheld from the training/testing dataset and used as a separate "validation" dataset afterwards. The idea is to implement the trained model on a completely "new" dataset and see if the model's performance holds up.

In [None]:
RBD_wind = 8
leftouts = []
train_precs = np.zeros([6,3])
train_recalls = np.zeros([6,3])
test_precs = np.zeros([6,3])
test_recalls = np.zeros([6,3])
val_precs = np.zeros([6,3])
val_recalls = np.zeros([6,3])
top_feats = []
for i in range(1,7):
    for j in range(1,4):
        val_reps_closed = [j]
        val_reps_open = [i]
        #if i == 6:
        #    val_reps_open = [1,6];
        #else:
        #    val_reps_open = [i,i+1];

        # Read open data
        fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_spike_open_prot_glyc_amarolab/results/FinalExtractedFeature'
        dfs_train, dfs_val = read_n_filter_dfs(fname,6,RBD_wind,val_reps_open,val_reps_closed,1)

        # Read closed data
        fname = '/net/jam-amaro-shared/dse_project/Spike_Dataset/TRAJECTORIES_spike_closed_prot_glyc_amarolab/results/FinalExtractedFeature'
        dfs_train, dfs_val = read_n_filter_dfs(fname,3,RBD_wind,val_reps_open,val_reps_closed,0,dfs_train,dfs_val)
        print('Val Size: ')
        print(pd.concat(dfs_val).shape)
        
        # Only use columns that exist in all datasets
        common_cols = list(set.intersection(*map(set,dfs_train+dfs_val)))
        full_df = pd.concat(dfs_train).loc[:,common_cols].drop(['frame'],axis = 1)
        full_df.shape
        
        # Remove highly correlated columns
        clf_df = remove_corr_feats(full_df,0.5)

        # Split train/test data
        train_X, test_X, train_labels, test_labels = prep_ML_data(clf_df,0.3,42,full_df.label)


        # Normalize data
        num_pipeline = Pipeline([
               ('std_scaler', StandardScaler()),
            ])
        train_X_prepared = num_pipeline.fit_transform(train_X)
        test_X_prepared = num_pipeline.transform(test_X)
        
        # Initialize classifier
        sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)

        # Perform 10-fold cross-validation on training data
        y_train_pred = cross_val_predict(sgd_clf,train_X_prepared, train_labels, cv=10)
        t = time.time()
        print(cross_val_score(sgd_clf, train_X_prepared, train_labels, cv=10, scoring="accuracy"))
        print(str(time.time()-t) + ' sec elapsed')

        # Get overall precision and recall for training data
        confusion_matrix(train_labels, y_train_pred)
        print(f' Train precison : {precision_score(train_labels, y_train_pred)}, train recall {recall_score(train_labels, y_train_pred)}')

        # Get overall precision and recall for testing data
        sgd_clf.fit(train_X_prepared,train_labels)
        y_test_pred = sgd_clf.predict(test_X_prepared)
        print(f' Test precison : {precision_score(test_labels, y_test_pred)}, Test recall {recall_score(test_labels, y_test_pred)}')

        # Prep data
        val_X = pd.concat(dfs_val).loc[:,train_X.keys()]
        val_labels = pd.concat(dfs_val).label
        val_X_prepared = num_pipeline.transform(val_X)

        # Get testing results on unseen replicant(s)
        y_val_pred = sgd_clf.predict(val_X_prepared)
        print(f' Val precison : {precision_score(val_labels, y_val_pred)}, Val recall {recall_score(val_labels, y_val_pred)}')
        
        # Save results
        leftouts.append(['open '+ str(x) +' ' for x in val_reps_open] + ['closed ' + str(x) + ' ' for x in val_reps_closed])
        train_precs[i-1,j-1] = precision_score(train_labels, y_train_pred)
        train_recalls[i-1,j-1] = recall_score(train_labels,y_train_pred)
        test_precs[i-1,j-1] = precision_score(test_labels, y_test_pred)
        test_recalls[i-1,j-1] = recall_score(test_labels, y_test_pred)
        val_precs[i-1,j-1] = precision_score(val_labels, y_val_pred)
        val_recalls[i-1,j-1] = recall_score(val_labels, y_val_pred)
        a = list(np.abs(sgd_clf.coef_[0]))
        idx = sorted(range(len(a)), key = lambda k: a[k])[-5:]
        x_vals = train_X.columns.to_list()
        top_feats.append(list(np.array(x_vals)[idx]))


In [None]:
top_feats

In [None]:
px.histogram(top_feats,title='Commonly-Important Features').update_xaxes(categoryorder='total ascending')

In [None]:
testPrec = px.imshow(test_precs,range_color=[0.5,1],title='Testing Precisions')
testPrec.show()

In [None]:
testRec = px.imshow(test_recalls,range_color=[0.5,1],title='Testing Recalls')
testRec.show()

In [None]:
unseenPrec = px.imshow(val_precs,range_color=[0.5, 1],title='Precision on Unseen Replicants')
unseenPrec.show()

In [None]:
unseenRec = px.imshow(val_recalls, range_color=[0.5,1],title='Recall on Unseen Replicants')
unseenRec.show()

# Dashboard

In [None]:
import mdtraj as md
from biopandas.pdb import PandasPdb

In [None]:
def extract_glycan_residues_4m_pdb(dcdObj):
    '''Extract glycans from dcd object. Glycans=atoms w/ segment_id == G1, G2, etc'''
    dcdObj[0].save_pdb('.tmp.pdb')
    pdb_df = PandasPdb().read_pdb('.tmp.pdb')
       
    pdb_atom_df = pdb_df.df['ATOM']
    glycan_mask =  pdb_atom_df.segment_id.apply(lambda x : True if re.match('G\d+', x) else False)
    glycan_residues = pdb_atom_df[glycan_mask].residue_name.unique()
    if os.path.exists('.tmp.pdb'):
        os.remove('.tmp.pdb')
    del pdb_df
    return glycan_residues    

def get_atom_ids_for_feature(dcd_traj=dcd_traj,feature='protein'):
    '''Get atom ids for top-level structures using mdtraj'''
    try:
        result = (i for i in dcd_traj.top.select(feature))
    except :
        print(f'[ERROR] {feature} not recognized for atom filtering')
        result = []
    else :
        #print(f'[INFO] # of atoms : {len(list(result))} filtered for {feature}')
        return list(result)

def build_atom_lup_4_common_features(dcd_traj=dcd_traj,flist = ['protein', 'backbone','sidechain']):
    '''Pull atoms for all top-level structures from dcd'''
    return {f: get_atom_ids_for_feature(dcd_traj,f) for f in flist}

def get_xyz_perFrame(traj,atom_ids):
    return pd.DataFrame(columns=['x','y','z'], data=traj.xyz[0,atom_ids])

def gen_xyz_Table_4_LUP(traj=dcd_traj, LUP = atom_id_LUP, keyNames =['sidechain','RBD_CA', 'CH_CA', 'GLY','backbone'] ):
    frame_0_coord_df = pd.DataFrame(columns=['type','typeID','x','y','z'])
    i = 0 
    for k in LUP.keys():
        if k in keyNames:
            frame_0_coord_df = (frame_0_coord_df
            .append(get_xyz_perFrame(traj,LUP[k]).assign(type = k).assign(typeID = i))
                               )
            i += 1
    return frame_0_coord_df


def gly_4m_featname(featname):
    return featname.replace(':ROF','').replace('RBD__2__','').replace(':RMSD','').replace('_x','').replace('_y','').replace('_z','').replace('GLY','G')

In [None]:
# Load trajectory

dcdClosed = './amarolab_covid19/TRAJECTORIES_spike_closed_prot_glyc_amarolab/spike_closed_prot_glyc_amarolab_1.dcd'
psfClosed = './amarolab_covid19/TRAJECTORIES_spike_closed_prot_glyc_amarolab/spike_closed_prot_glyc_amarolab.psf'
trajDir = os.path.dirname(dcdClosed)
trajClosed = md.load(dcdFile, top = psfClosed)

dcdOpen = './amarolab_covid19/TRAJECTORIES_spike_open_prot_glyc_amarolab/spike_open_prot_glyc_amarolab_1.dcd'
psfOpen = './amarolab_covid19/TRAJECTORIES_spike_open_prot_glyc_amarolab/spike_open_prot_glyc_amarolab.psf'
trajDir = os.path.dirname(dcdOpen)
trajOpen = md.load(dcdFile, top = psfOpen)

In [None]:
atom_id_LUP = build_atom_lup_4_common_features(trajClosed)
atom_id_LUP['GLY'] =[]
for gly in extract_glycan_residues_4m_pdb(trajClosed):
    for gly_atom in get_atom_ids_for_feature(trajClosed,f"resn =~ {gly}"):
        atom_id_LUP['GLY'].append(gly_atom)
        
atom_id_LUP['RBD_CA'] = get_atom_ids_for_feature(trajClosed,"resid >= 330 and resid <= 530 and name == CA")
atom_id_LUP['CH_CA'] = get_atom_ids_for_feature(trajClosed,"((resid >= 747 and resid <= 784) or (resid >= 946 and resid <= 967) or (resid >= 986 and resid <= 1034)) and (name == CA)")

In [None]:
feats = []
for i in train_X.columns.to_list():
    feats.append(gly_4m_featname(i))
    
for j in feats[:5]:
    name = 'segname ' + j
    atom_id_LUP[j] = trajClosed.top.select(name)

In [None]:
keyNames =['sidechain','RBD_CA', 'CH_CA', 'GLY','backbone']+feats[:5]
Closed_coord_df = gen_xyz_Table_4_LUP(traj=trajClosed, keyNames =keyNames)
figClosed = px.scatter_3d(Closed_coord_df, title='Closed Spike', x='x', y='y', z='z',
          color='type',width=800,height=800,opacity=0.5, 
                    size = [1]*len(frame_0_coord_df)
            )

In [None]:
figClosed.show()

In [None]:
Open_coord_df = gen_xyz_Table_4_LUP(traj=trajOpen, keyNames =keyNames)
figOpen = px.scatter_3d(Open_coord_df, title='Open Spike', x='x', y='y', z='z',
          color='type',width=800,height=800,opacity=0.5, 
                    size = [1]*len(frame_0_coord_df)
            )

In [None]:
figOpen.show()

In [None]:
from dash import Dash, html, dcc

app = Dash(__name__)

app.layout = html.Div(
    children=[
        html.H1(children="Predicting Effects of SARS-CoV-2 Variant Mutations on Spike Protein Dynamics and Mechanism",),
        html.Div(
            dcc.Graph(figure=fig1),
        ),
        html.Div(children=[
            dcc.Graph(
            figure=figClosed,
            style={'display': 'inline-block'}
            ),
            dcc.Graph(
            figure=figOpen,
            style={'display': 'inline-block'}
            ),
        ]),
        html.Div(children=[
            dcc.Graph(
                figure=testPrec,
                style={'display': 'inline-block'}
            ),
            dcc.Graph(
                figure=testRec,
                style={'display': 'inline-block'}
            ),
            dcc.Graph(
                figure=unseenPrec,
                style={'display': 'inline-block'}
            ),
            dcc.Graph(
                figure=unseenRec,
                style={'display': 'inline-block'}
            ),
        ])
    ]
)

if __name__ == "__main__":
    app.run_server(debug=False)