In [84]:
from Bio.PDB import *
from Bio import PDB
from torch import nn
import numpy as np
from ase import Atoms, Atom
import dask.dataframe as dd
from ordered_set import OrderedSet
import os


local_folder="/Users/jessihoernschemeyer/pKaSchNet"
pkPDB_CSV = f"{local_folder}/pkas.csv"
def read_database(path):
    """csv --> dask df"""
    #make the dask data frame from the PYPKA csv
    dk=dd.read_csv(pkPDB_CSV, delimiter=';', na_filter=False, dtype={'idcode':'category', 
                                                                    'residue_number':'uint8',
                                                                    'pk': 'float32',
                                                                    'residue_name':'category',
                                                                    'chain': 'category',
                                                                    'residue_name': 'category'
                                                                    })
                                                            
    dk=dk.rename(columns={'idcode': 'PDB ID', 'residue_number': 'Res ID', 'residue_name': 'Res Name', 'residue_number': 'Res ID', 'pk': 'pKa', 'chain' : 'Chain'}) #rename columns to match df from pkad 
    dk=dk.sort_values(['PDB ID', 'Res ID'], ascending=[True, True]) 
    dk=dk.compute() 
    dff = dk.reset_index() 

    return dff

def check_atoms_protein(structure, struc_atoms): 
    """internal function. checks every atom in the entire protein for metals, undesirables"""
    pdb_residues=[]
    for atom in struc_atoms: 
        resname, atomid=atom.get_parent().get_resname(), atom.get_full_id()[2:]
        element=atomid[2][0]

        if element in ["MG", "MN", "FE", "CO", "NI", "CU", "ZN"]:
            return 0,0#print(f"{element} present, pdb skipped")
        
        else:
            #atomid=atom.get_full_id() #('', 0, 'B', (' ', 177, ' '), ('OH', ' '))
            if atomid[1][0] not in [' ']:
                if element == 'S': #check 4 hetero sulfur, exclude.
                    print(f"{atomid}, hetero sulfur. pdb skipped ")
                    return 0,0
                
                if element in ['CA', 'CL', 'K', 'NA']: #other salt
                    for res in structure.get_residues():
                        if resname in ["GLU", "HIS", "ASP", "ARG", "TYR", "CYS", "LYS"]: #if the other salt is part of the residue (<3Ã¥ from geometric center), delete atom from residue
                            if np.linalg.norm(res.center_of_mass(geometric=True) - atom.get_coord()) < 3:
                                atom.get_parent().detach_child(atom.get_id()) #print(f"salt {atom} deleted, {d} from {res}")

        pdb_residues.append(atomid[0] + str(atomid[1][1]))
    
    return structure, set(pdb_residues) #('', 0, 'B', ('W', 371, ' '), ('O', ' '))

def atoms_to_structure(cutout, filename): 
    """Internal function (or not), cutout --> save to harddrive
    input: cutout: list of biopython atom objects (NOT ASE)"""
    chain_dict = {}

    structure = Structure.Structure(filename)
    model = Model.Model(0)
    structure.add(model)

    for atom in cutout:
        res = atom.get_parent()
        res_id, resname, chain_id = res.get_id(), res.get_resname(), res.get_full_id()[2]

        #make acidic GLH and ASH straight here. so change their name before saving 
        if resname == "GLU":
            resname="GLH"
            
        if resname=="ASP":
            resname="ASH"

        if resname=="HIS":
            resname="HIP"
            
        
        if chain_id not in chain_dict:
            chain = Chain.Chain(chain_id) #make new chain
            chain_dict[chain_id] = chain
            model.add(chain) #add it

        else:
            chain = chain_dict[chain_id]

        if res_id in [res.get_id() for res in chain.get_residues()]:
            residue = [res for res in chain.get_residues() if res.get_id() == res_id][0] 
        else:
            residue = Residue.Residue(res_id, resname, '') #make new res
            chain.add(residue)

        residue.add(atom)
    # save the pdb
    io = PDBIO()
    io.set_structure(structure)
    io.save(f"{filename}.pdb")

#dask_df = read_database(local_folder + pkPDB_CSV)

#TODO: #TODO: THIS EXCLUDES NTR AND CTR! (navigating pypka error)

dont forget removed disordered at site

In [107]:
#pdb_parser, pdbs = PDB.PDBParser(), list(OrderedSet(list(dask_df["PDB ID"])))
def generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter, fname):
    """Residue wise resolurion. ns is neighbor search set up for the entire protein, residue is the single data point / 1 of several residues in a pdb & in pypka.
    input is one residue. output is the cutout around its titratable site, both of which can be plural e.g. his, mb asp and glu.
    residue (biopython Residue object): a single protonable residue """
    protonatable_sites = ["OE2", "SG", "NZ", "OD2", "NE2", "ND1", "OH"]
    cuts = []
    for atom in residue.get_atoms():
        if atom.is_disordered(): 
            atomN = str(atom)[16:-1]
            if atomN in protonatable_sites:
                return #dont make cutout of titratable site is disordered
        atomN = str(atom)[6:-1]#, residue.get_resname()[0]
        
        if atomN == protonatable_sites[0]: #glu
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut)) #counter is id!
            
            continue
        if atomN  == protonatable_sites[1]: #CYS
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut))
            return cuts
        if atomN == protonatable_sites[2]: #LYS
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut))
            return cuts
        if atomN ==protonatable_sites[3]: #ASP
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut))
            continue
########HIS 
        if atomN == protonatable_sites[4]: #his eps #TODO: maybe one always comes first in biopython?
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((str(counter) + 'E', center, residue.get_resname()[0], cut))
            continue
        if atomN == protonatable_sites[5]: #HIS #TODO
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((str(counter) + 'D', center, residue.get_resname()[0], cut))
            continue

        if atomN == protonatable_sites[6]: #TYR
            center = atom.get_coord()
            cut = ns.search(center, 6, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut))
            return cuts

    return cuts #plural because of sites with multiple sites.


In [199]:
#%%capture
def merge_or_not_cutouts(cutouts_apdb, distance_cutoff): #TODO: reduce dtypes #PDB WISE!
    """
    in: all of the cutouts from the pdb. returns the merged or solo cutout for each input residue of cutouts_apdb. len in = len out"""
    #protein wise ..
    dp_ids,centers,cuts, Ds_lite, counter, cutouts, cuts_dict,resnames,redunant_merged_is =[],[],[],[], 0, [], {},[],[]
    for site in cutouts_apdb:
        dp_ids.append(site[0]) #1,2,3,4,5,6,7,8,9E,9D,... #is counter from generate_around_prot_site
        centers.append(site[1]) 
        resnames.append(site[2])
        cuts.append(site[3])
        

    num_residues=len(centers)
    distances = np.zeros((num_residues, num_residues))
    print(dp_ids)
    for i in range(num_residues): #residue wise..
        for j in range(i + 1, num_residues):
            distance = np.linalg.norm(centers[i] - centers[j]).astype(np.float32)
            if distance < distance_cutoff:
                distances[i, j] = distance
                distances[j, i] = distance

    for i in range(len(distances)):
        Ds_lite.append(np.array([distances[:,i][j] for j in range(num_residues) if distances[:,i][j] != 0])) #goes column wise and just gets the nonzeros
    
#residuewise.... 
    for a_residues_distance_array in Ds_lite: #ds lite is nontrivial distances
        if a_residues_distance_array.size != 0: #if not empty
            closest_cutout_i = int(np.where((distances[:,counter])==np.min(a_residues_distance_array))[0])
            pair_i = frozenset((dp_ids[counter],dp_ids[closest_cutout_i]))#key #frozen set is immutable thus can be used as a dict key #also order doesnt matter, 2-1=1-2
            #print("gonna be merged", pair_i) #LIST(ORDEREDSET) AUTOMATICALLY MAKES THEM ASCENDING!!
            if not bool(cuts_dict.keys()): #if there are not any pairs yet generated, make some..
                print("if none yet", pair_i)
                if pair_i not in cuts_dict.keys(): #if not in merged dict, make it
                    closest_cutout = cuts[closest_cutout_i]
                    print("making", pair_i)
                    cutout = (list(set(cuts[counter] + closest_cutout)),resnames[counter] + resnames[closest_cutout_i])
                    cuts_dict[pair_i] = cutout[0]
                else: #null
                    #print(pair_i)
                    print("null, pair_i")
                    cutout = 0 #"cutout = dict entry
                    redunant_merged_is.append(counter)
            else:
                if pair_i not in cuts_dict.keys(): #if not in merged dict, make it
                    closest_cutout = cuts[closest_cutout_i]
                    print("making", pair_i)
         
                    cutout = (list(set(cuts[counter] + closest_cutout)),resnames[counter] + resnames[closest_cutout_i])
                    cuts_dict[pair_i] = cutout[0]
                else: #null
                    print("null", pair_i)
                    cutout = 0
                    redunant_merged_is.append(counter)
            counter += 1 #the counter goes up each time a 

        else: #solo cutout
            cutout = cuts[counter]
            counter += 1
        cutouts.append(cutout)
    #print(list(cuts_dict.keys()))
    if len(cutouts) != len(dp_ids):
        return #this will make an exception if something went wrong

    return cutouts,redunant_merged_is #merged or solo

def get_cutout(dask_df, distance_cutoff): #"PARENT" FUNCTION
    """for each protein in dask_df (the entire PYPKA database), it iterates residue wise through the 121,294 proteins in PYPKA database and downloads
    the structure from RCSB with biopython. Then, it checks and skips the structure if metals & hetero sulfurs are present, and deletes non-sulfur
    salts from titratable residues.
    Then, for each structure residue represented in PYPKA, generates a cutout for each residue, appends the structure to cutouts_apdb"""
    pdbname="11as"  #for now
    cutouts_apdb, all_tit_res =[],[]
    for i in range(18,19): #will equal len of set of pdbs in pypka, == 121294 
        cutouts_apdb, fnames, cutouts_1_datapoint, counter, pdbname =[],[], [],0, pdbs[i]
        Structure = pdb_parser.get_structure("",  PDBList().retrieve_pdb_file(str.lower(pdbname),obsolete=False, pdir='PDB',file_format = 'pdb'))
        structure, pdb_residues = check_atoms_protein(Structure, Structure.get_atoms())
        if structure==0: #skip entire pdb and all its entries in pypka db if there are undesirables in pdb
            continue

        ns = PDB.NeighborSearch(list(structure.get_atoms())) #set up ns , entire protein
        pdb_df = dask_df[dask_df.iloc[:, 1] == pdbname].drop(columns = ["PDB ID", "pKa"]) #make a subdf containing only residue entries which are in PYPKA (dask_df) 
        dp_ids=[]
        for j in range(len(pdb_df)):  #go through each residue in a pdb #each j is a datapoint!
            chain, res_id =pdb_df.iloc[j]['Chain'], int(pdb_df.iloc[j]['Res ID'])
            
            try: 
                residue=structure[0][chain][res_id] #a datapoint #TODO: make ID?
                pypka_resname, PDBresname = pdb_df.iloc[j]['Res Name'], residue.get_resname() #pypka error

                if pypka_resname not in [PDBresname]: #ACHTUNG! this navigates the pyka error. #TODO: mail him #TODO: THIS EXCLUDES NTR AND CTR!
                    continue

                #generate the cutout solo cutouts
                cutouts_1_datapoint=generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter,  f"{pdbname}_{chain}_{res_id}_{PDBresname}") #can be multiple
                #returns empty if disordered

                cutouts_apdb.append(*cutouts_1_datapoint) #append each residue/data point cutouts here #it will error here if disordered
                fnames.append(f"{pdbname}_{PDBresname[0].lower()}_{chain}_{res_id}-{cutouts_1_datapoint[0][0]}") #{cutouts_1_datapoint[0][0] is the

                all_tit_res.append(chain + str(res_id))
                #print(f"{pdbname}_{res_id}_{chain}_{PDBresname[0]}-{cutouts_1_datapoint[0][0]}")
                counter+=1 #counter doesnt reach here if it fails

            except: 
                counter += 1

                #print(f"{pdbname}_{res_id}_{chain}_{PDBresname}")
                print("except") #delete
                pass #means pypka res not found in PDB

        #os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")  
        if cutouts_apdb:
            merged_and_solos, greaterN_pair_i =merge_or_not_cutouts(cutouts_apdb, distance_cutoff)#make a merged cutout or not based off radius criteria
            #print(merged_and_solos) #!IIS will be generated in ascending order and will always be the the higher cutout# .
            counter=0
            print(greaterN_pair_i)
            for cut, fname in zip(merged_and_solos, fnames[:]):

                if type(cut)==tuple: #means it is a merged cutout. second argument is the pairid #it is still 1-to-1 here but ima destroy it
                    pairid=cut[1]
                    print(pairid)
                    #fnames.index()
                    #Fname="".join([fname,'~',str(pairid[0]),"_",str(pairid[1])])
                    print(counter)
                    Fname="".join([fname,'+',fnames[greaterN_pair_i[counter]],"~",pairid])
                    #del greaterN_pair_i[0]
                    fnames[fnames.index(fname)] = Fname
                    atoms_to_structure(cut[0], Fname) #save as pdb)
                    counter += 1
                elif cut==0:
                    del fnames[fnames.index(fname)]
                    continue
                else:
                    atoms_to_structure(cut, fname) 
            #lines = str(all_tit_res)[1:-1].strip()
            #with open(f"{pdbname}", 'w') as file: #save all the residue names 
                #file.write(lines)

    return fnames,[c[1] for c in cutouts_apdb]
#os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")"""

fs, cs = get_cutout(dask_df, 8)

Structure exists: 'PDB/pdb194l.ent' 
except
except
[1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
if none yet frozenset({2, 4})
making frozenset({2, 4})
null frozenset({2, 4})
making frozenset({5, 14})
making frozenset({17, 6})
making frozenset({8, 10})
null frozenset({8, 10})
making frozenset({11, 12})
null frozenset({11, 12})
null frozenset({5, 14})
null frozenset({17, 6})
[2, 8, 10, 12, 15]
LA
TL
TL


  closest_cutout_i = int(np.where((distances[:,counter])==np.min(a_residues_distance_array))[0])


GA
TA


IndexError: list index out of range

In [183]:
fs

['194l_G_A_7-1',
 '194l_L_A_13-2',
 '194l_A_A_18-4',
 '194l_T_A_20-5',
 '194l_T_A_23-6',
 '194l_L_A_33-7',
 '194l_G_A_35-8',
 '194l_A_A_48-9',
 '194l_A_A_52-10',
 '194l_T_A_53-11~194l_A_A_66-12',
 '194l_A_A_87-13',
 '194l_L_A_96-14',
 '194l_L_A_97-15',
 '194l_A_A_101-16',
 '194l_L_A_116-17',
 '194l_A_A_119-18']

In [176]:
fs

['194l_A_7_GLU-1',
 '194l_A_13_LYS-2',
 '194l_A_18_ASP-4',
 '194l_A_20_TYR-5',
 '194l_A_23_TYR-6',
 '194l_A_33_LYS-7',
 '194l_A_35_GLU-8',
 '194l_A_48_ASP-9',
 '194l_A_52_ASP-10',
 '194l_A_53_TYR-11~194l_A_66_ASP-12',
 '194l_A_66_ASP-12~194l_A_66_ASP-12',
 '194l_A_87_ASP-13',
 '194l_A_96_LYS-14',
 '194l_A_97_LYS-15',
 '194l_A_101_ASP-16',
 '194l_A_116_LYS-17',
 '194l_A_119_ASP-18']

In [None]:
fname

In [166]:
fs[10]

'194l_A_66_ASP-12~11_12'

In [134]:
fs

['194l_A_7_GLU-1',
 '194l_A_13_LYS-2',
 '194l_A_18_ASP-4',
 '194l_A_20_TYR-5',
 '194l_A_23_TYR-6',
 '194l_A_33_LYS-7',
 '194l_A_35_GLU-8',
 '194l_A_48_ASP-9',
 '194l_A_52_ASP-10',
 '194l_A_53_TYR-11~11_12',
 '194l_A_66_ASP-12~11_12',
 '194l_A_87_ASP-13',
 '194l_A_96_LYS-14',
 '194l_A_97_LYS-15',
 '194l_A_101_ASP-16',
 '194l_A_116_LYS-17',
 '194l_A_119_ASP-18']

In [125]:
!cat 194l_A_66_ASP-12~9_10.pdb

ATOM      1  NE  ARG A  68      15.618  12.907  25.348  1.00 21.04           N  
ATOM      2  NH1 ARG A  68      17.096  12.044  23.798  1.00 21.06           N  
ATOM      3  NH2 ARG A  68      15.677  13.774  23.202  1.00 18.29           N  
ATOM      4  CD  ARG A  68      16.058  12.007  26.404  1.00 19.61           C  
ATOM      5  CG  ARG A  68      15.052  11.889  27.516  1.00 18.53           C  
ATOM      6  N   ARG A  68      13.708  10.820  30.066  1.00 16.80           N  
ATOM      7  C   ARG A  68      14.893  13.023  30.298  1.00 18.53           C  
ATOM      8  CZ  ARG A  68      16.125  12.908  24.114  1.00 21.25           C  
ATOM      9  CE2 TYR A  53       9.642  12.506  22.253  1.00 11.57           C  
ATOM     10  CE1 TYR A  53       9.296  14.165  23.955  1.00 11.47           C  
ATOM     11  CG  TYR A  53       7.722  13.983  22.122  1.00 11.04           C  
ATOM     12  OH  TYR A  53      11.165  12.736  24.066  1.00 12.47           O  
ATOM     13  CD2 TYR A  53  

In [117]:
fs

['194l_A_7_GLU-0',
 '194l_A_13_LYS-1',
 '194l_A_18_ASP-2',
 '194l_A_20_TYR-3',
 '194l_A_23_TYR-4',
 '194l_A_33_LYS-5',
 '194l_A_35_GLU-6',
 '194l_A_48_ASP-7',
 '194l_A_52_ASP-8',
 '194l_A_53_TYR-9~9_10',
 '194l_A_66_ASP-10~9_10',
 '194l_A_87_ASP-11',
 '194l_A_96_LYS-12',
 '194l_A_97_LYS-13',
 '194l_A_101_ASP-14',
 '194l_A_116_LYS-15',
 '194l_A_119_ASP-16']

In [102]:
# Path to the original file
source=local_folder + "/194l_A_53_TYR-9~9_10.pdb"

# Path to the duplicate file
destination=local_folder + "/194l_A_66_ASP-10~9_10.pdb"

# Copy the file
!cp "{source}" "{destination}"

join > +

In [93]:
fs

['194l_A_7_GLU-0',
 '194l_A_13_LYS-1',
 '194l_A_18_ASP-2',
 '194l_A_20_TYR-3',
 '194l_A_23_TYR-4',
 '194l_A_33_LYS-5',
 '194l_A_35_GLU-6',
 '194l_A_48_ASP-7',
 '194l_A_52_ASP-8',
 '194l_A_53_TYR-9910',
 '194l_A_66_ASP-10910',
 '194l_A_87_ASP-11',
 '194l_A_96_LYS-12',
 '194l_A_97_LYS-13',
 '194l_A_101_ASP-14',
 '194l_A_116_LYS-15',
 '194l_A_119_ASP-16']

make sure that TIP3p and ff14sb in same

In [None]:
!cat 194l_A_7_GLU-0.pdb

In [59]:
#%%capture

import time
from collections import Counter
from collections import defaultdict
from itertools import chain, product
#protonate("194l", )
def amber(input_pdb):
    skript = f"""source leaprc.protein.ff14SB
    source leaprc.water.tip3p
    loadOff "/Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/lib/amino19.lib"
    mol = loadpdb "/Users/jessihoernschemeyer/pKaSchNet/{input_pdb}.pdb"
    savepdb mol "/Users/jessihoernschemeyer/pKaSchNet/prot_{input_pdb}.pdb"

    quit"""
    with open("ascript.py","w") as file: 
        file.writelines(skript)
    return


for f in fs[:3]:
    amber(f)
    !tleap -s -f /Users/jessihoernschemeyer/pKaSchNet/ascript.py

-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/prep to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/lib to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/parm to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd to search path.
-s: Ignoring all leaprc startup files.
-f: Source /Users/jessihoernschemeyer/pKaSchNet/ascript.py.

Welcome to LEaP!
Sourcing: /Users/jessihoernschemeyer/pKaSchNet/ascript.py
----- Source: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd/leaprc.protein.ff14SB
----- Source of /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd/leaprc.protein.ff14SB done
Log file: ./leap.log
Loading parameters: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/parm/parm10.dat
Reading title:
PARM99 + frcmod.ff99SB + frcmod.parmbsc0 + OL3 for RNA
Loading parameters: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/p

In [None]:
!cat 194l_A_18_ASP-2.pdb


In [None]:
!cat 194l_A_13_LYS-1.pdb

In [14]:
amber("194l_A_7_GLU-0")

how to put not the same cutout through schnet!! maybe have to add a pair id.
will just replace the pdb in the end instead of prot

In [84]:
fname='prot_194l_A_18_ASP-2' + 'AT'
fname
    

'prot_194l_A_18_ASP-2AT'

In [None]:
#acidic
protonatable_sites = ["HE2", "HG", "HZ1", "OD2", "HE2", "HD1", "HH"] # glu cys lys asp hie hid tyr
protonatable_sites = ["G":"HE2", "HG", "HZ1", "OD2", "HE2", "HD1", "HH"] 
in="194l_A_7_GLU-0_0.pdb"

In [None]:
!cat 194l_A_7_GLU-0.pdb

In [None]:
!sed '/HH/d' /Users/jessihoernschemeyer/pKaSchNet/194l_A_53_TYR-9.pdb #

!sed '/delete_this/d' file > newfile

if a solo cutout we can use sed

In [58]:
from Bio.PDB import *
from Bio import PDB
from ase import Atoms, Atom
import torch
from matscipy.neighbours import neighbour_list as msp_neighbor_list
pdb_parser = PDB.PDBParser()
def PDB_to_schnet_input_and_names_map(cut,r):

    pos, names, B, a = [],[], [], []
    z_symbol = {'H' : 1,
        'C' : 6,
        'N' : 7,
        'O' : 8,
        'S': 16}
    #struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/{file}')

    for atom in cut:
        id=atom.get_full_id()
        res, name = id[3][1], id[4][0]
        names.append(name)
        pos.append(atom.get_coord())
        a.append(atom)


    z=[z_symbol.get(name[0]) for name in names]
    #Z IS MADE FROM THE NAMES
    atoms = Atoms([z_symbol.get(name[0]) for name in names], pos)
    atoms.set_cell([[1,0,0], [0,1,0], [0,0,1]])

    d, i, j = msp_neighbor_list('dij',  atoms, [r for i in range(len(atoms))])
    inputs = {'Z':torch.tensor(z).long(), 'R':torch.tensor(d).float(), 'idx_i':torch.tensor(i).long(), 'idx_j': torch.tensor(j).long()}

  
    return inputs, [names, a]

In [66]:
!cat 194l_A_18_ASP-2.pdb

ATOM      1  CE  LYS A  13     -16.811  20.282  11.682  1.00 34.74           C  
ATOM      2  CD  LYS A  13     -15.333  19.969  11.652  1.00 30.15           C  
ATOM      3  CG  LYS A  13     -14.897  19.284  12.931  1.00 25.07           C  
ATOM      4  CD2 LEU A  25     -11.936  22.147  11.830  1.00 15.53           C  
ATOM      5  CB  LEU A  25     -11.288  24.145  13.113  1.00 14.78           C  
ATOM      6  N   ASH A  18     -12.669  21.775  17.303  1.00 20.30           N  
ATOM      7  OD2 ASH A  18     -15.738  23.817  14.132  1.00 29.04           O  
ATOM      8  CG  ASH A  18     -14.738  23.940  14.877  1.00 26.01           C  
ATOM      9  CB  ASH A  18     -14.230  22.676  15.588  1.00 22.91           C  
ATOM     10  OD1 ASH A  18     -14.154  25.049  15.039  1.00 26.14           O  
ATOM     11  O   ASH A  18     -14.928  23.245  18.476  1.00 20.09           O  
ATOM     12  C   ASH A  18     -13.993  23.754  17.846  1.00 21.32           C  
ATOM     13  CA  ASH A  18  

In [68]:
!cat prot_194l_A_13_LYS-1.pdb

ATOM      1  N   LEU     1     -16.599  20.501   9.740  1.00  0.00
ATOM      2  H   LEU     1     -16.579  19.602  10.200  1.00  0.00
ATOM      3  CA  LEU     1     -16.662  20.530   8.292  1.00  0.00
ATOM      4  HA  LEU     1     -17.566  21.052   7.978  1.00  0.00
ATOM      5  CB  LEU     1     -15.454  21.253   7.705  1.00  0.00
ATOM      6  HB2 LEU     1     -15.429  22.278   8.073  1.00  0.00
ATOM      7  HB3 LEU     1     -14.541  20.737   8.003  1.00  0.00
ATOM      8  CG  LEU     1     -15.558  21.262   6.183  1.00  0.00
ATOM      9  HG  LEU     1     -15.583  20.237   5.814  1.00  0.00
ATOM     10  CD1 LEU     1     -16.835  21.984   5.766  1.00  0.00
ATOM     11 HD11 LEU     1     -16.811  23.010   6.134  1.00  0.00
ATOM     12 HD12 LEU     1     -16.910  21.991   4.678  1.00  0.00
ATOM     13 HD13 LEU     1     -17.699  21.468   6.185  1.00  0.00
ATOM     14  CD2 LEU     1     -14.351  21.985   5.596  1.00  0.00
ATOM     15 HD21 LEU     1     -13.438  21.469   5.894  1.00  

In [70]:
cs[1]

array([-17.552,  18.998,  11.717], dtype=float32)

In [64]:
%%capture
!cat 194l_A_18_ASP-2.pdb
!cat 194l_A_13_LYS-1.pdb
struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/prot_194l_A_18_ASP-2.pdb')


ns = PDB.NeighborSearch(list(struct.get_atoms())) #set up ns , entire protein
cut1 = ns.search(cs[2], 5, "A")
input1, extras1 = PDB_to_schnet_input_and_names_map(cut1,6)

struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/prot_194l_A_13_LYS-1.pdb')

len([atom for atom in struct.get_atoms()])

ns = PDB.NeighborSearch(list(struct.get_atoms())) #set up ns , entire protein
cut2 = ns.search(cs[1], 5, "A")



In [65]:
for atom in cut1:
    print(atom.get_parent())
print("")
for atom in cut2:
    print(atom.get_parent())

<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue LEU het=  resseq=2 icode= >
<

In [None]:
len(cut1)

46

In [None]:
len(input1.get('Z'))

46

In [None]:
from schnetpack.representation.schnet import SchNet
from torch.nn import Sequential
from schnetpack.model import NeuralNetworkPotential
from torch import nn
from schnetpack.nn import Dense
from schnetpack.nn.radial import GaussianRBF
from schnetpack.nn.cutoff import CosineCutoff

In [None]:

torch.set_printoptions(profile="full")
weights = torch.load('tensor_dict.pth')
r=10
output_weight = torch.load('output_tensor.pth')

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

#input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
#outputs1 = Model(input1)
#E1 = modelll(outputs1.get('scalar_representation'))
#Eatoms1 = [[atom for atom in extras1[1]], [e for e in E1]]
#Eatoms1_dict = dict(zip([atom for atom in extras1[1]], [e for e in E1]))

NameError: name 'torch' is not defined

In [None]:
outputs1 = Model(input1)

: 

In [None]:
r=10
input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
len(input1.get('Z'))

62

In [None]:
from schnetpack.representation.schnet import SchNet
from torch.nn import Sequential
from schnetpack.model import NeuralNetworkPotential
from torch import nn
from schnetpack.nn import Dense
from schnetpack.nn.radial import GaussianRBF
from schnetpack.nn.cutoff import CosineCutoff
torch.set_printoptions(profile="full")
weights = torch.load('tensor_dict.pth')
r=10
output_weight = torch.load('output_tensor.pth')

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
outputs1 = Model(input1)
E1 = modelll(outputs1.get('scalar_representation'))
Eatoms1 = [[atom for atom in extras1[1]], [e for e in E1]]
Eatoms1_dict = dict(zip([atom for atom in extras1[1]], [e for e in E1]))

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

input2, _ = PDB_to_schnet_input_and_names_map(cut2)
outputs2 = Model(input2)
E2 = modelll(outputs2.get('scalar_representation'))

Eatoms2 = [[atom for atom in _[1]], [e for e in E2]]
Eatoms2_dict = dict(zip([atom for atom in _[1]], [e for e in E2]))

common_entries = set(Eatoms1[0]).intersection(set(Eatoms2[0]))
for x in list(common_entries):
    print(Eatoms2_dict.get(x), Eatoms1_dict.get(x))


NameError: name 'torch' is not defined

In [None]:
print(len(modelll.state_dict()))

NameError: name 'modelll' is not defined

In [None]:
common_entries = set(Eatoms1[0]).intersection(set(Eatoms2[0]))
for x in list(common_entries):
    print(Eatoms2_dict.get(x), Eatoms1_dict.get(x))

In [None]:
type(cut[0].get_full_id())

tuple