In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.ml import ProteinSequenceEncoder
import numpy as np
import pandas as pd
import math
import os
import py3Dmol
from mmtfPyspark.datasets import secondaryStructureExtractor
import secondaryStructureExtractorFull
from glob import glob
import re

In [None]:
# Reads in a section of the PDB. Note that there are, in total, 722 parts.

pdb = mmtfReader.read_sequence_file('full/part-00001') \
                .flatMap(StructureToPolymerChains()) \
                .filter(ContainsLProteinChain())

In [None]:
# Whereas secondaryStructureExtractor returns, for each protein chain, the PDB-ID and DSSPQ3 code, along with other
# secondary structure information, secondaryStructureExtractorFull returns, for each protein chain residue, the 
# amino acid identity, phi angle, and psi angle.

data = secondaryStructureExtractor.get_dataset(pdb).toPandas()
data_1 = secondaryStructureExtractorFull.get_dataset(pdb).toPandas()

In [None]:
# This function classifies a residue as a cap or not. Its definition of a cap depends on the context of the residue.
# This definition of a cap has three critera. First, a residue must not be helical or beta strand. Second, the residue
# immediately before, or immediately after, the residue under examination must be helical. Third, the two residues on
# the side of the residue under examination, opposite from the helix, must themselves not be helical. (This last
# requirement ensures that small 'kinks' in alpha helices are not classified as caps.)

def cap_identifier(q3sequence, residue_number):
    
    is_cap = 0
    seq_len = len(q3sequence)-1
    
    # For residues at the beginning of the sequence:
    if residue_number == 0:
        if q3sequence[0] == 'C' and q3sequence[1] == 'H':
            is_cap = 1        
    if residue_number == 1:
        if q3sequence[0] == 'C' and q3sequence[1] == 'C' and q3sequence[2] == 'H':
            is_cap = 1
    
    # For residues at the end of the sequence:
    if residue_number == seq_len:
        if q3sequence[seq_len] == 'C' and q3sequence[seq_len - 1] == 'H':
            is_cap = 1
    if residue_number == seq_len - 1:
        if q3sequence[seq_len] == 'C' and q3sequence[seq_len - 1] == 'C' and \
        q3sequence[seq_len - 2] == 'H':
            is_cap = 1
    
    # For residues not at the beginning or end of the sequence:
    if residue_number > 2 and residue_number < seq_len - 2:  
        if q3sequence[residue_number] == 'C': 
            if q3sequence[residue_number + 1] == 'H':
                if q3sequence[residue_number - 1] == 'C' and q3sequence[residue_number - 2] == 'C':
                    is_cap = 1
            if q3sequence[residue_number - 1] == 'H':
                if q3sequence[residue_number + 1] == 'C' and q3sequence[residue_number + 2] == 'C':
                    is_cap = 1

    return is_cap

In [None]:
# For the purposes of calculating a distance metric, it was convenient to normalize phi and psi angles, by the sine
# function, to be between -1 and 1.

def angle_to_sin(angle):
    if(angle == 0 or np.isnan(angle)):
        return 0
    else:
        return np.sin(np.pi * angle/180)

In [None]:
# These set of functions were created to calculate distance metrics for the identification of nearest neighbor 
# residues. The 'distance' function actually computes this distance between two residues. 
# The 'nearest_neighbor' function identifies the single nearest neighbor, from the training dataset, to a given 
# residue. The 'k_nearest_neighbors' identifies, in an analogous manner, the k nearest neighbors. The 
# 'compute_farthest' function is useful for the k-nearest neighbors approach, as it identifies which of the nearest neighbors is the farthest
# from the residue in question; in turn, this residue is to be 'bumped' from the list if a closer residue is found.

def distance(residue_1, residue_2):
    
    sin_phi_distance = np.abs(residue_1[1] - residue_2[1])
    sin_psi_distance = np.abs(residue_1[2] - residue_2[2])
    
    if residue_1[0] == residue_2[0]:
        residue_identity = 0
    else:
        residue_identity = 1
    
    d = (sin_phi_distance ** 3 + sin_psi_distance ** 3 + residue_identity * 2) ** (1. / 3)
    
    return d

def nearest_neighbor(training_data, test_residue):
    #print(training_data)
    min_dist = 3.0

    for index, row in training_data.iterrows():

        train_residue = [row[0], row[1], row[2]]
        #print(train_residue)
        #print(test_residue)
        dist = distance(train_residue, test_residue)
        if dist < min_dist:
            min_dist = dist
            cap_prediction = row[3]
            residue_match = train_residue
            pdb_id_match = row[4]
            res_num_match = row[5]
    
    return [cap_prediction, residue_match, min_dist, pdb_id_match, res_num_match]

def k_nearest_neighbors(training_data, test_residue, k):
    
    nearest = pd.DataFrame(columns = ['residue', 'sin_phi', 'sin_psi', 'is_cap', 'pdb_id', 'res_num'])
    
    for index, row in training_data.iterrows():
        
        training_residue = [row[0], row[1], row[2], row[3], row[4], row[5]]

        if index < k:

            nearest = nearest.append({'residue': training_residue[0], 'sin_phi': training_residue[1], \
                                      'sin_psi':training_residue[2], 'is_cap': training_residue[3], \
                                     'pdb_id':training_residue[4], 'res_num': training_residue[5]}, ignore_index = True)
            farthest = compute_farthest(nearest, test_residue)

        else:

            dist = distance(training_residue, test_residue)

            if dist < farthest[1]:
                
                nearest = nearest.drop([farthest[0]])
                nearest.loc[farthest[0]] = training_residue
                farthest = compute_farthest(nearest, test_residue)
    
    return nearest

def compute_farthest(nearest, test_residue):
    
    farthest_dist = 0
    
    for index, row in nearest.iterrows():
        
        dist = distance(row, test_residue)
        
        if dist > farthest_dist:
            
            farthest_index = index
            farthest_dist = dist

    return (farthest_index, farthest_dist)

In [None]:
# This creates a sample training dataset, from a section of the PDB.

PATH = "/home/ec2-user/SageMaker/ProteinFragmenter/inputdata_20190824"
EXT = "*.csv"
all_proteins = [file
                 for path, subdir, files in os.walk(PATH)
                 for file in glob(os.path.join(path, EXT))]

training_data_full = pd.DataFrame(columns = ['residue', 'sin_phi', 'sin_psi', 'is_cap', 'pdb_id', 'res_num'])

for i in all_proteins:
    training_data_single = pd.read_csv(i)
    m = re.search('20190824/(.+?)_dataframe.csv', i)
    if m:
        found = m.group(1)
    training_data_single['pdb_id'] = found
    training_data_single['res_num'] = training_data_single.index
    training_data_full = training_data_full.append(training_data_single, ignore_index = True)

In [None]:
# This is a test implementation of the nearest neighbor function, with the protein 1B35, chain A. For each residue
# in that protein chain, it outputs the nearest neighbor from the training dataset, along with its cap prediction.

test_sequence = pd.read_csv('inputdata__20190826/1B35.A_dataframe.csv')

for index, row in test_sequence.iterrows():

    test_residue = [row[0], row[1], row[2]] 
    print("Residue number: ", index+1)
    print("Residue information: ", test_residue)
    print("Cap identity: ", row[3])
    
    nn = nearest_neighbor(training_data_full, test_residue) # Returns the nearest neighbor to the test residue from that one protein chain, along with its distance and the prediction
    print("Nearest neighbor: ", nn) # The nearest neighbor for that residue of the test sequence
    print("Cap prediction: ", nn[0]) # The best cap prediction for that residue of the test sequence

In [None]:
# As a final add-on, this allows you to visualize a protein chain, and highlight different residues. Right now, it is
# simply coded as a test visualization.

test_query = 'pdb:1KZ1'
viewer = py3Dmol.view(query=test_query)
viewer.setStyle({'cartoon': {'color': 'grey'}})
viewer.setStyle({'resi': 28},{'cartoon': {'color': 'red'}})
viewer.show()