For developing the basic framework

In [None]:
from petls import PersistentSheafLaplacian, sheaf_simplex_tree, Complex
# from Bio.PDB import PDBParser
# from preprocess import get_ligand_data, read_pdb_biopython
import pandas as pd
import gudhi as gd
import numpy as np
import os

In [None]:
def atom_distance(a, b):
    # also known as EUC
    return np.sqrt(np.power((a[0] - b[0]), 2.0)
                   + np.power((a[1] - b[1]), 2.0)
                   + np.power((a[2] - b[2]), 2.0))

def filter_heavy_atoms(df, include_list):
    return df[df.element.isin(include_list)]


def add_radii(df, elts, radii):
    element_df = pd.DataFrame(list(zip(elts, radii)), columns=["element", "r"])
    df = df.merge(element_df, on="element")
    return df

In [None]:
import argparse
import logging
import sys
from collections import OrderedDict
from collections.abc import Sequence
from io import StringIO
from os import PathLike
from pathlib import Path

# from pdb2pqr.main import build_main_parser, main_driver

# def run_pdb2pqr(args: Sequence[str | PathLike]):
#     """Run PDB2PQR with a list of arguments.

#     Logger is not set up so that it can be called multiple times.

#     :param args:  list of command-line arguments
#     :type args:  list
#     :return:  results of PDB2PQR run
#     :rtype:  tuple
#     """
#     args_strlist = [str(arg) for arg in args]
#     parser = build_main_parser()
#     args_parsed = parser.parse_args(args_strlist)
#     return main_driver(args_parsed)

In [None]:

# from pdbfixer import PDBFixer
# from openmm.app import PDBFile
# def run_pdbfixer(pdbid):

# # pdbid = "1a0q"
#     prefix = f"../data/v2007/{pdbid}/{pdbid}_"
#     filename = f"{prefix}protein.pdb"
#     fixer = PDBFixer(filename=filename)
#     fixer.findMissingResidues()
#     fixer.findNonstandardResidues()
#     # fixer.replaceNonstandardResidues()
#     # fixer.removeHeterogens(False)
#     fixer.findMissingAtoms()
#     print("Nonstandard residues:", fixer.nonstandardResidues)
#     print("Missing residues:",fixer.missingResidues)
#     print("Missing atoms:", fixer.missingAtoms)
#     fixer.addMissingAtoms()
#     fixer.addMissingHydrogens(7.0)
#     # fixer.addSolvent(fixer.topology.getUnitCellDimensions())
#     PDBFile.writeFile(fixer.topology, fixer.positions, open(f'{pdbid}.pdb', 'w'))

In [None]:
import pandas as pd
df = pd.read_csv("../data/v2007/INDEX.2007.refined.csv", header=None,index_col=None,sep='\s+')
pdbids = df[0].tolist()
# failed_pdbids = []
# successful_pdbids = []
# problematic_pdbids = ['2hdr', '1cet', '1ux7', '1k1y']
# out of the first 100, Failed pdbids:  ['2hdr', '1cet', '1ux7', '1k1y']
# for pdbid in pdbids[0:100]:
#     print("Run preprocess pipeline on ", pdbid)
#     status = preprocess(pdbid)
#     if status:
#         successful_pdbids.append(pdbid)
#     if not status:
#         failed_pdbids.append(pdbid)
# print("Failed pdbids: ", failed_pdbids)
# print("Len(failed_pdbids): ", len(failed_pdbids))
# print("Len(successful_pdbids): ", len(successful_pdbids))

In [None]:
from readin import mol2_to_pqre
def check_mol2_okay(pdbids):
    failed_pdbids = []
    successful_pdbids = []
    for i, pdbid in enumerate(pdbids):
        if i % 50 == 0:
            print(f"Progress: {i} of {len(pdbids)}")
        try:
            protein = mol2_to_pqre(pdbid, "charged")
            ligand = mol2_to_pqre(pdbid, "ligand")
            successful_pdbids.append(pdbid)
            # print("SUCCESS:", pdbid)
        except Exception as e:
            failed_pdbids.append(pdbid)
            print("FAILED:", pdbid)
    print(f"{len(successful_pdbids)} successful reads: {successful_pdbids}")
    print(f"{len(failed_pdbids)} failed reads: {failed_pdbids}")
# 10 failed reads: ['1w2g', '1bzj', '1sts', '1mfa', '1g7f', '2bsu', '1gt1', '966c', '9abp', '1gar']
# throw away for now
# TODO: once model basically works, figure out how to get these back in the dataset

In [None]:
import numpy as np
def new_xyzq(pdbid, pro_elements, lig_elements, cutoff=10.0):
    protein = mol2_to_pqre(pdbid, "charged")
    ligand = mol2_to_pqre(pdbid, "ligand")
    # print("ligand=",ligand)
    

    filtered_pro = np.array(filter_heavy_atoms(protein,pro_elements))
    filtered_lig = np.array(filter_heavy_atoms(ligand, lig_elements))
    print("filtered_ligand=",filtered_lig)

    keep_indices = []
    for p_index, pro_atom in enumerate(filtered_pro):
        if p_index % 400 == 0:
            print("checking pqr index: ", p_index, flush=True)
        keep = False
        for l_list_index, lig_atom in enumerate(filtered_lig):
            # print("pro_atom, lig_atom", pro_atom, lig_atom)
            d = atom_distance(pro_atom,lig_atom)
            if d < cutoff:
                keep = True
        if keep:
            # print(p_index)
            
            keep_indices.append(p_index)
            
    pro_cutoff = filtered_pro[keep_indices]
    combined = np.concatenate([pro_cutoff,ligand])
    combined = combined[:,[0,1,2,4]]
    return combined, ligand

In [None]:
# command line works for some: pdb2pqr --ff=AMBER --ligand=1a99_ligand.mol2 1a99_protein.pdb 1a99_charged.pqr

In [None]:
import math
def get_rough_diameter(ptcloud):
    min_x = np.min(ptcloud[:,0])
    max_x = np.max(ptcloud[:,0])
    min_y = np.min(ptcloud[:,1])
    max_y = np.max(ptcloud[:,1])
    min_z = np.min(ptcloud[:,2])
    max_z = np.max(ptcloud[:,2])
    print("x range: ",min_x," to ",max_x,flush=True)
    print("y range: ",min_y," to ",max_y,flush=True)
    print("z range: ",min_z," to ",max_z,flush=True)
    span_x = max_x - min_x
    span_y = max_y - min_y
    span_z = max_z - min_z
    diameter = math.sqrt(span_x**2 + span_y**2 + span_z**2)
    # diameter = max(max_x-min_x, max_y-min_y, max_z-min_z)
    print("diameter = ", diameter, flush=True)
    return diameter

In [None]:
from gudhi import AlphaComplex
def get_alpha_complex(ptcloud, max_filtration):
    alpha_complex = AlphaComplex(points=ptcloud)
    simplex_tree = alpha_complex.create_simplex_tree()
    simplex_tree.prune_above_dimension(2)
    simplex_tree.prune_above_filtration(max_filtration)
    return simplex_tree

In [None]:
def get_extra_data(combined):
    extra_data = {}
    for i, vertex in enumerate(combined):
        extra_data[tuple([i])] = [*vertex]  # store x,y,z,r,element,charge for each vertex
        print(vertex)
    return extra_data

In [None]:
import petls
def my_restriction(simplex: list[int], coface: list[int], sst: petls.sheaf_simplex_tree) -> float:
    # return 1.0
    from math import sqrt
    k = 8.9875517862e9 #Coulumb's constant
    sqrt_k = sqrt(k)
    if len(simplex) == 1:
        if simplex == [coface[0]]:
            sibling = [coface[1]]
        else:
            sibling = [coface[0]]
        
        coords_simplex = sst.extra_data[tuple(simplex)][0:3]
        coords_sibling = sst.extra_data[tuple(sibling)][0:3]
        distance = sqrt((coords_simplex[0] - coords_sibling[0])**2 \
                    + (coords_simplex[1] - coords_sibling[2])**2 \
                    + (coords_simplex[2] - coords_sibling[1])**2)
        return sst.extra_data[tuple(sibling)][3] / distance # charge / distance
    elif len(simplex) == 2:
        coeff = 1.0
        for (sibling, _) in sst.st.get_boundaries(coface):
            if sibling == simplex:
                opposite_vertex = coface[sst.coface_index(simplex,coface)]
                coeff = coeff * sst.extra_data[tuple([opposite_vertex])][3] #charge
            else:
                coeff = coeff / sst.st.filtration(sibling)
        return coeff


In [None]:
def get_laplacian_complex(combined_xyzq, max_filtration=None):
    ptcloud = combined_xyzq[:,0:3] #xyz_combined
    if max_filtration is None:
        max_filtration = get_rough_diameter(ptcloud)/4 
    # diameter = get_rough_diameter(ptcloud)
    stree = get_alpha_complex(ptcloud, max_filtration=max_filtration)
    print("Number of simplices:", stree.num_simplices())
    extra_data = get_extra_data(combined_xyzq)
    sstree = sheaf_simplex_tree(stree,extra_data, my_restriction)
    laplacian_complex = PersistentSheafLaplacian(sstree)
    return laplacian_complex, stree

In [None]:
def get_selected_spectra_requests():
    dims = [0]
    min_filt = 1.0
    max_filt = 2.0
    steps = 1
    stepsize = (max_filt - min_filt) / steps
    sampled_filtrations = [min_filt + i * stepsize for i in range(steps + 1)]
    print("sampled_filtrations:", sampled_filtrations)
    requests = []
    delta = 1.0
    for dim in dims:
        for idx, filt in enumerate(sampled_filtrations):
            requests.append((dim, filt, filt+delta))
            # for idx2, filt2 in enumerate(sampled_filtrations):
            #     if filt2 >= filt:
            #         requests.append((dim, filt, filt))
    print("number of requests:", len(requests), flush=True)
    print("requests:", requests, flush=True)
    return requests

In [None]:
def min_nonzero(X):
    nonzeros = [s for s in X if s > 1e-2]
    if len(nonzeros) == 0:
        return np.nan
    return min(nonzeros)

In [None]:
import time
# pro_ele_rad = [1.70]#, 1.55, 1.52, 1.80]
# #     pro_elements = ["C", "N", "O", "S"]
# pro_elements = ["C"] # for simplicity for now
# #     lig_elements = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I"]
# lig_elements = ["C", "N", "O", "S"]#, "P", "F", "Cl", "Br", "I"]
# lig_ele_rad = [1.70, 1.55, 1.52, 1.80]#, 1.80, 1.47, 1.75, 1.85, 1.98]  # last is hydrogen when needed: 1.20

def pipeline(pdbid):
    pdbid = "1a0q"
    directory = "../data/v2007"
    pro_elements = ["C","N","O","S"]
    pro_ele_rad = [1.70]
    lig_elements = ["C", "N", "O", "S"]#, "P", "F", "Cl", "Br", "I"]
    lig_ele_rad = [1.70, 1.55, 1.52, 1.80]#, 1.80, 1.47, 1.75, 1.85, 1.98]  # last is hydrogen when needed: 1.20
    # pdbid = "1a99"
    cutoff = 30
    combined,ligand = new_xyzq(pdbid,pro_elements, lig_elements,cutoff = cutoff)
    print("Combined:",combined)

    print(pdbid,"building laplacian complex...",flush=True)
    t = time.time()
    laplacian_complex, stree = get_laplacian_complex(combined,max_filtration=8)
    print(pdbid,"got laplacian complex in {complex_time} seconds",flush=True)
        # features = get_selected_spectra(laplacian_complex)

    dgms = stree.persistence()
    gd.plot_persistence_diagram(dgms)
    print(stree.num_simplices(),flush=True)
    requests = get_selected_spectra_requests()
    laplacian_complex.verbose = True
    print(stree.num_simplices(),flush=True)
    requested_spectra = laplacian_complex.spectra(request_list=requests)
    laplacian_complex.verbose = False
    
    # processed_spectra = round_zeros_PH(stree, requested_spectra)

    return requested_spectra, stree, combined, laplacian_complex, np.array(ligand)



In [None]:
spectra, stree, combined, laplacian_complex, ligand = pipeline("1a0q")

In [None]:
stree.persistence()
dgm1 = stree.persistence_intervals_in_dimension(1)
print(len(dgm1))
print(len([x for x in dgm1 if x[1] > 10 ]))

In [None]:
dgm0 = stree.persistence_intervals_in_dimension(0)
betti0_1_2 = len([_ for x in dgm0 if x[1] >= 2])
print(dgm0)
print(betti0_1_2)

In [None]:
L = laplacian_complex.get_L(0,1,2)

In [None]:
np.max(L - L.T)

In [None]:
eps = 1e-9
L
L_eps = L+eps*np.eye(L.shape[0],L.shape[1])

In [None]:
from scipy.linalg import eigvalsh, eigvals
from scipy.sparse.linalg import eigsh, eigs
import matplotlib.pyplot as plt
e = eigvalsh(L)
e_eps = eigvalsh(L_eps)
# e = eigsh(L,k=L.shape[0]-1)[0]
# e_eps = eigsh(L_eps, k=L.shape[0]-1)[0]
diff = e_eps - e
rel_diff = []
for i in range(len(e)):
    rel = diff[i] / e[i] if e[i] != 0 else diff[i]
    rel_diff.append(rel)


In [1]:
print(e)

NameError: name 'e' is not defined

In [None]:
e_eigvalsh = eigvalsh(L)[1:]
print(e_eigvalsh.shape)

In [None]:
# diff_sparse = e-e_eigvalsh

In [None]:
print(diff_sparse)
print(max(abs(diff_sparse)))

In [None]:
# print(e)
# print(e_eps)
# print(diff)
e_eps_diffs = list(zip(e,e_eps,diff, rel_diff))
for e in e_eps_diffs:
    print(e)

In [None]:
# e = sorted(eigs(L,which="LM",ncv=200)[0].real)
true_betti0 = betti0_1_2
true_nonneg = e[betti0_1_2+1:]
print(true_nonneg)
nonneg = [eig for eig in e if eig > 0]
neg = [eig for eig in e if eig < 0]
betti = [eig for eig in e if eig == 0]
print("Betti=",betti)
for eig in neg:
    print("Negative eig ", eig)
for eig in nonneg:
    print(eig)
logs = np.log10(nonneg)
plt.hist(logs)

In [None]:
print(spectra)

In [None]:
s = spectra[0][3]
nonneg = [eig for eig in s if eig >= 0]
for eig in nonneg:
    print(eig)
logs = np.log10(nonneg)

In [None]:
for l in logs:
    print(l)

In [None]:
import matplotlib.pyplot as plt
# plt.hist(nonneg)
# plt.hist(logs)

In [None]:

dim0_spectra = [[filt, min_nonzero(spectrum), len([s for s in spectrum if s < 1e-10])] for (dim, filt, _, spectrum) in spectra if dim == 0]
dim1_spectra = [[filt, min_nonzero(spectrum), len([s for s in spectrum if s < 1e-10])] for (dim, filt, _, spectrum) in spectra if dim == 1]
import matplotlib.pyplot as plt
d0 = np.array(dim0_spectra)
d1 = np.array(dim1_spectra)
d = d1
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()

ax1.plot(d[:,0],d[:,1],color='b',label='lambda')
ax2.plot(d[:,0],d[:,2],color='r',label='betti')
plt.xlim(0,8)
ax1.legend(loc="upper left")
ax2.legend(loc="upper right")
# ax1.set_ylim(-0.001, 0.05)
# ax2.set_ylim(-0.001, 200)
# ax1.set_ylim(-0.001,0.035)
# ax2.set_ylim(-0.001,401)

In [None]:
# summaries = petls.PLutil.summaries(requested_spectra,func=min_nonzero)
print(dim1_spectra)

In [None]:
# print(summaries[0])

In [None]:
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# im = petls.PLutil.plot_summary(ax =ax,summary=summaries[0][1])
# fig.colorbar(im, ax=ax)

In [None]:
#plot alpha complex
import numpy as np
import plotly.graph_objects as go
import plotly

# from plotly.graph_objs import graph_objs as go
import ipywidgets as widgets

plotly.offline.init_notebook_mode()
from plotly.offline import iplot

alpha = widgets.FloatSlider(
    value = 0.005,
    min = 0.0,
    max = 8,
    step = 0.01,
    description = 'Alpha:', 
    readout_format = '.4f'
)

In [None]:
# !pip list

In [None]:
print(ligand)

In [None]:
print(ligand)

In [None]:
# pip install anywidget, plotly, ipywidgets, nbformat
ptcloud = combined[:,0:3]
ligand = ligand[:,0:3]
def plot_alpha_1_skeleton(alpha_thresh):
    show_triangles = True  # Set to True to show triangles
    show_tetrahedra = True  # Set to True to show tetrahedra
# def plot_alpha_1_skeleton(torus, stree, alpha_thresh, show_triangles=False):
    # Extract edges and vertices under the filtration threshold
    edges = []
    vertices = set()
    triangles = []
    pyramids = []

    for simplex, filtration_value in stree.get_filtration():
        if filtration_value <= alpha_thresh:
            if len(simplex) == 1:  # vertex
                vertices.add(simplex[0])
            elif len(simplex) == 2:  # edge
                i, j = simplex
                edges.append((i, j))
                vertices.add(i)
                vertices.add(j)
            elif len(simplex) == 3 and show_triangles:
                i, j, k = simplex
                triangles.append((i, j, k))
                vertices.update([i, j, k])
            elif len(simplex) == 4 and show_tetrahedra:
                i, j, k, l = simplex
                pyramids.append((i, j, k, l))
                vertices.update([i, j, k, l])
    # Get coordinates for the vertices
    vertex_coords = ptcloud[list(vertices)]
    idx_map = {v: i for i, v in enumerate(sorted(vertices))}  # map old -> new indices
    
    # Create scatter plot for vertices
    vertex_trace = go.Scatter3d(
        x=vertex_coords[:, 0],
        y=vertex_coords[:, 1],
        z=vertex_coords[:, 2],
        mode='markers',
        marker=dict(size=3, color='blue'),
        # name='Vertices'
    )
    ligand_trace = go.Scatter3d(
        x=ligand[:, 0],
        y=ligand[:, 1],
        z=ligand[:, 2],
        mode='markers',
        marker=dict(size=3, color='red'),
        # name='Vertices'
    )
    
    # Create line segments for edges
    edge_x = []
    edge_y = []
    edge_z = []
    for i, j in edges:
        edge_x += [ptcloud[i, 0], ptcloud[j, 0], None]
        edge_y += [ptcloud[i, 1], ptcloud[j, 1], None]
        edge_z += [ptcloud[i, 2], ptcloud[j, 2], None]
    
    edge_trace = go.Scatter3d(
        x=edge_x,
        y=edge_y,
        z=edge_z,
        mode='lines',
        line=dict(width=3, color='black'),
        # name='Edges'
    )
    
    # Triangle mesh
    mesh_trace = None
    if show_triangles and triangles:
        # Re-index vertices for Mesh3d
        i_list, j_list, k_list = zip(*[(idx_map[i], idx_map[j], idx_map[k]) for i, j, k in triangles])
        mesh_trace = go.Mesh3d(
            x=vertex_coords[:, 0],
            y=vertex_coords[:, 1],
            z=vertex_coords[:, 2],
            i=i_list,
            j=j_list,
            k=k_list,
            color='violet',
            opacity=0.7,
            # name='Triangles',
            showscale=False
        )

    # Tetrahedra mesh
    tetra_trace = None
    if pyramids:
        # Re-index vertices for Mesh3d
        tuples = []
        for i, j, k, l in pyramids:
            tuples.append((idx_map[i], idx_map[j], idx_map[k])) # leave out l
            tuples.append((idx_map[i], idx_map[j], idx_map[l])) # leave out k
            tuples.append((idx_map[i], idx_map[k], idx_map[l])) # leave out j
            tuples.append((idx_map[j], idx_map[k], idx_map[l])) # leave out i

        i_list, j_list, k_list = zip(*tuples)
        
        tetra_trace = go.Mesh3d(
            x=vertex_coords[:, 0],
            y=vertex_coords[:, 1],
            z=vertex_coords[:, 2],
            i=i_list,
            j=j_list,
            k=k_list,
            color='purple',
            opacity=1.0,
            name='Tetrahedra',
            showscale=False
        )

    traces = [edge_trace, vertex_trace,ligand_trace]
    if mesh_trace:
        traces.insert(0, mesh_trace)
    if tetra_trace:
        traces.append(tetra_trace)

    
    fig = go.FigureWidget(
        data = traces, 
        layout = go.Layout(
            # title = dict(
            #     text = 'Alpha Complex Representation of the 2-Torus'
            # ), 
            
            scene = dict(
                # xaxis = dict(nticks = 8, range = [min_x, max_x], visible=False), 
                # yaxis = dict(nticks = 8, range = [min_y, max_y], visible=False), 
                # zaxis = dict(nticks = 8, range = [min_z, max_z], visible=False),
                # xaxis = dict(visible=False), 
                # yaxis = dict(visible=False), 
                # zaxis = dict(visible=False),
                aspectmode = 'data'
            ),
            showlegend=False,
            margin = dict(l=0.00, r=0.00, b=0, t=0,pad=0),
            plot_bgcolor='rgba(0,0,0,0)',
            title=None,
            
        )   
    )

    # plotly.io.write_image(fig, 'alpha_complex.png', scale=1, width=1080, height=1080)
    iplot(fig)
    return fig
    
    # if filename:
    #     plotly.io.write_image(fig, filename)

# plot_alpha_1_skeleton(alpha_thresh=0.4, ptcloud=combined, stree=simplex_tree)
dummy = widgets.interact(plot_alpha_1_skeleton, alpha_thresh=alpha);
# plot_alpha_1_skeleton(torus, stree, alpha_thresh=0.8,show_triangles=True)
# plot_alpha_1_skeleton(alpha_thresh=0.4)