In [None]:
from Bio.PDB import *
from Bio import PDB
from torch import nn
import numpy as np
from ase import Atoms, Atom
import dask.dataframe as dd
from ordered_set import OrderedSet
import os
pdb_parser = PDB.PDBParser()
local_folder="/Users/jessihoernschemeyer/pKaSchNet"
pkPDB_CSV = f"{local_folder}/pkas.csv"
#PKAD_CSV = f"{local_folder}/WT_pka.csv"

#make the dask data frame from the PYPKA csv
dk=dd.read_csv(pkPDB_CSV, delimiter=';', na_filter=False, dtype={'idcode':'category', 
                                                                  'residue_number':'uint8',
                                                                  'pk': 'float32',
                                                                  'residue_name':'category',
                                                                  'chain': 'category',
                                                                  'residue_name': 'category'
                                                                  })
                                                           
dk=dk.rename(columns={'idcode': 'PDB ID', 'residue_number': 'Res ID', 'residue_name': 'Res Name', 'residue_number': 'Res ID', 'pk': 'pKa', 'chain' : 'Chain'}) #rename columns to match df from pkad 
dk=dk.sort_values(['PDB ID', 'Res ID'], ascending=[True, True]) #sorts both
dk=dk.compute() #full pypka database
dff = dk.reset_index() #also the full db but with a reset index.

In [None]:
#%%capture
def get_cutout(dask_df, distance_cutoff): #"PARENT" FUNCTION
    """gets pdb name from dask_df (PYPKA), downloads, then checks entire protein for undesirable atoms, and skips the entire structure in their presence.
        For those retained, it generates a cutout surrounding the target residue within target protein, and saves it as a PDB.
        inputs  | dask_df (dd):             the full pypka database
                | distance_cutoff (int):    cutoff radius (Ã…) from the titratable residue's COG for the neighbor search"""
    
    pdb_parser = PDB.PDBParser()
    pdbs = list(OrderedSet(list(dask_df["PDB ID"])))
    
    for i in range(3,4): #will equal len of set of pdbs in pypka, == 121294 
        #get target protein information, check the protein structure for questionable atoms, and gets biopython structure object in their absense. 
        structure_list=[] #per pdb
        pdbname = pdbs[i]
        #pdbname="149l"
        Structure = pdb_parser.get_structure("",  PDBList().retrieve_pdb_file(str.lower(pdbname),obsolete=False, pdir='PDB',file_format = 'pdb'))
        #Structure=pdb_parser.get_structure("", 'pdb11as.ent')
        structure = check_atoms_protein(Structure, Structure.get_atoms())
        if structure == 0: #skip entire pdb and all its entries in pypka db if there are undesirables in pdb
            continue

        ns = PDB.NeighborSearch(list(structure.get_atoms())) #set up neighbor search for later execution 

        pdb_df = dask_df[dask_df.iloc[:, 1] == pdbname].drop(columns = ["PDB ID", "pKa"]) #make a subdf containing only residue entries which are in PYPKA (dask_df) TO Save time.
        #for each represented titratable residue in PYPKA, generate a cutout and saves to pdb 
        for j in range(len(pdb_df)): 
            chain=pdb_df.iloc[j]['Chain']
            res_id = int(pdb_df.iloc[j]['Res ID'])
            try:
                residue=structure[0][chain][res_id] 
                center = residue.center_of_mass(geometric=True)
                cutout = ns.search(center, distance_cutoff, "A")
                structure_list.append(cutout)
                atoms_to_structure(cutout, f"{pdbname}_{chain}_{res_id}_{pdb_df.iloc[j]['Res Name']}") #save as pdb
            except:
                f"residue not found in pdb {pdbname}, skipping"
        #sett = max_overlap(structure_list)
        #print(max)
        #maxes.append(max)
        os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")
    return structure_list




def check_atoms_protein(structure, struc_atoms): 
    """internal function. checks every atom in the entire protein for metals, undesirables"""
    for atom in struc_atoms: #check if each atom is ?atoms
        element = atom.element

        if element in ["MG", "MN", "FE", "CO", "NI", "CU", "ZN"]:
            print(f"{element} present, pdb skipped")
            return 0
        else:
            atomid=atom.get_full_id()
            if atomid[3][0] not in [' '] and atomid[1] != ' ':       #check for hetero residues and "None" residues
                if element == 'S': #means that it is hetero and Sulfur, exclude.
                    print(f"{atomid}, hetero sulfur. pdb skipped ")
                    return 0
                #other salt
                if element in ['CA', 'CL', 'K', 'NA']:
                    for res in structure.get_residues():
                        if res.get_resname() in ["HIS", "CYS", "LYS", "ARG", "ASP", "GLU", "TYR", "MET"]: #MET is NTR. IS CTR EXCLUDED?? CHECK
                            d=np.linalg.norm(res.center_of_mass(geometric=True) - atom.get_coord()) #
                            if d < 3:
                                atom.get_parent().detach_child(atom.get_id())
                                print(f"salt {atom} deleted, {d} from {res}")
    return structure

def atoms_to_structure(cutout, filename): 
    """Internal function, under debugging construction. This one is responsible 
    for making my cutout info into a PDB. Last I checked it was working but not perfectly."""
    chain_dict = {}

    structure = Structure.Structure(filename)
    model = Model.Model(0)
    structure.add(model)

    for atom in cutout:
        res = atom.get_parent()  # a residue obj
        res_id = res.get_id()
        resname = res.get_resname()
        chain_id = res.get_full_id()[2]

        if chain_id not in chain_dict:
            chain = Chain.Chain(chain_id) #make new chain
            chain_dict[chain_id] = chain
            model.add(chain) #add it

        else:
            chain = chain_dict[chain_id]

        if res_id in [r.get_id() for r in chain.get_residues()]:
            residue = [r for r in chain.get_residues() if r.get_id() == res_id][0]
        else:
            residue = Residue.Residue(res_id, resname, '') #make res
            chain.add(residue)

        residue.add(atom)

    # save the pdb
    io = PDBIO()
    io.set_structure(structure)
    io.save(f"{filename}.pdb")

pdb_atoms, parent,tuples,out =[],[],[],[]
resnames=[]
def max_overlap(structure_list):
    for structure in structure_list: #each strcture is a cutout/pdb
        for atom in structure:
            id=atom.get_full_id()
            parent_res=atom.get_parent()
            parent.append(parent_res)
            pdb_atoms.append(id)
        for i in range(len(pdb_atoms)):
            resname=str(parent[i])[9:12] 
            resnames.append(resname)
            if resname in ["HIS", "CYS", "LYS", "ARG", "ASP", "GLU", "TYR"]: #if the atom is part of a titratable residue...
                count = pdb_atoms.count(pdb_atoms[i]) #search for its occurence in the list of all atoms in the pdb (how many dotted circles its in)
                if count > 1:
                    tuple=(count, pdb_atoms[i][2], pdb_atoms[i][3][1]) #(...,chain, resnum)
                    tuples.append(tuple)
    maxes = list(set(tuples))
    maxes.sort()
    MAX = maxes[len(maxes)-1]

    return maxes

sl = get_cutout(dff, 10)
