In [None]:
# Import libraries
import os,sys,tarfile,re
import time 
import numpy as np
import pandas as pd
import mdtraj as md
from biopandas.pdb import PandasPdb

import logging
from queue import Queue
from threading import Thread
from time import time

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from plotly.subplots import make_subplots

In [None]:
# Unzip folder
zip_path = './DCD_FeatureExt_02_21_2022/DCD_FeatureExt_02_21_2022/bckup.tar.gz'
with tarfile.open(name=zip_path, mode="r") as tf:
                print(f'Extracting {zip_path} : {tf.getnames()}')
                fnames = tf.getnames()
                tf.extractall('./')

In [None]:
# ID dcd and psf files
dcdFile = './bckup/frame012.dcd'
psfFile = './bckup/spike_WE_renumbered.psf'
#dcdFile = './amarolab_covid19/TRAJECTORIES_continuous_spike_opening_WE_chong_and_amarolab/spike_WE.dcd'
#psfFile = './amarolab_covid19/TRAJECTORIES_continuous_spike_opening_WE_chong_and_amarolab/spike_WE_renumbered.psf'
trajDir = os.path.dirname(dcdFile)

In [None]:
# Load dcd file as trajectory
dcd_traj = md.load(dcdFile, top = psfFile)

In [None]:
def extract_glycan_residues_4m_pdb(dcdObj):
    '''Extract glycans from dcd object. Glycans=atoms w/ segment_id == G1, G2, etc'''
    dcdObj[0].save_pdb('.tmp.pdb')
    pdb_df = PandasPdb().read_pdb('.tmp.pdb')
       
    pdb_atom_df = pdb_df.df['ATOM']
    glycan_mask =  pdb_atom_df.segment_id.apply(lambda x : True if re.match('G\d+', x) else False)
    glycan_residues = pdb_atom_df[glycan_mask].residue_name.unique()
    if os.path.exists('.tmp.pdb'):
        os.remove('.tmp.pdb')
    del pdb_df
    return glycan_residues    

In [None]:
def get_atom_ids_for_feature(dcd_traj,feature='protein'):
    '''Get atom ids for top-level structures using mdtraj'''
    try:
        result = (i for i in dcd_traj.top.select(feature))
    except :
        print(f'[ERROR] {feature} not recognized for atom filtering')
        result = []
    else :
        #print(f'[INFO] # of atoms : {len(list(result))} filtered for {feature}')
        return list(result)

def build_atom_lup_4_common_features(dcd_traj,flist = ['protein', 'backbone','sidechain']):
    '''Pull atoms for all top-level structures from dcd'''
    return {f: get_atom_ids_for_feature(dcd_traj,f) for f in flist}

In [None]:
def metric_4m_mDtraj(dcdObj):
    '''Compute radius of gyration, density, & COM for dcd object'''
    return {'Rofguration' : md.compute_rg(dcdObj),
            'density' : md.density(dcdObj),
            'compute_center_of_mass' : md.compute_center_of_mass(dcdObj)
            }

In [None]:
def rmsd_by_mdTraj(f1,f2):
    '''Calculate rmsd using mdtraj'''
    return md.rmsd(f1,f2)[0]
def gen_frame_tuples(dcdObj):
    '''Create tuples for use in calculating rmsd'''
    return ((i+1,i) for i in range(dcdObj.n_frames -1 ))

In [None]:
def save_dcdMetric(addon_metrics,fileOut = f'{trajDir}/dcd_ExtractedMetrics.csv'):    
    '''Write extracted features to csv'''
    dcd_metricOut_df = pd.DataFrame()
    for k in addon_metrics.keys():
        if (( type(addon_metrics[k]) == list) or ( len(addon_metrics[k].shape) == 1)):
            dcd_metricOut_df[k] = addon_metrics[k]
        else :
            for y in range(addon_metrics[k].shape[1]):

                dcd_metricOut_df[k + '_' + str(y)] = addon_metrics[k][:,y]
    dcd_metric_Out_df = (dcd_metricOut_df
     .assign(frame = lambda df : [ f'frame_{i}' for i in range(df.shape[0]) ])
     .rename(columns = {'Rofguration' : 'rofgyration' , 
                        'compute_center_of_mass_0' : 'COM_x',
                        'compute_center_of_mass_1' : 'COM_y',
                        'compute_center_of_mass_2' : 'COM_z'
                       }
            )
     .loc[:, ['frame','rofgyration','density',
     'rmsd', 
     'COM_x',
     'COM_y',
     'COM_z',
     ]]

    )
    if not os.path.exists(os.path.dirname(fileOut)):
        os.mkdir(os.path.dirname(fileOut))
        
    dcd_metric_Out_df.to_csv(fileOut)
    #print(f'[INFO] metric dumped ==> {fileOut}')
    return dcd_metric_Out_df


In [None]:
def extract_features_4m_filteredTraj(trajObj,trajName):
    '''Extract top-level features, put into dataframe and write to file'''
    addon_metrics = metric_4m_mDtraj(trajObj)
    addon_metrics['rmsd'] = [md.rmsd(trajObj[i+1],trajObj[0])[0] for i in range(trajObj.n_frames - 1)]
    addon_metrics['rmsd'].append(md.rmsd(trajObj[-1],trajObj[0])[0])
    df = save_dcdMetric(addon_metrics,fileOut=f'{trajDir}/{trajName}/featureMetric.csv')
    return df.assign(tracjectory_KEY = trajName)

In [None]:
def extract_featuresPerChain_4m_filteredTraj(trajObj,trajName, chainIds = []):
    '''Extract chain-level features, put into dataframe, and write to file'''
    # Initialize
    df = pd.DataFrame()
    if len(chainIds) == 0 :
        chainIds = range(trajObj.n_chains)
    
    # Extract atom ids for each chain
    chain_LUP = {f'chainID_{c}' : get_atom_ids_for_feature(trajObj,f'chainid == {c}') for c in range(trajObj.n_chains) if c in chainIds}
    print(f'[INFO] {trajName} {len(list(chain_LUP.keys()))} chains considered for feature extraction')
    for k in chain_LUP.keys():
        
        # Derive chain trajectory
        cur_time = time()
        print(f'[INFO] deriving Chain {k} tracjectory')
        chainObj = derive_trajectory(trajObj,frames=list(range(trajObj.n_frames)),atom_key= f'{k}',LUP = chain_LUP) 
        print(f'[INFO] Chain {k} tracjectory derivation completed in {round(time() - cur_time,2) } seconds')
        # Extract metrics for chain
        cur_time = time()
        addon_metrics = metric_4m_mDtraj(chainObj) #radius of gyration, density, COM
        addon_metrics['rmsd'] = [md.rmsd(chainObj[i+1],chainObj[0])[0] for i in range(chainObj.n_frames - 1)]
        addon_metrics['rmsd'].append(md.rmsd(chainObj[-1],chainObj[0])[0]) #rmsd
        del chainObj
        # Insert extracted metrics into dataframe and write to file
        df = (df.append(
                save_dcdMetric(addon_metrics,fileOut=f'{trajDir}/{trajName}/featureMetric__{k}.csv')
                .assign(chain_ID = k)
                .assign(tracjectory_KEY = trajName)
                )
             )
        print(f'[INFO] Feature extraction for chain {k} completed in {round(time() - cur_time,2) } seconds')
        
    return df

In [None]:
def derive_trajectory(traj_Full,frames=[0], atom_key = 'backbone', LUP = {}):
    '''Take subset of trajectory for relevant atoms'''
    return traj_Full[frames].atom_slice(LUP[atom_key])

In [None]:
# Pull atom ids for top-level strucutes from trajectory
atom_id_LUP = build_atom_lup_4_common_features(dcd_traj)
atom_id_LUP['GLY'] =[]

In [None]:
# Extract glycan atom ids
for gly in extract_glycan_residues_4m_pdb(dcd_traj):
    for gly_atom in get_atom_ids_for_feature(dcd_traj,f"resn =~ {gly}"):
        atom_id_LUP['GLY'].append(gly_atom)

In [None]:
# Extract atom ids for RBD and Central Helix (CH)
atom_id_LUP['RBD_CA'] = get_atom_ids_for_feature(dcd_traj,"resid >= 330 and resid <= 530 and name == CA")
atom_id_LUP['CH_CA'] = get_atom_ids_for_feature(dcd_traj,"((resid >= 747 and resid <= 784) or (resid >= 946 and resid <= 967) or (resid >= 986 and resid <= 1034)) and (name == CA)")

In [None]:
#atom_id_LUP.keys()

#### Find GLY chian atoms within 4A of RBD

In [None]:
## build trajectore from filtered atoms of RBD and GLY, use only frame-0
traj_GLY_F0 = derive_trajectory(dcd_traj,atom_key='GLY', LUP=atom_id_LUP)
traj_RBD_F0 = derive_trajectory(dcd_traj,atom_key='RBD_CA', LUP=atom_id_LUP)

In [None]:
## get COM of all chains in GLY-trajectory of frame1
GLY_RBD_proximity_df = pd.DataFrame(columns=['chain','x','y','z'])
# Get atom ids
GLY_chain_LUP = {f'chainID_{c}' : get_atom_ids_for_feature(traj_GLY_F0,f'chainid == {c}')  for c in range(traj_GLY_F0.n_chains)}
GLY_chain_COM = {}
for k in GLY_chain_LUP.keys():
        #print(f'[INFO] deriving Chain {k} tracjectory')
        # Derive trajectory
        chainObj = derive_trajectory(traj_GLY_F0,frames=list(range(traj_GLY_F0.n_frames)),atom_key= f'{k}',LUP = GLY_chain_LUP)
        # Calculate COM
        GLY_chain_COM[k] = md.compute_center_of_mass(chainObj)
        # Add to dataframe
        GLY_RBD_proximity_df = GLY_RBD_proximity_df.append(pd.DataFrame(columns =['x','y','z'], data =[GLY_chain_COM[k][0]]).assign(chain = k))

In [None]:
# Compute COM of RBD (x,y,z)
RBD_COM = md.compute_center_of_mass(traj_RBD_F0)
GLY_RBD_proximity_df['RBD_x'] = RBD_COM[0][0]
GLY_RBD_proximity_df['RBD_y'] = RBD_COM[0][1]
GLY_RBD_proximity_df['RBD_z'] = RBD_COM[0][2]

In [None]:
### calculate the distance in center of mass of RBD vs GLY_chains and drop all GLY chains > 4A
GLY_RBD_proximity_df = (GLY_RBD_proximity_df
    .assign(distance = lambda df : np.sqrt((np.square(df.x - df.RBD_x) + np.square(df.y - df.RBD_y) + np.square(df.z - df.RBD_z)).astype(float)))
    .sort_values(by = ['distance'],ascending=True)    
)
GLY_chain_ids_next_to_RBD =  [int(s_c[0]) for s_c in GLY_RBD_proximity_df[GLY_RBD_proximity_df.distance <= 4].chain.str.extract(r'chainID_(\d+)').values]
#GLY_chain_ids_next_to_RBD, atom_id_LUP.keys()

#### Extract feature matrix from backbone , RBD & CH without sub-fracturing into chains
#### Extract feature matrix from each shortlisted Glycan chain in GLY_chain_ids_next_to_RBD

In [None]:
def launchFeatureExract(dcdObj,gly_chains, LUP=atom_id_LUP):
    '''Extract features for the RBD, CH, backbone, and each glycan close to RBD. Write each feature set to file.'''
    
    df_Traj_chains = pd.DataFrame()    
    start_0 = time()
    for k in atom_id_LUP.keys():
        if k == 'protein' or k == 'sidechain':
            continue
        cur_time = time()
        print(f'[INFO] deriving Tracjectory for {k}')
        curTraj = derive_trajectory(dcdObj,frames=list(range(dcdObj.n_frames)),atom_key=k ,LUP = atom_id_LUP)
        print(f'[INFO] Tracjectory derivation completed in {round(time() - cur_time,2) } seconds')
       
        if k == 'GLY' :
            print(f'[INFO] Starting Feature extraction for {gly_chains} of feature {k}')
            extract_featuresPerChain_4m_filteredTraj(curTraj,k, chainIds = gly_chains)

        else : 
            print(f'[INFO] Starting Feature extraction for {k}')
            extract_features_4m_filteredTraj(curTraj,k)

        print(f'[INFO] Feature extraction for  {k} completed in {round(time() - cur_time,2) } seconds')
        del curTraj

    print(f'[INFO] Time elapsed for Feature extraction {round(time() - start_0,2) } seconds')
    

In [None]:
#Extract features for the RBD, CH, backbone, and each glycan close to RBD. Write each feature set to file.
#dcd_traj[:3].save_dcd(f'{trajDir}/frame012.dcd')
launchFeatureExract(dcd_traj,GLY_chain_ids_next_to_RBD)

####  Gen 3D scatter to help visualize

In [None]:
def get_xyz_perFrame(traj,atom_ids):
    return pd.DataFrame(columns=['x','y','z'], data=traj.xyz[0,atom_ids])

In [None]:
def gen_xyz_Table_4_LUP(LUP = atom_id_LUP, keyNames =['sidechain','RBD_CA', 'CH_CA', 'GLY','backbone'] ):
    frame_0_coord_df = pd.DataFrame(columns=['type','typeID','x','y','z'])
    i = 0 
    for k in LUP.keys():
        if k in keyNames:
            frame_0_coord_df = (frame_0_coord_df
            .append(get_xyz_perFrame(dcd_traj,LUP[k]).assign(type = k).assign(typeID = i))
                               )
            i += 1
    return frame_0_coord_df

In [None]:
frame_0_coord_df = gen_xyz_Table_4_LUP()
fig = px.scatter_3d(frame_0_coord_df, x='x', y='y', z='z',
          color='type',width=800,height=800,opacity=0.5, 
                    #size = [20 for i in range(frame_0_coord_df.shape[0])]
            )


In [None]:
fig.show()


- RBD , CH  (com)
- sidechain + glycans -->  (G1-G70)  AI (Anand/Lorenzo)
- Monomer A/B/C  --> (Needs Info)  AI (Anand/Lorenzo)
- backbone (low prioroty) 


- Monomer A/B/C are comprised of group of chains. these chainIDs need to be provided by Lab/Data experts?


#### Read-in extracted feature per chain for RBD/backbone/CH

In [None]:
import glob

In [None]:
#Read in extracted features from csvs
featureFile_dict = {k : glob.glob(f'./bckup/{k}/*csv') for k in ['backbone','RBD_CA', 'CH_CA', 'GLY']}   
feature_df = pd.DataFrame()
# Put in single dataframe
for k in featureFile_dict.keys():
    for f in  featureFile_dict[k]:
        if 'chain' in f:
            cid = int(os.path.basename(f).split('_')[-1].replace('.csv',''))
            if k == 'GLY':
                if cid in GLY_chain_ids_next_to_RBD:
                    feature_df = feature_df.append(pd.read_csv(f).assign(feature = k).assign(chainID = cid )  )
                else :
                    continue
        else :
            feature_df = feature_df.append(pd.read_csv(f).assign(feature = k).assign(chainID = 0)  )
feature_df = feature_df.drop(['Unnamed: 0'],axis=1)

In [None]:
# Name features as: structure name + chain ID
feature_df = feature_df.assign(feature_chain = lambda df  : df.feature +  df.chainID.astype(str))

In [None]:
# Plot atom locations for frame 0
fig = px.scatter_3d((feature_df[(feature_df.frame == 'frame_0') ]
.iloc[:,4:]
.assign(feature_chain = lambda df : df.feature_chain.apply(lambda x : x + '_COM'))
.rename(columns= {'COM_x' : 'x', 'COM_y' : 'y', 'COM_z' : 'z', 'feature_chain' : 'type', 'chainID' : 'typeID'})
.loc[:,['type','typeID','x','y','z']]
.append(frame_0_coord_df[frame_0_coord_df.type.isin(['RBD_CA','CH_CA'])])
),
x='x', y='y', z='z',  color='type',width=800,height=800,opacity=0.5,
)

#### Center of Mass off Filtered Glycan chains in the vicinity of RBD

In [None]:
fig.show()

In [None]:
def extract_distance_metric(df1,df2):
   
    return (df1
     .merge(df2, left_on=['frame'], right_on=['frame'], how = 'inner', suffixes = ['_1','_2'])
     .assign(metric = lambda dfx : np.sqrt((np.square(dfx.COM_x_1 - dfx.COM_x_2) + np.square(dfx.COM_y_1 - dfx.COM_y_2) + np.square(dfx.COM_z_1 - dfx.COM_z_2) ).astype(float)))
     .metric.to_list()
    )

In [None]:
feature_df.head()

In [None]:
common_features = ['rofgyration','density','rmsd']
final_feature_df = pd.DataFrame(columns=['frame'], data = feature_df[feature_df.feature == 'RBD_CA'].frame.to_list())
for c in common_features:
            final_feature_df[f'RBD_CA0:{c}'] = feature_df[feature_df.feature_chain == 'RBD_CA0'][c]
        
for f in sorted(feature_df.feature_chain.unique()):
    if f != 'RBD_CA0' and not re.match(r'sidechain\d+',f) :
        match_object = re.match(r'GLY(\d+)',f)
        if match_object != None:
            if int(match_object.group(1)) not in GLY_chain_ids_next_to_RBD:
                continue
        final_feature_df[f'RBD__2__{f}'] = extract_distance_metric(feature_df[feature_df.feature_chain == 'RBD_CA0'], feature_df[feature_df.feature_chain == f])
        for c in common_features:
            final_feature_df[f'{f}:{c}'] = feature_df[feature_df.feature_chain == f][c]



In [None]:
final_feature_df

In [None]:
!pwd

### OLD Review Plots

#### RBD and CH COM 

In [None]:
feature_df_F0 = feature_df[(feature_df.frame == 'frame_0') & feature_df.feature.str.contains('_CA')].assign(feature_chain =  lambda df : df.feature + df.chainID)
feature_chain_trace =  go.Scatter3d(
    x=feature_df_F0.COM_x,
    y=feature_df_F0.COM_y,
    z=feature_df_F0.COM_z,
    mode='markers',
    name='RBD_n_CH_COM',
    hovertext= feature_df_F0.feature_chain,
    marker=dict(
        size=18,
        #color=feature_df_F0.chainID,                # set color to an array/list of desired values
        color ='black',
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ),
   
)
feature_df_F0

In [None]:
frame_0_trace = go.Scatter3d(
    x=frame_0_coord_df.x,
    y=frame_0_coord_df.y,
    z=frame_0_coord_df.z,
    mode='markers',
    name='frame_0_scatter',
    hovertext= frame_0_coord_df.type,
    marker=dict(
        size=12,
        color=frame_0_coord_df.typeID,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.5
    ),
   
)
fig = go.Figure(data=[frame_0_trace,feature_chain_trace])

fig.update_layout(
    #autosize=False,
    width=800,
    height=800
    
)
fig.show()

#### Filtered GLY Chain COM close to RBD. 
- see that black dots in plot below

In [None]:
feature_df[(feature_df.frame == 'frame_0')  & feature_df.feature.str.contains('GL') & feature_df.chainID.astype(int).isin(GLY_chain_ids_next_to_RBD)].sort_values(by=['chainID'])

In [None]:
feature_df_F0 =feature_df[(feature_df.frame == 'frame_0')  & feature_df.feature.str.contains('GL') & feature_df.chainID.astype(int).isin(GLY_chain_ids_next_to_RBD)].sort_values(by=['chainID']).assign(feature_chain =  lambda df : df.feature + df.chainID)
feature_chain_trace =  go.Scatter3d(
    x=feature_df_F0.COM_x,
    y=feature_df_F0.COM_y,
    z=feature_df_F0.COM_z,
    mode='markers',
    name='GLY_chains',
    hovertext= feature_df_F0.feature_chain,
    marker=dict(
        size=18,
        #color=feature_df_F0.chainID,                # set color to an array/list of desired values
        color ='black',
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ),
   
)


In [None]:
fig = go.Figure(data=[frame_0_trace,feature_chain_trace])

fig.update_layout(
    #autosize=False,
    width=800,
    height=800
    
)
fig.show()

#### backbone and 6 chains of it with their COMs

In [None]:
feature_df_F0_b = feature_df[(feature_df.frame == 'frame_0') & feature_df.feature.str.endswith('backbone')].assign(feature_chain =  lambda df : df.feature + df.chainID)
feature_chain_trace_b =  go.Scatter3d(
    x=feature_df_F0_b.COM_x,
    y=feature_df_F0_b.COM_y,
    z=feature_df_F0_b.COM_z,
    name='backbone_COM',
    mode='markers',
    hovertext= feature_df_F0_b.feature_chain,
    marker=dict(
        size=18,
        #color=feature_df_F0.chainID,                # set color to an array/list of desired values
        color ='cyan',
        colorscale='Viridis',   # choose a colorscale
        opacity=0.4
    ),
   
)
feature_df_F0_b_GLY = feature_df[(feature_df.frame == 'frame_0') & feature_df.feature.str.endswith('backbone_GLY')].assign(feature_chain =  lambda df : df.feature + df.chainID)
feature_chain_trace_b_GLY =  go.Scatter3d(
    x=feature_df_F0_b_GLY.COM_x,
    y=feature_df_F0_b_GLY.COM_y,
    z=feature_df_F0_b_GLY.COM_z,
    name='backbone_GLY_COM',
    mode='markers',
    hovertext= feature_df_F0_b_GLY.feature_chain,
    marker=dict(
        size=18,
        #color=feature_df_F0.chainID,                # set color to an array/list of desired values
        color ='purple',
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ),
   
)
feature_df_F0_b_GLY

In [None]:

fig = go.Figure(data=[frame_0_trace,feature_chain_trace_b,feature_chain_trace_b_GLY])

fig.update_layout(
    #autosize=False,
    width=800,
    height=800
    
)
fig.show()

In [None]:
print(f'count of chains = {dcd_traj.top.n_chains}, \
         count of residues = {dcd_traj.top.n_residues},\
        count of potential Glycans = {len([r for r in dcd_traj.top.residues if str(r).startswith("GLY")]) } '
)

In [None]:
[r for r in dcd_traj.top.residues if str(r).startswith('GLY')]

In [None]:
dcd_traj.top.select("resn =~ 'GLY*'")

In [None]:
atom_id_LUP_chains = {f'chainID_{c}' : get_atom_ids_for_feature(dcd_traj,f'chainid == {c}') for c in range(dcd_traj.n_chains)}

In [None]:
fig = px.scatter_3d(gen_xyz_Table_4_LUP(LUP=atom_id_LUP_chains,keyNames=list(atom_id_LUP_chains.keys())), x='x', y='y', z='z',
          color='type',width=800,height=800,opacity=0.4, 
                    #size = [5 for i in range(frame_0_coord_df.shape[0])]
            )


In [None]:
fig.show()