In [1]:
from preprocess import get_ligand_data, read_pdb_biopython
import numpy as np
import pandas as pd
from petls import Rips, dFlag

In [2]:
#     pro_ele_rad = [1.70, 1.55, 1.52, 1.80]
pro_ele_rad = [1.70]#, 1.55, 1.52, 1.80]
#     pro_elements = ["C", "N", "O", "S"]
pro_elements = ["C"] # for simplicity for now
#     lig_elements = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I"]
lig_elements = ["C", "N", "O", "S"]#, "P", "F", "Cl", "Br", "I"]
lig_ele_rad = [1.70, 1.55, 1.52, 1.80]#, 1.80, 1.47, 1.75, 1.85, 1.98]  # last is hydrogen when needed: 1.20

In [3]:
def atom_distance(a, b):
    # also known as EUC
    return np.sqrt(np.power((a[0] - b[0]), 2.0)
                   + np.power((a[1] - b[1]), 2.0)
                   + np.power((a[2] - b[2]), 2.0))

def filter_heavy_atoms(df, include_list):
    return df[df.element.isin(include_list)]


def add_radii(df, elts, radii):
    element_df = pd.DataFrame(list(zip(elts, radii)), columns=["element", "r"])
    df = df.merge(element_df, on="element")
    return df

In [4]:
def make_edge(pro,lig,p_index,l_index,distance):
    pro_element = pro[4]
    lig_element = lig[4]
    if not pro_element == 'C':
#         print("non-carbon protein element!")
        return None
    if lig_element == "H":
        return [[l_index, p_index,distance]]
    elif lig_element == "S":
        return [[l_index, p_index,distance]]
    elif lig_element == "C":
        return [[l_index, p_index,distance],
                [p_index, l_index,distance]] #both edges
    elif lig_element == "N":
        return [[p_index, l_index,distance]]
    elif lig_element == "O":
        return [[p_index, l_index,distance]]

    return []

def compare_element(element_1,element_2):
    # H -> S -> C-> N -> O
    order_dict = {"H":0, "S": 1,"C": 2,"N": 3,"O": 4}
    if order_dict[element_1] < order_dict[element_2]:
        return -1
    elif order_dict[element_1] > order_dict[element_2]:
        return 1
    else:
        return 0

def make_edge_ligand(atom_a,atom_b,a_index,b_index,distance):
    a_element = atom_a[4]
    b_element = atom_b[4]
    
    c = compare_element(a_element,b_element)
    if c == -1:
        return [[a_index,b_index,distance]]
    elif c == 1:
        return [[b_index,a_index,distance]]
    else:
        return [[a_index, b_index,distance],
                [b_index, a_index,distance]]
    

In [5]:
def reindex_edges(edges):
    # vertices and edges are originally ordered with respect to the total molecules
    # not wrt the cutoff-filtered molecules. This relabels them to the minimal mapping.
    count = 0
    vertex_dict = {}
    new_edges = []
    for index, edge in enumerate(edges):
        source = edge[0]
        sink = edge[1]
        if source in vertex_dict:
            source_new = vertex_dict[source]
        else:
            source_new = count
            vertex_dict[source] = count
            count += 1
        if sink in vertex_dict:
            sink_new = vertex_dict[sink]
        else:
            sink_new = count
            vertex_dict[sink] = count
            count += 1
        new_edges.append([source_new,sink_new,edge[2]])
    return [new_edges,count]
            

In [6]:
def output_graph(edges,vertex_count,cutoff,no_cutoff,pdbid):
    cutoff_str = "no_cutoff" if no_cutoff else f"{cutoff}"
    with open(f'{pdbid}_{cutoff_str}_1sigfig.flag', 'w') as f:
        f.write("dim 0\n")
        f.write("0 "*vertex_count)
        f.write("\ndim 1")
        for edge in edges:
#             f.write(f"\n{edge[0]} {edge[1]} %.3f" % edge[2])
            f.write(f"\n{edge[0]} {edge[1]} %.1f" % edge[2])
            

In [7]:
def get_ligand_data_mol2(pdbid, directory):

    # read in the entire file
    with open(f"{directory}/{pdbid}/{pdbid}_ligand.mol2") as ligand_file:
        lines = ligand_file.readlines()

    # get meta information
    index_meta = lines.index("@<TRIPOS>MOLECULE\n") + 2
    meta = lines[index_meta].split()
    # filter to atom info only and split text into columns
    index_atoms_start = lines.index("@<TRIPOS>ATOM\n")
    index_atoms_end = lines.index("@<TRIPOS>BOND\n")
    filtered_lines = [line.split() for line in lines[index_atoms_start + 1:index_atoms_end]]
    if len(filtered_lines) != int(meta[0]):
        raise Exception("incorrect number of ligand atoms")

    # turn into Pandas DataFrame, remove unnecessary columns, label columns
    df = pd.DataFrame(filtered_lines)
    df = df.iloc[:, [2, 3, 4, 5]]
    df.columns = ["x", "y", "z", "element"]
    df["element"] = df["element"].str.split(".").str[0]
    return df

In [8]:
def get_ligand_bonds_mol2(pdbid, directory):
    import pandas as pd
       # read in the entire file
    with open(f"{directory}/{pdbid}/{pdbid}_ligand.mol2") as ligand_file:
        lines = ligand_file.readlines()

    # get meta information
    index_meta = lines.index("@<TRIPOS>MOLECULE\n") + 2
    meta = lines[index_meta].split()
    # filter to atom info only and split text into columns
    index_bonds_start = lines.index("@<TRIPOS>BOND\n")
    index_bonds_end = lines.index("@<TRIPOS>SUBSTRUCTURE\n")
    
    filtered_lines = [line.split() for line in lines[index_bonds_start + 1:index_bonds_end]]
    if len(filtered_lines) != int(meta[1]):
        raise Exception("incorrect number of ligand bonds")
    df = pd.DataFrame(filtered_lines)
    df = df.iloc[:, [1,2]]
    df = df.apply(pd.to_numeric)
    as_np = df.to_numpy()
    return as_np



In [9]:
def get_ligand_data(pdbid, directory, lig_elements, lig_ele_rad, filetype):
    # retrieve data based on file type.
    if filetype == "mol2":
        ligand = get_ligand_data_mol2(pdbid, directory)
    elif filetype == "sdf":
        ligand = get_ligand_data_sdf(pdbid, directory)
    else:
        raise Exception(f"invalid ligand filetype {filetype}. use 'mol2' or 'sdf'")
    

    select_indices = list(np.where(ligand.element.isin(lig_elements))[0])
    ligand = filter_heavy_atoms(ligand, lig_elements)

    ligand = add_radii(ligand, lig_elements, lig_ele_rad)

    # convert position vector from text to numeric
    ligand[["x", "y", "z"]] = ligand[["x", "y", "z"]].apply(pd.to_numeric)
    ligand = ligand[["x", "y", "z", "r", "element"]]
    ligand_np = ligand.to_numpy()
    return ligand_np, select_indices

In [10]:
def compute_features(pdbid,pro_elements, pro_ele_rad, directory, lig_elements, lig_ele_rad,cutoff=4,no_cutoff = False):
    print(pdbid,flush=True)

    # read in the protein and ligand data
    protein = read_pdb_biopython(pdbid, pro_elements, pro_ele_rad)
    ligand, select_indices = get_ligand_data(pdbid, directory, lig_elements, lig_ele_rad, "mol2")
    
    keep_indices = []
    print("len(protein) = ", len(protein), flush=True)
    for p_index, pro_atom in enumerate(protein):
        if p_index % 400 == 0:
            print("checking protein index: ", p_index, flush=True)
        keep = False
        for l_list_index, lig_atom in enumerate(ligand):
            d = atom_distance(pro_atom,lig_atom)
            if d < cutoff or no_cutoff:
                keep = True
        if keep:
            # print(p_index)
            
            keep_indices.append(p_index)
            
    protein = protein[keep_indices]         
    protein = np.delete(protein,[3,4],1)
    ligand = np.delete(ligand,[3,4],1)
    combined = np.concatenate([protein,ligand])
    print("len(combined)=",len(combined),flush=True)
    prl = Rips(combined, 2, cutoff)  #0:2853 is the range where errors occured, takes over 1 min to do 0:200 and grows exponentially
    print("built laplacian",flush=True)
    max_r = cutoff
    stepsize = 0.25
    num_steps = max_r/stepsize
    for i in range(round(num_steps)):
        for dim in range(3):
            print(prl.spectra(dim,i*stepsize,(i+1)*stepsize),flush=True)
    # spec = prl.spectra()
    print("computed spectra",flush=True)
    # print("first few spectra: ",spec[0:10])

In [11]:
# pro_ele_rad = [1.70]#, 1.55, 1.52, 1.80]
# #     pro_elements = ["C", "N", "O", "S"]
# pro_elements = ["C"] # for simplicity for now
# #     lig_elements = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I"]
# lig_elements = ["C", "N", "O", "S"]#, "P", "F", "Cl", "Br", "I"]
# lig_ele_rad = [1.70, 1.55, 1.52, 1.80]#, 1.80, 1.47, 1.75, 1.85, 1.98]  # last is hydrogen when needed: 1.20


directory = "./data"
pro_elements = ["C"]
pro_ele_rad = [1.70]
lig_elements = ["C", "N", "O", "S"]#, "P", "F", "Cl", "Br", "I"]
lig_ele_rad = [1.70, 1.55, 1.52, 1.80]#, 1.80, 1.47, 1.75, 1.85, 1.98]  # last is hydrogen when needed: 1.20
pdbid = "1a99"
cutoff = 7
no_cutoff = False
# compute_features("1a99",["C"], [1.70], directory, ["C"], [1.55],cutoff=7,no_cutoff=False)
print(pdbid,flush=True)

# read in the protein and ligand data
protein = read_pdb_biopython(pdbid, pro_elements, pro_ele_rad)
ligand, select_indices = get_ligand_data(pdbid, directory, lig_elements, lig_ele_rad, "mol2")

keep_indices = []
print("len(protein) = ", len(protein), flush=True)
for p_index, pro_atom in enumerate(protein):
    if p_index % 400 == 0:
        print("checking protein index: ", p_index, flush=True)
    keep = False
    for l_list_index, lig_atom in enumerate(ligand):
        d = atom_distance(pro_atom,lig_atom)
        if d < cutoff or no_cutoff:
            keep = True
    if keep:
        # print(p_index)
        
        keep_indices.append(p_index)
        
protein = protein[keep_indices]         
protein = np.delete(protein,[3,4],1)
ligand = np.delete(ligand,[3,4],1)
combined = np.concatenate([protein,ligand])
print("len(combined)=",len(combined),flush=True)



1a99
len(protein) =  3446
checking protein index:  0
checking protein index:  400
checking protein index:  800
checking protein index:  1200
checking protein index:  1600
checking protein index:  2000
checking protein index:  2400
checking protein index:  2800
checking protein index:  3200
len(combined)= 83


In [12]:
import time

In [13]:
start = time.time()
prl = Rips(combined, 2, cutoff)  #0:2853 is the range where errors occured, takes over 1 min to do 0:200 and grows exponentially
print("built laplacian",flush=True)
max_r = cutoff
stepsize = 0.25
num_steps = max_r/stepsize
for i in range(round(num_steps)):
    for dim in range(3):
        print(prl.spectra(dim,i*stepsize,(i+1)*stepsize),flush=True)
# spec = prl.spectra()
print("computed spectra",flush=True)
end = time.time()
rips_time = end - start

built laplacian
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[]
[]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[]
[]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

In [None]:
start = time.time()
dflag = dFlag("./data/flag/1a99_7_1sigfig.flag",2)
cutoff = 7
max_r = cutoff
stepsize = 0.25
num_steps = max_r/stepsize
for i in range(round(num_steps)):
    for dim in range(3):
        print(dflag.spectra(dim,i*stepsize,(i+1)*stepsize),flush=True)

end = time.time()
flag_time = end - start

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
nontrivial filtration algorithm
end nontrivial filtration alg
end prepare graph filtration
cell count in dimension 1 is 497
cell count in dimension 2 is 798
cell count in dimension 3 is 0
[]
[]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

In [19]:
print("flag time = ", flag_time)
print("ripse time = ", rips_time)
print("flag speedup factor = {:0.1f}".format(rips_time/flag_time))

flag time =  0.29808807373046875
ripse time =  19.465643405914307
flag speedup factor = 65.3


In [16]:
# import seaborn as sns
# import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = [15, 5]
# plt.rcParams['figure.dpi'] = 600
# sns.set_theme()
# sns.set_style("whitegrid")
# sns.color_palette("colorblind")

# fig,axes = plt.subplots(2,2)
# #TODO have a palette for fill under dimension
# sns.lineplot(ax=axes[0,0],data=df,x="filtration",y="betti_0",color="black",drawstyle='steps-post')
# sns.lineplot(ax=axes[1,0],data=df,x="filtration",y="betti_1",color="black",drawstyle='steps-post')
# sns.lineplot(ax=axes[0,1],data=df,x="filtration",y="lambda_0",color="black",drawstyle='steps-post')
# sns.lineplot(ax=axes[1,1],data=df,x="filtration",y="lambda_1",color="black",drawstyle='steps-post')
# for type_id in range(0,2):
#     for dim in range(0,2):        
#         l = axes[dim,type_id].lines[0]
#         x = l.get_xydata()[:,0]
#         y = l.get_xydata()[:,1]
#         axes[dim,type_id].fill_between(x,y,alpha=0.6,step="post")

# plt.show()