# Predict Helix Capping Residues #

The goal is to identify residues just before an alpha helix begins or the residues just after the helix ends. This will improve secondary structure predictors becuase they often extend too far or do not start at the right place. 

The CapsDB has annoted sequences of structures of helix capping residues that can be used to train a deep nueral net. We will use a Bidirectional LSTM using phi/psi features to see if it will those will be good predictors.

## 1. Download data ##

## 2. Generate Features ##
### MMTF Pyspark Imports ###

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.ml import ProteinSequenceEncoder
import numpy as np
import pandas as pd
import math
import os

### Custom imports ###

In [2]:
import secondaryStructureExtractorFull
#import mmtfToASA

### Configure Spark Context ###

In [3]:
spark = SparkSession.builder.master("local[8]").appName("DeepCap").getOrCreate()

### Create SQLContext ###

In [4]:
from pyspark.sql import SQLContext
from pyspark.sql.functions import concat, col, lit, array_contains

sqlContext = SQLContext(spark)

### Read in filtered cap+MMTF data from parquet file###

In [6]:
# Read output of above get_dataset operation from parquet file
parquetPath = '/home/ec2-user/SageMaker/ProteinFragmenter/data-parquet'
dataframe = sqlContext.read.parquet(parquetPath)
data = dataframe.toPandas()
data = data.drop('__index_level_0__', axis=1)

capsdb = sqlContext.read.parquet('caps_descriptors.parquet')

### Get Torsion angle and secondary structure info ###

In [7]:
data.head(10)

Unnamed: 0,pdbId,chain,resi,resn,phi,psi
0,2ygn,A,1,THR,,163.677383
1,2ygn,A,2,GLY,-66.660973,160.703186
2,2ygn,A,3,SER,-123.853607,-7.871733
3,2ygn,A,4,LEU,-74.896896,137.483932
4,2ygn,A,5,TYR,-134.41983,140.864288
5,2ygn,A,6,LEU,-139.275024,127.621544
6,2ygn,A,7,TRP,-152.167755,166.833832
7,2ygn,A,8,ILE,-108.079048,119.799377
8,2ygn,A,9,ASP,-61.78611,150.193756
9,2ygn,A,10,ALA,-47.469296,-38.584801


In [34]:
df1 = capsdb.toPandas()
df = pd.merge(data, df1, left_on=('pdbId','chain'), right_on=('pdbid','chain'), how='inner')
df = df[['pdbId', 'chain', 'resi', 'resn', 'phi', 'psi', 'startcap', 'endcap']]


### Create labels

In [9]:
df['is_cap'] = df.apply(lambda x: 1 if (x['resi'] >= x['startcap'] and x['resi'] <= x['endcap']) else 0, axis=1)
df_caps = df.groupby(["pdbId", "chain", "resi"])['is_cap'].max().reset_index()

In [10]:
data_caps = pd.merge(data, df_caps, left_on=('pdbId','chain', 'resi'), right_on=('pdbId','chain', 'resi'), how='inner')

In [11]:
from Bio.PDB.Polypeptide import aa3
one_hot_encoded = pd.DataFrame(data_caps.resn.apply(lambda x: secondaryStructureExtractorFull.get_residue(x)).tolist(), columns=aa3)
one_hot_encoded.head()
data_caps = data_caps.join(one_hot_encoded)
data_caps.head()

Unnamed: 0,pdbId,chain,resi,resn,phi,psi,is_cap,ALA,CYS,ASP,...,MET,ASN,PRO,GLN,ARG,SER,THR,VAL,TRP,TYR
0,2ygn,A,1,THR,,163.677383,0,0,0,0,...,0,0,0,0,0,0,1,0,0,0
1,2ygn,A,2,GLY,-66.660973,160.703186,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,2ygn,A,3,SER,-123.853607,-7.871733,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
3,2ygn,A,4,LEU,-74.896896,137.483932,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2ygn,A,5,TYR,-134.41983,140.864288,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [13]:
data_caps.head()

Unnamed: 0,pdbId,chain,resi,resn,phi,psi,is_cap,ALA,CYS,ASP,...,MET,ASN,PRO,GLN,ARG,SER,THR,VAL,TRP,TYR
0,2ygn,A,1,THR,,163.677383,0,0,0,0,...,0,0,0,0,0,0,1,0,0,0
1,2ygn,A,2,GLY,-66.660973,160.703186,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,2ygn,A,3,SER,-123.853607,-7.871733,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
3,2ygn,A,4,LEU,-74.896896,137.483932,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2ygn,A,5,TYR,-134.41983,140.864288,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


# Define functions for feature extraction

In [14]:

def is_cap(pdbId, chain, resi, is_cap):
    if is_cap == 1:
        return(1)
    elif is_cap == 0:
        return(0)
    else:
        raise ValueError("is_cap must be 0 or 1")

def angle_to_cos(angle):
    if(angle == 0 or np.isnan(angle)):
        return 0
    else:
        return np.cos(np.pi * angle/180)

def angle_to_sin(angle):
    if(angle == 0 or np.isnan(angle)):
        return 0
    else:
        return np.sin(np.pi * angle/180)


# Process data into list of arrays

In [81]:
groups = data_caps.groupby(["pdbId", "chain"])
groups2 = set((pdbid, chain) for (pdbid, chain),group in groups)
                           # num pdbs,    max len of seqs, num features

# Check max length of protein chains
# maxlen = 0
# for i, ((pdbid, chain), group) in enumerate(groups):
#     l = 0
#     for j, featuretuple in enumerate(group.itertuples()):
#         l += 1
#         if l > maxlen:
#             maxlen = l
# print(maxlen)

In [84]:
import os

if not os.path.isfile("pdb_seqres.txt"):
    !wget ftp://ftp.wwpdb.org/pub/pdb/derived_data/pdb_seqres.txt.gz & gunzip pdb_seqres.txt
      
with open("pdb_seqres.txt") as pdbseqs, open("capsdb_seqres.fasta", "w") as capsdbseqs:
    for line in pdbseqs:
        if line.startswith(">"):
            pdb, chain = line.split()[0].split("_")
            if (pdb[1:], chain) in groups2:
                
                sequence = next(pdbseqs)
                try:
                    groups.groups.keys()
                except KeyError:
                    continue
                capsdbseqs.write(line)
                capsdbseqs.write(sequence)
            
#Uncomment if docker is not installed
#!curl -fsSL https://get.docker.com -o get-docker.sh & sh get-docker.sh
    
!docker pull edraizen/usearch:latest
!docker run -v `pwd`:/data -w /data --entrypoint /opt/usearch/usearch edraizen/usearch:latest -cluster_fast capsdb_seqres.fasta -id 0.4 -centroids capsdb_centroids.fasta



latest: Pulling from edraizen/usearch
Digest: sha256:061556fcbb9e89a4421ad77db4f4f63c96745e39cac9a8e047830d091d6f4046
Status: Image is up to date for edraizen/usearch:latest
usearch v11.0.667_i86linux32, 4.0Gb RAM (62.9Gb total), 4 cores
(C) Copyright 2013-18 Robert C. Edgar, all rights reserved.
https://drive5.com/usearch

License: ed4bu@virginia.edu

00:00 43Mb    100.0% Reading capsdb_seqres.fasta
00:00 38Mb    100.0% DF                         
00:00 38Mb   6714 seqs, 6700 uniques, 6687 singletons (99.8%)
00:00 38Mb   Min size 1, median 1, max 3, avg 1.00
00:00 41Mb    100.0% DB
00:02 161Mb   100.0% 5563 clusters, max size 25, avg 1.2
00:02 161Mb   100.0% Writing centroids to capsdb_centroids.fasta
                                                                
      Seqs  6700
  Clusters  5563
  Max size  25
  Avg size  1.2
  Min size  1
Singletons  4933, 73.6% of seqs, 88.7% of clusters
   Max mem  161Mb
      Time  2.00s
Throughput  3350.0 seqs/sec.



In [85]:
with open("capsdb_centroids.fasta") as capsdb_centroids:   
    clustered_pdbs = [line.split()[0].split("_") for line in capsdb_centroids if line.startswith(">")]


In [67]:
clustered_pdbs[1][0][1:]

'12e8'

In [86]:

for i,j in enumerate(clustered_pdbs):
    clustered_pdbs[i][0]= j[0][1:]
    

In [87]:
clustered_pdbs = set([tuple(x) for x in clustered_pdbs])

In [88]:
print(len(groups2.intersection(clustered_pdbs)))
print(len(groups))

5563
6714


In [90]:
train_chains = []
label_chains = []
laglabel_chains = []

pdb_order = []

for i, ((pdbid, chain),group) in enumerate(groups):
    
    # Create empty arrays
    if not (pdbid, chain) in clustered_pdbs: continue
    train_chain = np.zeros((1300,24), dtype=float) # max chain length is 1288 residues
    label_chain = np.zeros((1300,1), dtype=int)
    laglabel_chain = np.zeros((5000,1), dtype=int)
    
    # Populate arrays
    for j, featuretuple in enumerate(group.itertuples()):
        train_chain[j, :] = (angle_to_cos(featuretuple.phi), angle_to_sin(featuretuple.phi), 
                              angle_to_cos(featuretuple.psi), angle_to_sin(featuretuple.psi), featuretuple.ALA,
                              featuretuple.CYS,featuretuple.ASP,featuretuple.GLU,featuretuple.PHE,
                              featuretuple.GLY,featuretuple.HIS,featuretuple.ILE,featuretuple.LYS,
                              featuretuple.LEU,featuretuple.MET,featuretuple.ASN,featuretuple.PRO,
                              featuretuple.GLN,featuretuple.ARG,featuretuple.SER,featuretuple.THR,
                              featuretuple.VAL,featuretuple.TRP,featuretuple.TYR)
        label_chain[j,0] = is_cap(featuretuple.pdbId, featuretuple.chain, featuretuple.resi, featuretuple.is_cap)
        if (j > 0):
            laglabel_chain[j-1,0] = label_chain[j,0]
    
    # Trim zeros
    trimmed_train = train_chain[~np.all(train_chain == 0, axis=1)]
    trimmed_label = label_chain[:trimmed_train.shape[0]]
    trimmed_laglabel = label_chain[:trimmed_train.shape[0]+1]
    
    # Add chain data to lists of arrays
    train_chains.append(trimmed_train)
    label_chains.append(trimmed_label)
    laglabel_chains.append(trimmed_laglabel)
    pdb_order.append((pdbid,chain))

In [91]:
pdb_order

[('1a1x', 'A'),
 ('1a62', 'A'),
 ('1a73', 'A'),
 ('1a8l', 'A'),
 ('1a92', 'A'),
 ('1a9x', 'A'),
 ('1ae9', 'A'),
 ('1ah7', 'A'),
 ('1al3', 'A'),
 ('1aoc', 'A'),
 ('1aol', 'A'),
 ('1atg', 'A'),
 ('1atz', 'A'),
 ('1ayl', 'A'),
 ('1ayo', 'A'),
 ('1azo', 'A'),
 ('1b0n', 'B'),
 ('1b12', 'A'),
 ('1b25', 'A'),
 ('1b5e', 'A'),
 ('1b6a', 'A'),
 ('1baz', 'A'),
 ('1bea', 'A'),
 ('1bf2', 'A'),
 ('1bgc', 'A'),
 ('1bgf', 'A'),
 ('1bkr', 'A'),
 ('1bm8', 'A'),
 ('1brt', 'A'),
 ('1btk', 'A'),
 ('1bu8', 'A'),
 ('1bxy', 'A'),
 ('1byf', 'A'),
 ('1byi', 'A'),
 ('1c1k', 'A'),
 ('1c30', 'B'),
 ('1c3c', 'A'),
 ('1c4q', 'A'),
 ('1c7k', 'A'),
 ('1c7s', 'A'),
 ('1c96', 'A'),
 ('1cb8', 'A'),
 ('1cc8', 'A'),
 ('1ccw', 'A'),
 ('1ccw', 'B'),
 ('1cew', 'I'),
 ('1chm', 'A'),
 ('1ci4', 'A'),
 ('1cl1', 'A'),
 ('1cl8', 'A'),
 ('1clv', 'I'),
 ('1cmc', 'A'),
 ('1coz', 'A'),
 ('1cq3', 'A'),
 ('1ctf', 'A'),
 ('1cv8', 'A'),
 ('1cvr', 'A'),
 ('1cxq', 'A'),
 ('1cy5', 'A'),
 ('1czy', 'A'),
 ('1d02', 'A'),
 ('1d0d', 'A'),
 ('1d0q'

# Write training data to pickle file

In [92]:
import pickle
pickle_out = open("pickled_data/train_chains.pickle","wb")
pickle.dump(train_chains, pickle_out)
pickle_out.close()

pickle_out = open("pickled_data/label_chains.pickle","wb")
pickle.dump(label_chains, pickle_out)
pickle_out.close()

pickle_out = open("pickled_data/laglabel_chains.pickle","wb")
pickle.dump(laglabel_chains, pickle_out)
pickle_out.close()

pickle_out = open("pickled_data/pdb_order.pickle","wb")
pickle.dump(pdb_order, pickle_out)
pickle_out.close()

In [8]:
spark.stop()

# The code below reads in 1-dim (binary) labels and writes back out as 2-dim labels (one-hot)

In [60]:
import pickle
label_chain_in = open("pickled_data/label_chains.pickle","rb")
labels = pickle.load(label_chain_in)

newlabels = []
for i, l in enumerate(labels):
    temp = np.zeros([l.shape[0], 2], dtype=int)
    temp[:,1] = l[:,0]
    temp[:,0] = (l[:,0]+1)%2
    newlabels.append(temp)

pickle_out = open("pickled_data/label_chains.pickle","wb")
pickle.dump(newlabels, pickle_out)
pickle_out.close()

In [61]:
laglabel_chain_in = open("pickled_data/laglabel_chains.pickle","rb")
labelslag = pickle.load(laglabel_chain_in)

newlabelslag = []
for i, l in enumerate(labelslag):
    temp = np.zeros([l.shape[0], 2], dtype=int)
    temp[:,1] = l[:,0]
    temp[:,0] = (l[:,0]+1)%2
    newlabelslag.append(temp)

pickle_out = open("pickled_data/laglabel_chains.pickle","wb")
pickle.dump(newlabelslag, pickle_out)
pickle_out.close()

# The code below reads in train/label and writes out lists sorted by chain length

In [31]:
import pickle
train_chain_in = open("pickled_data/train_chains.pickle","rb")
train = pickle.load(train_chain_in)

lens = [len(chain) for chain in train]
inds = range(len(train))
lenSeries = pd.Series(data=lens, index=inds).sort_values()
newInds = lenSeries.index.values
newlist = []
[newlist.append(train[i]) for i in newInds]

pickle_out = open("pickled_data/train_chains_sorted.pickle","wb")
pickle.dump(newlist, pickle_out)
pickle_out.close()

# now sort label list
label_chain_in = open("pickled_data/label_chains.pickle","rb")
labels = pickle.load(label_chain_in)

newlist2 = []
[newlist2.append(labels[i]) for i in newInds]

pickle_out = open("pickled_data/label_chains_sorted.pickle","wb")
pickle.dump(newlist2, pickle_out)
pickle_out.close()

# now sort laglabel list
laglabel_chain_in = open("pickled_data/laglabel_chains.pickle","rb")
labelslag = pickle.load(laglabel_chain_in)

newlist3 = []
[newlist3.append(labelslag[i]) for i in newInds]

pickle_out = open("pickled_data/laglabel_chains_sorted.pickle","wb")
pickle.dump(newlist3, pickle_out)
pickle_out.close()