In [1]:
# Imports and classes
import xml.etree.ElementTree as ET
import os
from shutil import copyfile

class PDBTM:
    def __init__(self, id = str, chains = []): 
        self.id = id
        self.chains = chains
    
class Chain:
    def __init__(self, id = str, num_tm = int, typ = str, seq = str, regions = []): 
        self.id = id
        self.num_tm = num_tm
        self.typ = typ
        self.seq = seq
        self.regions = regions
    
class Region:
    def __init__(self, seq_beg = int, pdb_beg = int, seq_end = int, pdb_end = int, typ = str):
        self.seq_beg = seq_beg
        self.pdb_beg = pdb_beg
        self.seq_end = seq_end
        self.pdb_end = pdb_end
        self.typ = typ

In [2]:
def parse_pdbtm_xmls(paths):
    '''
    Parses PDBTM XML files to a list od PDBTM Objects.
    A PDBTM Object contains the PDB ID as well as a list of Chain objects. 
    A Chain Object contains the chain ID, the tm number, the type, the sequence and a list of Region objects.
    A Region object contains the begin and end indices of the sequence and the pdb file, as well as the type.
    '''
    ns = {'pdbtm': 'http://pdbtm.enzim.hu'}
    pdbtms = []
    for path in paths:
        pdbtm_xml = ET.parse(path) 
        pdbtm_root = pdbtm_xml.getroot()
        pdbtm_id = pdbtm_root.attrib.get('ID')
        
        if pdbtm_root.attrib.get('TMP') == 'yes':
            chains = []
            for chain_xml in pdbtm_root.findall('pdbtm:CHAIN', ns):
                chain_id = chain_xml.attrib.get('CHAINID')
                num_tm = chain_xml.attrib.get('NUM_TM')
                typ = chain_xml.attrib.get('TYPE')
                seq = chain_xml.find('pdbtm:SEQ', ns)

                regions = []
                for region_xml in chain_xml.findall('pdbtm:REGION', ns):
                    seq_beg = region_xml.attrib.get('seq_beg')
                    pdb_beg = region_xml.attrib.get('pdb_beg')
                    seq_end = region_xml.attrib.get('seq_end')
                    pdb_end = region_xml.attrib.get('pdb_end')
                    typ_region = region_xml.attrib.get('type')
                    region = Region(seq_beg, pdb_beg, seq_end, pdb_end, typ_region)
                    regions.append(region)
            chain = Chain(chain_id, num_tm, typ, seq.text.replace(" ", "").replace("\n", ""), regions)
            chains.append(chain) 
            
            pdbtm = PDBTM(pdbtm_id, chains)
            pdbtms.append(pdbtm)
        
        else:
            print(pdbtm_id, "is no TMP and was ignored.")
    
    return pdbtms

# Parse pdbtm xml files
paths = []
for file in os.listdir("pdbtm_xmls"):
    if file.endswith(".xml"):
        path = os.path.join("pdbtm_xmls", file)
        paths.append(path.strip())        
pdbtms = parse_pdbtm_xmls(paths)

In [3]:
def get_list_of_pdbtm_ids():
    pdbtms = []
    for file in os.listdir("pdbtm_xmls"):
        if file.endswith(".xml"):
            pdbid = file.replace(".xml", "")
            pdbtms.append(pdbid)
    
    return pdbtms

pdbtm_ids = get_list_of_pdbtm_ids()
print("There are",len(pdbtm_ids), "transmembrane proteins in the pdbtm database.")

There are 3284 transmembrane proteins in the pdbtm database.


In [51]:
aas = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'H', 'I', 'G', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']

def isValid(helix):
    if set(helix).issubset(set(aas)):
        return True
    else:
        return False

def get_tm_helices(pdbtms):
    pdbtm_helices = []
    for pdbtm in pdbtms:
        for chain in pdbtm.chains:
            if chain.typ == "alpha":
                for region in chain.regions:
                    if region.typ == "H":
                        helix = chain.seq[int(region.seq_beg):int(region.seq_end)]
                        if (len(helix) > 0) and isValid(helix):
                            pdbtm_helices.append([pdbtm.id, chain.id, int(region.pdb_beg), int(region.pdb_end), int(region.pdb_end)-int(region.pdb_beg), helix, int(1)])
    return pdbtm_helices

tm_helices = get_tm_helices(pdbtms)
print("There are", len(tm_helices), "transmembrane helices in the transmembrane proteins of the pdbtm database.")

There are 14973 transmembrane helices in the transmembrane proteins of the pdbtm database.


In [52]:
import pandas as pd
import numpy as np

# Create dataframe, remove duplicates and save it as CSV
columns = ["PDB ID", "Chain", "Helix Start", "Helix End", "Helix Length", "Helix Sequence", "Is Transmembrane"]
indices = range(1,len(tm_helices)+1)
tm_helices = np.array(tm_helices)
df_tm = pd.DataFrame(tm_helices,columns=columns)
df_tm = df_tm.drop_duplicates()
df_tm.to_csv("tm_helices.csv", sep=',', encoding='utf-8', index=False)

In [53]:
def get_nontm_helices(pdbtms):
    
    regions = []
    pdbtm_helices = []
    for pdbtm in pdbtms:
        for chain in pdbtm.chains:
            if chain.typ == "alpha":
                for region in chain.regions:
                    if region.typ == "1" or region.typ != "2":
                        helix = chain.seq[int(region.seq_beg):int(region.seq_end)]
                        if (len(helix) > 0 ) and isValid(helix):
                            pdbtm_helices.append([pdbtm.id, chain.id, region.pdb_beg, region.pdb_end, int(region.pdb_end)-int(region.pdb_beg), helix, int(0)])
    return pdbtm_helices

nontm_helices = get_nontm_helices(pdbtms)
print("There are", len(nontm_helices), "non-transmembrane helices in the transmembrane proteins of the pdbtm database.")

There are 30671 non-transmembrane helices in the transmembrane proteins of the pdbtm database.


In [54]:
# Create dataframe, remove duplicates and save it as CSV
columns = ["PDB ID", "Chain", "Helix Start", "Helix End", "Helix Length", "Helix Sequence", "Is Transmembrane"]
indices = range(1,len(nontm_helices)+1)
nontm_helices = np.array(nontm_helices)
df_nontm = pd.DataFrame(nontm_helices,columns=columns)
df_nontm = df_nontm.drop_duplicates()
df_nontm.to_csv("nontm_helices.csv", sep=',', encoding='utf-8', index=False)

In [58]:
# Combine dataframes
tm_helices = pd.read_csv("tm_helices.csv")
samples=len(tm_helices.axes[0])
print("TM helices:", samples)

nontm_helices = pd.read_csv("nontm_helices.csv")
samples=len(nontm_helices.axes[0])
print("NON-TM helices:", samples)  

frames = [tm_helices, nontm_helices]
training_data = pd.concat(frames)
training_data = training_data.drop_duplicates()
training_data = training_data.dropna()
training_data.sort_values(["PDB ID", "Chain", "Is Transmembrane"], ascending=[True, True, True], inplace=True)
training_data.to_csv("training_data.csv", sep=',', encoding='utf-8', index=False)

TM helices: 14973
NON-TM helices: 30671
