In [1]:
from Bio.PDB import *
from Bio import PDB
from torch import nn
import numpy as np
from ase import Atoms, Atom
import dask.dataframe as dd
from ordered_set import OrderedSet
import os


local_folder="/Users/jessihoernschemeyer/pKaSchNet"
pkPDB_CSV = f"{local_folder}/pkas.csv"
def read_database(path):
    """csv --> dask df"""
    #make the dask data frame from the PYPKA csv
    dk=dd.read_csv(pkPDB_CSV, delimiter=';', na_filter=False, dtype={'idcode':'category', 
                                                                    'residue_number':'uint8',
                                                                    'pk': 'float32',
                                                                    'residue_name':'category',
                                                                    'chain': 'category',
                                                                    'residue_name': 'category'
                                                                    })
                                                            
    dk=dk.rename(columns={'idcode': 'PDB ID', 'residue_number': 'Res ID', 'residue_name': 'Res Name', 'residue_number': 'Res ID', 'pk': 'pKa', 'chain' : 'Chain'}) #rename columns to match df from pkad 
    dk=dk.sort_values(['PDB ID', 'Res ID'], ascending=[True, True]) 
    dk=dk.compute() 
    dff = dk.reset_index() 

    return dff

def check_atoms_protein(structure, struc_atoms): 
    """internal function. checks every atom in the entire protein for metals, undesirables"""
    pdb_residues=[]
    for atom in struc_atoms: 
        resname, atomid=atom.get_parent().get_resname(), atom.get_full_id()[2:]
        element=atomid[2][0]

        if element in ["MG", "MN", "FE", "CO", "NI", "CU", "ZN"]:
            return 0#,0#print(f"{element} present, pdb skipped")
        
        else:
            #atomid=atom.get_full_id() #('', 0, 'B', (' ', 177, ' '), ('OH', ' '))
            if atomid[1][0] not in [' ']:
                if element == 'S': #check 4 hetero sulfur, exclude.
                    print(f"{atomid}, hetero sulfur. pdb skipped ")
                    return 0#,0
                
                if element in ['CA', 'CL', 'K', 'NA']: #other salt
                    for res in structure.get_residues():
                        if resname in ["GLU", "HIS", "ASP", "ARG", "TYR", "CYS", "LYS"]: #if the other salt is part of the residue (<3Ã¥ from geometric center), delete atom from residue
                            if np.linalg.norm(res.center_of_mass(geometric=True) - atom.get_coord()) < 3:
                                atom.get_parent().detach_child(atom.get_id()) #print(f"salt {atom} deleted, {d} from {res}")
    
    return structure#, set(pdb_residues) #('', 0, 'B', ('W', 371, ' '), ('O', ' '))

def atoms_to_structure(cutout, filename): 
    """Internal function (or not), cutout --> save to harddrive
    input: cutout: list of biopython atom objects (NOT ASE)"""
    chain_dict = {}

    structure = Structure.Structure(filename)
    model = Model.Model(0)
    structure.add(model)

    for atom in cutout:
        res = atom.get_parent()
        res_id, resname, chain_id = res.get_id(), res.get_resname(), res.get_full_id()[2]

        #make acidic GLH and ASH straight here. so change their name before saving 
        if resname == "GLU":
            resname="GLH"
            
        if resname=="ASP":
            resname="ASH"

        if resname=="HIS":
            resname="HIP"
            
        
        if chain_id not in chain_dict:
            chain = Chain.Chain(chain_id) #make new chain
            chain_dict[chain_id] = chain
            model.add(chain) #add it

        else:
            chain = chain_dict[chain_id]

        if res_id in [res.get_id() for res in chain.get_residues()]:
            residue = [res for res in chain.get_residues() if res.get_id() == res_id][0] 
        else:
            residue = Residue.Residue(res_id, resname, '') #make new res
            chain.add(residue)

        residue.add(atom)
    # save the pdb
    io = PDBIO()
    io.set_structure(structure)
    io.save(f"{filename}.pdb")

dask_df = read_database(local_folder + pkPDB_CSV)

In [145]:
#pdb_parser, pdbs = PDB.PDBParser(), list(OrderedSet(list(dask_df["PDB ID"])))
def generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter, resname):
    """Residue wise resolurion. ns is neighbor search set up for the entire protein, residue is the single data point / 1 of several residues in a pdb & in pypka.
    input is one residue. output is the cutout around its titratable site, both of which can be plural e.g. his, mb asp and glu.
    residue (biopython Residue object): a single protonable residue """
    protonatable_sites = ["OE2", "SG", "NZ", "OD2", "NE2", "ND1", "OH"]
    cuts = []
    if resname==0:
        #first atom is N and NTR
        #atoms=residue.
        print("ntr section")
        Natom = residue.get_atom('N')
        print(Natom)
        center = Natom.get_coord()
        cut = ns.search(center, distance_cutoff, "A")
        print('ntr', cut)
        cuts.append((counter, center, 'NTR', cut)) #counter is id!
        return cuts
    
    elif resname==1:
        print("ctr section")
        try:
            Catom = residue['OXT']
        except:
            Catom = residue['C']
        print(Catom)
        center = Catom.get_coord()
        cut = ns.search(center, distance_cutoff, "A")
        print('ctr',cut)
        cuts.append((counter, center, 'CTR', cut)) #counter is id!
        return cuts
        #residue.get_atoms()
        
     #for atom in residue.get_atoms():
       # if atom.is_disordered(): 
            #atomN = str(atom)[16:-1]
            #if atomN in ["N", "OXT"]: #can remove this after schnet
                #return #dont make cutout of titratable site is disordered


    else:
        if resname=="G": 
            atom=residue[protonatable_sites[0]]
            if atom.is_disordered(): 
                return
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, residue.get_resname()[0], cut)) #counter is id!
            return cuts
            
        if resname=="C": 
            atom=residue[protonatable_sites[1]]
            if atom.is_disordered(): 
                return
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="L": 
            atom=residue[protonatable_sites[2]]
            if atom.is_disordered(): 
                return
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="A": 
            atom=residue[protonatable_sites[3]]
            if atom.is_disordered(): 
                return
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="H":
            atom1=residue[protonatable_sites[4]], atom2=residue[protonatable_sites[5]]
            if atom1.is_disordered(): 
                return
            if atom2.is_disordered():
                return
            
            center1,center2=atom1.get_coord(), atom2.get_coord()
            cut1=ns.search(center1, distance_cutoff, "A")
            cut2=ns.search(center2, distance_cutoff, "A")
            cuts.append((counter+.1, center1, resname, cut1), (counter+.2, center2, resname, cut2))
        if resname=="T":
            atom=residue[protonatable_sites[6]]
            if atom.is_disordered(): 
                return
            center = atom.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts

    return cuts #plural because of sites with multiple sites.


In [146]:
#%%capture
def merge_or_not_cutouts(cutouts_apdb, distance_cutoff): #TODO: reduce dtypes #PDB WISE!
    """
    in: all of the cutouts from the pdb. returns the merged or solo cutout for each input residue of cutouts_apdb. len in = len out"""
    #protein wise ..
    dp_ids,centers,cuts, Ds_lite, cutouts, resnames,redunant_merged_is, done_pairs =[],[],[],[], [],[],[],[]

    for site in cutouts_apdb:
        dp_ids.append(site[0]) #1,2,3,4,5,6,7,8,9.1, 9.2....
        centers.append(site[1]) 
        resnames.append(site[2])
        cuts.append(site[3])

    num_residues=len(centers)
    distances = np.zeros((num_residues, num_residues))
    
    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            distance = np.linalg.norm(centers[i] - centers[j]).astype(np.float32)
            if distance < distance_cutoff:
                distances[i, j] = distance
                distances[j, i] = distance

    Ds_lite = [distances[i][distances[i] != 0] for i in range(num_residues)] #nonzero entries for easier searching
#residuewise...
    for i in range(len(Ds_lite)): #=len IDs 
        a_residues_distance_array=Ds_lite[i]
        if a_residues_distance_array.any(): #if not empty
            index=dp_ids[i]
            closest_cutout_i = int(np.where((distances[:,int(index)])==np.min(a_residues_distance_array))[0])
            pair_i = frozenset((index,dp_ids[closest_cutout_i]))#key #frozen set is immutable thus can be used as a dict key #also order doesnt matter, 2-1=1-2

            if not done_pairs: #if there are any yet merged
                cutout = (list(set(cuts[i] + cuts[closest_cutout_i])),resnames[i] + resnames[closest_cutout_i])
                done_pairs.append(pair_i)
                redunant_merged_is.append(closest_cutout_i)
                
            else: #if there are already some generated
                if pair_i not in done_pairs: #if that mergedcut hasnt yet been made
                    cutout = (list(set(cuts[i] + cuts[closest_cutout_i])),resnames[i] + resnames[closest_cutout_i])
                    done_pairs.append(pair_i)
                    redunant_merged_is.append(closest_cutout_i)

                else: #null
                    cutout = None

        else: #solo cutout
            cutout = cuts[i]

        cutouts.append(cutout)

    if len(cutouts) != len(dp_ids): #delete?
        return #this will make an exception if something went wrong
    
    return cutouts,redunant_merged_is #merged or solo

def get_cutout(dask_df, distance_cutoff): #"PARENT" FUNCTION
    """for each protein in dask_df (the entire PYPKA database), it iterates residue wise through the 121,294 proteins in PYPKA database and downloads
    the structure from RCSB with biopython. Then, it checks and skips the structure if metals & hetero sulfurs are present, and deletes non-sulfur
    salts from titratable residues.
    Then, for each structure residue represented in PYPKA, generates a cutout for each residue, appends the structure to cutouts_apdb"""
    #pdbname="11as"  #for now #delete
    cutouts_apdb, all_tit_res =[],[]
    for i in range(24,25): #will equal len of set of pdbs in pypka, == 121294 
        cutouts_apdb, fnames, cutouts_1_datapoint, counter, pdbname, newfnames =[],[], [],0, pdbs[i],[]
        Structure = pdb_parser.get_structure("",  PDBList().retrieve_pdb_file(str.lower(pdbname),obsolete=False, pdir='PDB',file_format = 'pdb'))
        structure= check_atoms_protein(Structure, Structure.get_atoms())
        if not structure: #skip entire pdb and all its entries in pypka db if there are undesirables in pdb
            continue
            
        ns = PDB.NeighborSearch(list(structure.get_atoms())) #set up ns , entire protein
        pdb_df = dask_df[dask_df.iloc[:, 1] == pdbname].drop(columns = ["PDB ID", "pKa"]) #make a subdf containing only residue entries which are in PYPKA (dask_df) 
        for j in range(len(pdb_df)):  #go through each residue in a pdb #each j is a datapoint!
            chain, res_id =pdb_df.iloc[j]['Chain'], int(pdb_df.iloc[j]['Res ID'])
            try: 
                residue=structure[0][chain][res_id] #a datapoint #TODO: make ID?
                pypka_resname, PDBresname = pdb_df.iloc[j]['Res Name'], residue.get_resname() #pypka error
                if pypka_resname=='NTR':
                    resname=0
                elif pypka_resname=='CTR':
                    resname=1
                elif pypka_resname==PDBresname:
                    resname=pypka_resname[0]
                    print(resname)
                #pypka error: if not in ntr or == pdbresname, it will error and pass.
                    #elif pypka_resname == PDBresname: #ACHTUNG! this navigates the pyka error. #TODO: mail him #TODO: THIS EXCLUDES NTR AND CTR!

                cutouts_1_datapoint=generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter, resname) #can be multiple #returns empty if disordered
                if cutouts_1_datapoint: #cutouts_1_datapoint DNE if titratable site is disordered
                    cutouts_apdb.append(*cutouts_1_datapoint) #append each residue/data point cutouts here #it will error here if disordered
                    fnames.append(f"{counter}_{pdbname}_{pypka_resname}_{chain}_{res_id}") 
                    counter+=1
                else:
                    continue #pypka error

            except: 
                print("hi")
                pass #means pypka res not found in PDB
        
        #os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")  
        if cutouts_apdb:
            merged_and_solos, greaterN_pair_i =merge_or_not_cutouts(cutouts_apdb, distance_cutoff)#make a merged cutout or not based off radius criteria
            for cut, fname in zip(merged_and_solos, fnames):
                if type(cut)==tuple: #means it is a merged cutout. second argument is the pairid #it is still 1-to-1 here but ima destroy it
                    Fname="".join([fname,'+',fnames[greaterN_pair_i[0]],"~",cut[1]]) #cut1 is pairid AT, HH,...
                    newfnames.append(Fname)
                    del greaterN_pair_i[0]
                    atoms_to_structure(cut[0], Fname) #save as pdb)
                elif not cut:
                    continue
                else:
                    newfnames.append(fname)
                    atoms_to_structure(cut, fname) 

    return newfnames,[c[1] for c in cutouts_apdb] #centers


fs,centers = get_cutout(dask_df, 5)

Structure exists: 'PDB/pdb1a0f.ent' 
ntr section
hi
ntr section
hi
L
L
T
T




L
L
C
C
H
hi
H
hi
G
G
L
L
A
A
A
A
L
L
L
L
G
G
A
A
A
A
T
T
L
L
A
A
A
A
G
G
T
T
A
A
A
A
T
T
L
L
G
G
T
T
G
G
H
hi
H
hi
L
L
A
A
G
G
G
G
T
T
L
L
G
G
L
L
L
L
T
T
G
G
L
L
A
A
G
G
H
hi
H
hi
C
C
A
A
T
T
T
T
L
L
G
G
G
G
H
hi
H
hi
G
G
G
G
A
A
G
G
L
ctr section
<Atom OXT>
ctr [<Atom CG>, <Atom CD2>, <Atom O>, <Atom O>, <Atom C>, <Atom N>, <Atom CG>, <Atom CA>, <Atom CB>, <Atom CD1>, <Atom OXT>, <Atom O>, <Atom C>, <Atom O>]
ctr section
<Atom OXT>
ctr [<Atom O>, <Atom CD1>, <Atom O>, <Atom O>, <Atom OXT>, <Atom O>, <Atom C>, <Atom CB>, <Atom CA>, <Atom O>, <Atom CG>, <Atom N>, <Atom C>]
L


  closest_cutout_i = int(np.where((distances[:,int(index)])==np.min(a_residues_distance_array))[0])


make sure that TIP3p and ff14sb in same

In [131]:
#%%capture

import time
from collections import Counter
from collections import defaultdict
from itertools import chain, product
#protonate("194l", )
def amber(input_pdb):
    skript = f"""source leaprc.protein.ff14SB
    source leaprc.water.tip3p
    loadOff "/Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/lib/amino19.lib"
    mol = loadpdb "/Users/jessihoernschemeyer/pKaSchNet/{input_pdb}.pdb"
    savepdb mol "/Users/jessihoernschemeyer/pKaSchNet/prot_{input_pdb}.pdb"

    quit"""
    with open("ascript.py","w") as file: 
        file.writelines(skript)
    return


for f in fs:
    amber(f)
    !tleap -s -f /Users/jessihoernschemeyer/pKaSchNet/ascript.py

-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/prep to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/lib to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/parm to search path.
-I: Adding /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd to search path.
-s: Ignoring all leaprc startup files.
-f: Source /Users/jessihoernschemeyer/pKaSchNet/ascript.py.

Welcome to LEaP!
Sourcing: /Users/jessihoernschemeyer/pKaSchNet/ascript.py
----- Source: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd/leaprc.protein.ff14SB
----- Source of /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/cmd/leaprc.protein.ff14SB done
Log file: ./leap.log
Loading parameters: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/parm/parm10.dat
Reading title:
PARM99 + frcmod.ff99SB + frcmod.parmbsc0 + OL3 for RNA
Loading parameters: /Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/p

In [None]:
#acidic
protonatable_sites = ["HE2", "HG", "HZ1", "OD2", "HE2", "HD1", "HH"] # glu cys lys asp hie hid tyr
protonatable_sites = ["G":"HE2", "HG", "HZ1", "OD2", "HE2", "HD1", "HH"] 
in="194l_A_7_GLU-0_0.pdb"

In [None]:
!sed '/HH/d' /Users/jessihoernschemeyer/pKaSchNet/194l_A_53_TYR-9.pdb #

!sed '/delete_this/d' file > newfile

if a solo cutout we can use sed

In [58]:
from Bio.PDB import *
from Bio import PDB
from ase import Atoms, Atom
import torch
from matscipy.neighbours import neighbour_list as msp_neighbor_list
pdb_parser = PDB.PDBParser()
def PDB_to_schnet_input_and_names_map(cut,r):

    pos, names, B, a = [],[], [], []
    z_symbol = {'H' : 1,
        'C' : 6,
        'N' : 7,
        'O' : 8,
        'S': 16}
    #struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/{file}')

    for atom in cut:
        id=atom.get_full_id()
        res, name = id[3][1], id[4][0]
        names.append(name)
        pos.append(atom.get_coord())
        a.append(atom)


    z=[z_symbol.get(name[0]) for name in names]
    #Z IS MADE FROM THE NAMES
    atoms = Atoms([z_symbol.get(name[0]) for name in names], pos)
    atoms.set_cell([[1,0,0], [0,1,0], [0,0,1]])

    d, i, j = msp_neighbor_list('dij',  atoms, [r for i in range(len(atoms))])
    inputs = {'Z':torch.tensor(z).long(), 'R':torch.tensor(d).float(), 'idx_i':torch.tensor(i).long(), 'idx_j': torch.tensor(j).long()}

  
    return inputs, [names, a]

In [66]:
!cat 194l_A_18_ASP-2.pdb

ATOM      1  CE  LYS A  13     -16.811  20.282  11.682  1.00 34.74           C  
ATOM      2  CD  LYS A  13     -15.333  19.969  11.652  1.00 30.15           C  
ATOM      3  CG  LYS A  13     -14.897  19.284  12.931  1.00 25.07           C  
ATOM      4  CD2 LEU A  25     -11.936  22.147  11.830  1.00 15.53           C  
ATOM      5  CB  LEU A  25     -11.288  24.145  13.113  1.00 14.78           C  
ATOM      6  N   ASH A  18     -12.669  21.775  17.303  1.00 20.30           N  
ATOM      7  OD2 ASH A  18     -15.738  23.817  14.132  1.00 29.04           O  
ATOM      8  CG  ASH A  18     -14.738  23.940  14.877  1.00 26.01           C  
ATOM      9  CB  ASH A  18     -14.230  22.676  15.588  1.00 22.91           C  
ATOM     10  OD1 ASH A  18     -14.154  25.049  15.039  1.00 26.14           O  
ATOM     11  O   ASH A  18     -14.928  23.245  18.476  1.00 20.09           O  
ATOM     12  C   ASH A  18     -13.993  23.754  17.846  1.00 21.32           C  
ATOM     13  CA  ASH A  18  

In [68]:
!cat prot_194l_A_13_LYS-1.pdb

ATOM      1  N   LEU     1     -16.599  20.501   9.740  1.00  0.00
ATOM      2  H   LEU     1     -16.579  19.602  10.200  1.00  0.00
ATOM      3  CA  LEU     1     -16.662  20.530   8.292  1.00  0.00
ATOM      4  HA  LEU     1     -17.566  21.052   7.978  1.00  0.00
ATOM      5  CB  LEU     1     -15.454  21.253   7.705  1.00  0.00
ATOM      6  HB2 LEU     1     -15.429  22.278   8.073  1.00  0.00
ATOM      7  HB3 LEU     1     -14.541  20.737   8.003  1.00  0.00
ATOM      8  CG  LEU     1     -15.558  21.262   6.183  1.00  0.00
ATOM      9  HG  LEU     1     -15.583  20.237   5.814  1.00  0.00
ATOM     10  CD1 LEU     1     -16.835  21.984   5.766  1.00  0.00
ATOM     11 HD11 LEU     1     -16.811  23.010   6.134  1.00  0.00
ATOM     12 HD12 LEU     1     -16.910  21.991   4.678  1.00  0.00
ATOM     13 HD13 LEU     1     -17.699  21.468   6.185  1.00  0.00
ATOM     14  CD2 LEU     1     -14.351  21.985   5.596  1.00  0.00
ATOM     15 HD21 LEU     1     -13.438  21.469   5.894  1.00  

In [70]:
cs[1]

array([-17.552,  18.998,  11.717], dtype=float32)

In [108]:
%%capture
!cat 194l_A_18_ASP-2.pdb
!cat 194l_A_13_LYS-1.pdb
struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/prot_194l_A_18_ASP-2.pdb')


ns = PDB.NeighborSearch(list(struct.get_atoms())) #set up ns , entire protein
cut1 = ns.search(cs[2], 5, "A")
input1, extras1 = PDB_to_schnet_input_and_names_map(cut1,6)

struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/prot_194l_A_13_LYS-1.pdb')

len([atom for atom in struct.get_atoms()])

ns = PDB.NeighborSearch(list(struct.get_atoms())) #set up ns , entire protein
cut2 = ns.search(cs[1], 5, "A")



NameError: name 'cs' is not defined

In [111]:
a=struct.get_residues()
for b in a:
    print(b['N'])

<Atom N>
<Atom N>
<Atom N>
<Atom N>


In [65]:
for atom in cut1:
    print(atom.get_parent())
print("")
for atom in cut2:
    print(atom.get_parent())

<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue LYS het=  resseq=1 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue ASN het=  resseq=4 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue ASH het=  resseq=3 icode= >
<Residue LEU het=  resseq=2 icode= >
<Residue LEU het=  resseq=2 icode= >
<

In [None]:
len(cut1)

46

In [None]:
len(input1.get('Z'))

46

In [None]:
from schnetpack.representation.schnet import SchNet
from torch.nn import Sequential
from schnetpack.model import NeuralNetworkPotential
from torch import nn
from schnetpack.nn import Dense
from schnetpack.nn.radial import GaussianRBF
from schnetpack.nn.cutoff import CosineCutoff

In [None]:

torch.set_printoptions(profile="full")
weights = torch.load('tensor_dict.pth')
r=10
output_weight = torch.load('output_tensor.pth')

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

#input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
#outputs1 = Model(input1)
#E1 = modelll(outputs1.get('scalar_representation'))
#Eatoms1 = [[atom for atom in extras1[1]], [e for e in E1]]
#Eatoms1_dict = dict(zip([atom for atom in extras1[1]], [e for e in E1]))

NameError: name 'torch' is not defined

In [None]:
outputs1 = Model(input1)

: 

In [None]:
r=10
input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
len(input1.get('Z'))

62

In [None]:
from schnetpack.representation.schnet import SchNet
from torch.nn import Sequential
from schnetpack.model import NeuralNetworkPotential
from torch import nn
from schnetpack.nn import Dense
from schnetpack.nn.radial import GaussianRBF
from schnetpack.nn.cutoff import CosineCutoff
torch.set_printoptions(profile="full")
weights = torch.load('tensor_dict.pth')
r=10
output_weight = torch.load('output_tensor.pth')

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

input1, extras1 = PDB_to_schnet_input_and_names_map(cut1)
outputs1 = Model(input1)
E1 = modelll(outputs1.get('scalar_representation'))
Eatoms1 = [[atom for atom in extras1[1]], [e for e in E1]]
Eatoms1_dict = dict(zip([atom for atom in extras1[1]], [e for e in E1]))

Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)

modelll=nn.Sequential(Dense(128,64), Dense(64,1))

with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))

input2, _ = PDB_to_schnet_input_and_names_map(cut2)
outputs2 = Model(input2)
E2 = modelll(outputs2.get('scalar_representation'))

Eatoms2 = [[atom for atom in _[1]], [e for e in E2]]
Eatoms2_dict = dict(zip([atom for atom in _[1]], [e for e in E2]))

common_entries = set(Eatoms1[0]).intersection(set(Eatoms2[0]))
for x in list(common_entries):
    print(Eatoms2_dict.get(x), Eatoms1_dict.get(x))


NameError: name 'torch' is not defined

In [None]:
print(len(modelll.state_dict()))

NameError: name 'modelll' is not defined

In [None]:
common_entries = set(Eatoms1[0]).intersection(set(Eatoms2[0]))
for x in list(common_entries):
    print(Eatoms2_dict.get(x), Eatoms1_dict.get(x))

In [None]:
type(cut[0].get_full_id())

tuple