# Data Aquisition
In this file, I download and format data to be provided to machine learning models. There are several steps:


## 1. Aquire a list of PDB IDs to download.

In [121]:
import datetime
import requests
import tqdm
import prody as pr
pr.confProDy(verbosity='error')
import numpy as np
import re
import pickle
from joblib import Parallel, delayed
from multiprocessing import Pool
import sys
    
AA_MAP = {'A': 15,'C': 0,'D': 1,'E': 17,'F': 8,'G': 10,'H': 11,'I': 5,'K': 4,'L': 12,'M': 19,'N': 9,'P': 6,'Q': 3,'R': 13,'S': 2,'T': 7,'V': 16,'W': 14,'Y': 18}
CUR_DIR = "/home/jok120/projML/data/"
pr.pathPDBFolder(CUR_DIR + "pdbgz/")
np.set_printoptions(suppress=True) # suppresses scientific notation when printing
np.set_printoptions(threshold=np.nan) # suppresses '...' when printing


today = datetime.datetime.today()
day = today.day
month = today.month
suffix = "_{0:02d}{1:02d}".format(month, day)
suffix

'_1208'

In [122]:
url = 'http://www.rcsb.org/pdb/rest/search'

# Helix Only Dataset
alpha_helix_query = """<orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.EntriesOfEntitiesQuery</queryType>
    <description>Entries of :Secondary structure has:  1 or more Alpha Helices and between 85 and 100 percent of elements are Alpha Helical  and 0 or less Beta Sheets and 0 or less percent of elements are Beta Sheet 
and
Oligomeric state Search : Min Number of oligomeric state=1 Max Number of oligomeric state=1
and
Sequence Length is between 9 and 60 
</description>
    <queryId>72D8FBC1</queryId>
    <resultCount>86</resultCount>
    <runtimeStart>2018-11-29T13:52:29Z</runtimeStart>
    <runtimeMilliseconds>2</runtimeMilliseconds>
    <parent><![CDATA[<orgPdbCompositeQuery version="1.0">
    <resultCount>1890</resultCount>
    <queryId>1A24C4C2</queryId>
 <queryRefinement>
  <queryRefinementLevel>0</queryRefinementLevel>
  <orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.SecondaryStructureQuery</queryType>
    <description>Secondary structure has:  1 or more Alpha Helices and between 85 and 100 percent of elements are Alpha Helical  and 0 or less Beta Sheets and 0 or less percent of elements are Beta Sheet </description>
    <queryId>F9D5DD03</queryId>
    <resultCount>1890</resultCount>
    <runtimeStart>2018-11-29T13:35:16Z</runtimeStart>
    <runtimeMilliseconds>637</runtimeMilliseconds>
    <polyStats.helixPercent.comparator>between</polyStats.helixPercent.comparator>
    <polyStats.helixCount.comparator>between</polyStats.helixCount.comparator>
    <polyStats.sheetPercent.comparator>between</polyStats.sheetPercent.comparator>
    <polyStats.sheetCount.comparator>between</polyStats.sheetCount.comparator>
    <polyStats.helixPercent.min>85</polyStats.helixPercent.min>
    <polyStats.helixPercent.max>100</polyStats.helixPercent.max>
    <polyStats.helixCount.min>1</polyStats.helixCount.min>
    <polyStats.sheetPercent.max>0</polyStats.sheetPercent.max>
    <polyStats.sheetCount.max>0</polyStats.sheetCount.max>
  </orgPdbQuery>
 </queryRefinement>
 <queryRefinement>
  <queryRefinementLevel>1</queryRefinementLevel>
  <conjunctionType>and</conjunctionType>
  <orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.BiolUnitQuery</queryType>
    <description>Oligomeric state Search : Min Number of oligomeric state=1 Max Number of oligomeric state=1</description>
    <queryId>1BB8A37D</queryId>
    <resultCount>59551</resultCount>
    <runtimeStart>2018-11-29T13:35:17Z</runtimeStart>
    <runtimeMilliseconds>1060</runtimeMilliseconds>
    <oligomeric_statemin>1</oligomeric_statemin>
    <oligomeric_statemax>1</oligomeric_statemax>
  </orgPdbQuery>
 </queryRefinement>
 <queryRefinement>
  <queryRefinementLevel>2</queryRefinementLevel>
  <conjunctionType>and</conjunctionType>
  <orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.SequenceLengthQuery</queryType>
    <description>Sequence Length is between 9 and 60 </description>
    <queryId>95F56C17</queryId>
    <resultCount>31099</resultCount>
    <runtimeStart>2018-11-29T13:35:18Z</runtimeStart>
    <runtimeMilliseconds>2545</runtimeMilliseconds>
    <v_sequence.chainLength.min>9</v_sequence.chainLength.min>
    <v_sequence.chainLength.max>60</v_sequence.chainLength.max>
  </orgPdbQuery>
 </queryRefinement>
</orgPdbCompositeQuery>]]></parent>
  </orgPdbQuery>"""

# Retrieves all PDB IDs that have resolution < 3.0 Angstroms.
# Must contain 1 protein.
query = """<orgPdbCompositeQuery version="1.0">
    <resultCount>118342</resultCount>
    <queryId>3DE6E672</queryId>
 <queryRefinement>
  <queryRefinementLevel>0</queryRefinementLevel>
  <orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.ResolutionQuery</queryType>
    <description>Resolution is 3.0 or less</description>
    <queryId>C285C563</queryId>
    <resultCount>118342</resultCount>
    <runtimeStart>2018-05-30T17:49:17Z</runtimeStart>
    <runtimeMilliseconds>1631</runtimeMilliseconds>
    <refine.ls_d_res_high.comparator>between</refine.ls_d_res_high.comparator>
    <refine.ls_d_res_high.max>3.0</refine.ls_d_res_high.max>
  </orgPdbQuery>
 </queryRefinement>
 <queryRefinement>
  <queryRefinementLevel>1</queryRefinementLevel>
  <conjunctionType>and</conjunctionType>
  <orgPdbQuery>
    <version>head</version>
    <queryType>org.pdb.query.simple.ChainTypeQuery</queryType>
    <description>Chain Type: there is a Protein chain</description>
    <queryId>6631AA3E</queryId>
    <resultCount>137581</resultCount>
    <runtimeStart>2018-05-30T17:49:19Z</runtimeStart>
    <runtimeMilliseconds>1502</runtimeMilliseconds>
    <containsProtein>Y</containsProtein>
    <containsDna>?</containsDna>
    <containsRna>?</containsRna>
    <containsHybrid>?</containsHybrid>
  </orgPdbQuery>
 </queryRefinement>
</orgPdbCompositeQuery>"""

header = {'Content-Type': 'application/x-www-form-urlencoded'}
response = requests.post(url, data=alpha_helix_query, headers=header)
if response.status_code != 200:
    print ("Failed to retrieve results.")
    
PDB_IDS = response.text.split("\n")    
print ("Retrieved {0} PDB IDs.".format(len(PDB_IDS)))

Retrieved 87 PDB IDs.


## 2. Set amino acid encoding and angle downloading methods.

In [123]:
def seq_to_onehot(seq):
    """ Given an AA sequence, returns a vector of one-hot vectors."""
    vector_array = []
    for aa in seq:
        one_hot = np.zeros(len(AA_MAP), dtype=bool)
        one_hot[AA_MAP[aa]] = 1
        vector_array.append(one_hot)
    return np.asarray(vector_array)

In [124]:
def get_bond_angles(res, next_res):
    """ Given 2 residues, returns the ncac, cacn, and cnca bond angles between them."""
    atoms = res.backbone.copy()
    atoms_next = next_res.backbone.copy()
    ncac = pr.calcAngle(atoms[0], atoms[1], atoms[2], radian=True)
    cacn = pr.calcAngle(atoms[1], atoms[2], atoms_next[0], radian=True)
    cnca = pr.calcAngle(atoms[2], atoms_next[0], atoms_next[1], radian=True)
    return ncac, cacn, cnca

In [125]:
def get_angles_from_chain(chain, pdb_id):
    """ Given a ProDy Chain object (from a Hierarchical View), return a numpy array of 
        angles. Returns None if the PDB should be ignored due to weird artifacts. Also measures
        the bond angles along the peptide backbone, since they account for significat variation.
        i.e. [[phi, psi, omega, ncac, cacn, cnca, chi1, chi2, chi3, chi4, chi5], [...] ...] """
    PAD_CHAR = 0
    OUT_OF_BOUNDS_CHAR = 0
    dihedrals = []
    sequence = ""
    
    try:
        if chain.nonstdaa:
            print("Non-standard AAs found.")
            return None
        sequence = chain.getSequence()
        length = len(sequence)
        chain = chain.select("protein and not hetero").copy()
    except Exception as e:
        print("Problem loading sequence.", e)
        return None

    all_residues = list(chain.iterResidues())
    prev = all_residues[0].getResnum()
    for i, res in enumerate(all_residues):   
        if (not res.isstdaa):
            print("Found a non-std AA. Why didn't you catch this?", chain)
            print(res.getNames())
            return None
        if res.getResnum() != prev:
            print('\rNon-continuous!!', pdb_id, end="")
            return None
        else:
            prev = res.getResnum() + 1
        try:
            phi = pr.calcPhi(res, radian=True, dist=None)
        except:
            phi = OUT_OF_BOUNDS_CHAR
        try:
            psi = pr.calcPsi(res, radian=True, dist=None)
        except:
            psi = OUT_OF_BOUNDS_CHAR
        try:
            omega = pr.calcOmega(res, radian=True, dist=None)
        except:
            omega = OUT_OF_BOUNDS_CHAR
#         if phi == 0 and psi == 0 and omega == 0:
#             return None
            
        if i == len(all_residues) - 1:
            BONDANGLES = [0, 0, 0]
        else:
            try:
                BONDANGLES = list(get_bond_angles(res, all_residues[i+1]))
            except Exception as e:
                print("Bond angle issue with", pdb_id, e)
                return None

        BACKBONE = [phi,psi,omega]
                  
        def compute_single_dihedral(atoms):
            return pr.calcDihedral(atoms[0],atoms[1],atoms[2],atoms[3],radian=True)
        
        def compute_all_res_dihedrals(atom_names):
            atoms = [res.select("name " + an) for an in atom_names]
            if None in atoms:
                return None
            res_dihedrals = []
            if len(atom_names) > 0:
                for i in range(len(atoms)-3):      
                    a = atoms[i:i+4]
                    res_dihedrals.append(compute_single_dihedral(a))
            return BACKBONE + BONDANGLES + res_dihedrals + (5 - len(res_dihedrals))*[PAD_CHAR]

        if res.getResname()=="ARG":
            atom_names = ["CA","C","CB","CG","CD","NE","CZ","NH1"]             
        elif res.getResname()=="HIS":
            atom_names = ["CA","C","CB","CG","ND1"]            
        elif res.getResname()=="LYS":
            atom_names = ["CA","C","CB","CG","CD","CE","NZ"]                   
        elif res.getResname()=="ASP":
            atom_names = ["CA","C","CB","CG","OD1"]            
        elif res.getResname()=="GLU":
            atom_names = ["CA","C","CB","CG","CD","OE1"]            
        elif res.getResname()=="SER":
            atom_names = ["CA","C","CB", "OG"]       
        elif res.getResname()=="THR":
            atom_names = ["CA","C","CB","CG2"]                    
        elif res.getResname()=="ASN":
            atom_names = ["CA","C","CB","CG","ND2"]                    
        elif res.getResname()=="GLN":
            atom_names = ["CA","C","CB","CG","CD","NE2"]                    
        elif res.getResname()=="CYS":
            atom_names = ["CA","C","CB","SG"]         
        elif res.getResname()=="GLY":
            atom_names = []                    
        elif res.getResname()=="PRO":
            atom_names = []                    
        elif res.getResname()=="ALA":
            atom_names = []            
        elif res.getResname()=="VAL":
            atom_names = ["CA","C","CB","CG1"]        
        elif res.getResname()=="ILE":
            atom_names = ["CA","C","CB","CG1","CD1"]        
        elif res.getResname()=="LEU":
            atom_names = ["CA","C","CB","CG","CD1"]        
        elif res.getResname()=="MET":
            atom_names = ["CA","C","CB","CG","SD","CE"]                    
        elif res.getResname()=="PHE":
            atom_names = ["CA","C","CB","CG", "CD1"]         
        elif res.getResname()=="TRP":
            atom_names = ["CA","C","CB","CG","CD1"]                          
        elif res.getResname()=="TYR":
            atom_names = ["CA","C","CB","CG","CD1"]
        else:
            continue
            
        calculated_dihedrals = compute_all_res_dihedrals(atom_names)
        if calculated_dihedrals == None:
            return None
        dihedrals.append(calculated_dihedrals)

    # No normalization
    dihedrals_np = np.asarray(dihedrals)
    # Check for NaNs - they shouldn't be here, but certainly should be excluded if they are.
    if np.any(np.isnan(dihedrals_np)):
        print("NaNs found")
        return None
    return dihedrals_np, sequence

## 3a. Iterate through all chains in `PDB_IDS`, saving all results to disk.

### Remove empty-string PDB ids.

In [126]:
PDB_IDS = list(filter(lambda x: x != "", PDB_IDS))
len(PDB_IDS)

86

## 3b. Parallelized method of downloading data (not yet implemented).

In [127]:
%time
def work(pdb_id):
    pdb_dihedrals = []
    pdb_sequences = []
    ids = []
    try:
        pdb_hv = pr.parsePDB(pdb_id).getHierView()
        for chain in pdb_hv:
            chain_id = chain.getChid()
            dihedrals_sequence = get_angles_from_chain(chain, pdb_id)
            if dihedrals_sequence is None:
                continue 
            dihedrals, sequence = dihedrals_sequence
            pdb_dihedrals.append(dihedrals)
            pdb_sequences.append(sequence)
            ids.append(pdb_id + "_" + chain_id)
    except Exception as e:
        print("Whoops, returning where I am.", e)
    if len(pdb_dihedrals) == 0:
        return None
    else:
        return pdb_dihedrals, pdb_sequences, ids


def _foo(i):
    return work(PDB_IDS[i])


with Pool(16) as p:
    results = list(tqdm.tqdm(p.imap(_foo, range(len(PDB_IDS))), total=len(PDB_IDS)))

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 7.15 µs


  0%|          | 0/86 [00:00<?, ?it/s]

Non-standard AAs found.


100%|██████████| 86/86 [00:03<00:00, 25.01it/s]


## 4. Save Python lists of data to disk. 

In [128]:
with open("raw_aquired_helices_only" + suffix + ".pkl", "wb") as F:
    pickle.dump(results, F)
len(results)

86

In [157]:
with open("raw_aquired_py3" + suffix + ".pkl", "wb") as F:
    pickle.dump(results, F)
len(results)

86

In [189]:
with open("raw_aquired" + "_1207" + ".pkl", "rb") as F:
    results = pickle.load(F)
len(results)

121481

In [190]:
MAX_LEN = 500
results_onehots = []
c = 0
for r in results:
    if not r:
        # PDB failed to download
        continue
    ang, seq, i = r
    if len(seq[0]) > MAX_LEN:
        continue
    for j in range(len(ang)):
        results_onehots.append((ang[j], seq_to_onehot(seq[j]), i[j]))
        c += 1
c

129364

In [160]:
results_onehots[0]

(array([[ 0.        ,  2.75926986, -3.12618806,  1.92773927,  2.03690809,
          2.12002397,  2.11688401,  2.72434335, -1.25779762,  0.        ,
          0.        ],
        [-2.52922554,  2.39936229,  3.12111854,  1.92163996,  2.03487097,
          2.11627554, -2.28186913,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-1.65842877,  2.61421349,  3.09431138,  1.9264074 ,  2.02928154,
          2.10008816, -0.74494379,  1.74459812,  0.        ,  0.        ,
          0.        ],
        [-1.44308421,  2.90906143,  3.13190164,  1.8854647 ,  2.04695474,
          2.11976509,  2.24119261,  0.        ,  0.        ,  0.        ,
          0.        ],
        [-1.09137911, -0.63291584,  3.1296441 ,  1.89041526,  2.03350427,
          2.11062713, -2.30934125,  1.6085956 ,  0.29658616,  0.        ,
          0.        ],
        [-1.0410247 , -0.73365845,  3.12767338,  1.94291148,  2.03118101,
          2.1212867 ,  0.        ,  0.        ,  0.        ,  0.       

In [165]:
all_ohs = []
all_angs = []
all_ids = []
for r in results_onehots:
    a, oh, i = r
    all_ohs.append(oh)
    all_angs.append(a)
    all_ids.append(i)

In [166]:
from sklearn.model_selection import train_test_split

In [167]:
ohs_ids = list(zip(all_ohs, all_ids))

In [None]:
# To only have a training set, use this cell
X_train, X_test, X_val = ohs_ids, [], ohs_ids
y_train, y_test, y_val = all_angs, [], all_angs

In [168]:
X_train, X_test, y_train, y_test = train_test_split(ohs_ids, all_angs, test_size=0.20, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.14, random_state=42)

In [169]:
list(map(len, [X_train, y_train, X_test, y_test, X_val, y_val]))

[36984, 36984, 10752, 10752, 6021, 6021]

Remove and save ids.

In [170]:
X_train_labels = [x[1] for x in X_train]
X_test_labels = [x[1] for x in X_test]
X_val_labels = [x[1] for x in X_val]

In [171]:
X_train = [x[0] for x in X_train]
X_test = [x[0] for x in X_test]
X_val = [x[0] for x in X_val]

In [173]:
data = {"train": {"seq": X_train,
                  "ang": angle_list_to_sin_cos(y_train),
                  "ids": X_train_labels},
        "valid": {"seq": X_val,
                  "ang": angle_list_to_sin_cos(y_val),
                  "ids": X_val_labels},
        "test":  {"seq": X_test,
                  "ang": angle_list_to_sin_cos(y_test),
                  "ids": X_test_labels},
       "settings": {"max_len": max(map(len, all_ohs))}}

In [174]:
data["train"]["ang"][0]

array([[ 1.        ,  0.        ,  0.99793888, -0.06417165, -0.99905134,
        -0.04354786, -0.44020316,  0.8978982 , -0.47205415,  0.88156955,
        -0.50556458,  0.86278877, -0.53191824,  0.84679572,  0.55098613,
         0.8345144 ,  1.        ,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ],
       [-0.61309058, -0.79001262, -0.78241815,  0.62275343, -0.99751088,
         0.07051276, -0.39918713,  0.91686947, -0.45700449,  0.88946439,
        -0.51101688,  0.85957068, -0.64986208, -0.76005215, -0.80901521,
        -0.58778771, -0.93122258,  0.36445098,  1.        ,  0.        ,
         1.        ,  0.        ],
       [ 0.45311041, -0.8914544 , -0.74198854,  0.67041256, -0.9993844 ,
         0.03508318, -0.31564632,  0.94887692, -0.44249117,  0.89677286,
        -0.49906102,  0.86656684,  0.83758557, -0.54630615,  0.61312919,
        -0.78998265,  1.        ,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ],
       [ 0.44107772

In [149]:
# data_file_name = "all_helices_trig_1208.pkl"
# data_file_name = "data" + "_1208" + ".pkl"

In [None]:
data

In [None]:
import pickle

In [None]:
with open("data.pkl", "rb") as datafile:
    data = pickle.load(datafile)

In [175]:
import torch

In [150]:
torch.save(data, data_file_name)

In [None]:
X_train_labels

In [172]:
def angle_list_to_sin_cos(angs, reshape=True):
    new_list = []
    new_pad_char = np.array([1, 0])
    for a in angs:   
        new_mat = np.zeros((a.shape[0], a.shape[1], 2))
        new_mat[:, :, 0] = np.cos(a)
        new_mat[:, :, 1] = np.sin(a)
#         new_mat = (new_mat != new_pad_char) * new_mat
        if reshape:
            new_list.append(new_mat.reshape(-1, 22))
        else:
            new_list.append(new_mat)
    return new_list
        

In [176]:
data_trig = {"train": {"seq": X_train,
                  "ang": angle_list_to_sin_cos(y_train),
                  "ids": X_train_labels},
        "valid": {"seq": X_val,
                  "ang": angle_list_to_sin_cos(y_val),
                  "ids": X_val_labels},
        "test":  {"seq": X_test,
                  "ang": angle_list_to_sin_cos(y_test),
                  "ids": X_test_labels},
       "settings": {"max_len": max(map(len, all_ohs))}}

In [177]:
data_trig["train"]["ang"][3]

array([[ 1.        ,  0.        , -0.97292414,  0.23112469, -0.99818488,
         0.06022408, -0.40819144,  0.91289635, -0.40841525,  0.91279625,
        -0.57571118,  0.81765313,  1.        ,  0.        ,  1.        ,
         0.        ,  1.        ,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ],
       [ 0.49209873, -0.87053939, -0.86364721,  0.50409672, -0.99765601,
        -0.06842875, -0.36280029,  0.93186692, -0.46702333,  0.88424499,
        -0.52697015,  0.84988379, -0.51807791, -0.85533343,  0.69953572,
         0.71459763,  1.        ,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ],
       [ 0.40566786, -0.91402056,  0.68824459, -0.72547872, -0.992339  ,
         0.1235448 , -0.36196718,  0.93219084, -0.43180886,  0.90196514,
        -0.53749303,  0.84326819, -0.68185243, -0.73148976,  0.02328633,
         0.99972884,  1.        ,  0.        ,  1.        ,  0.        ,
         1.        ,  0.        ],
       [ 0.61787361

In [155]:
with open("data_1208_trig.pkl", "wb") as f:
    pickle.dump(data_trig, f)

In [178]:
torch.save(data_trig, "data_1208_trig.pkl")

In [185]:
t = data_trig["train"]["ang"]
print(len(t))
for x in t:
    if np.any(np.isnan(x)):
        print("Problem")

36984


In [188]:
np.mean(list(map(len, data_trig["train"]["ang"])))

213.5910393683755