In [2]:

import torch
from schnetpack.representation.schnet import SchNet
from torch.nn import Sequential
from schnetpack.model import NeuralNetworkPotential
from torch import nn
from schnetpack.nn import Dense
from Bio.PDB import *
import pandas as pd
import MDAnalysis as mda
from MDAnalysis import Merge
import numpy as np
np.set_printoptions(threshold=1000) #1000 is default
import random
import math
from IPython.display import display, HTML
import re
import os
from ase import Atoms, Atom
import dask.dataframe as dd
from schnetpack.nn.radial import GaussianRBF
from schnetpack.nn.cutoff import CosineCutoff

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
dff

Unnamed: 0,index,PDB ID,Chain,Res Name,Res ID,pKa
0,947317,107l,A,NTR,1,8.12807
1,2525126,107l,A,GLU,5,3.47286
2,1774291,107l,A,ASP,10,1.21981
3,3149365,107l,A,GLU,11,4.16394
4,2152859,107l,A,LYS,16,10.31040
...,...,...,...,...,...,...
12628143,2567963,2n9a,A,NTR,1,7.25285
12628144,2939746,2n9a,A,LYS,8,10.34350
12628145,2646745,2n9a,A,CTR,11,2.79366
12628146,2632960,6uoq,A,NTR,24,7.93290


In [3]:
header = """
   _____      __    _   __     __  ____             __
  / ___/_____/ /_  / | / /__  / /_/ __ \____ ______/ /__
  \__ \/ ___/ __ \/  |/ / _ \/ __/ /_/ / __ `/ ___/ //_/
 ___/ / /__/ / / / /|  /  __/ /_/ ____/ /_/ / /__/ ,<
/____/\___/_/ /_/_/ |_/\___/\__/_/    \__,_/\___/_/|_|
"""



dictionary = {
    1 : {
        'Atomic_No': 1,
        'Atomic_Mass': 1,
        'Element_Name': 'Hydrogen',
        'Element_Symbol': 'H',
        'Group_No': 1,
        'Period_No': 1,
        'Relative_Atomic_Mass': 1.0079
    },
    2 : {
        'Atomic_No': 2,
        'Atomic_Mass': 4,
        'Element_Name': 'Helium',
        'Element_Symbol': 'He',
        'Group_No': 18,
        'Period_No': 1,
        'Relative_Atomic_Mass': 4.0026
    },
    3 : {
        'Atomic_No': 3,
        'Atomic_Mass': 7,
        'Element_Name': 'Lithium',
        'Element_Symbol': 'Li',
        'Group_No': 1,
        'Period_No': 2,
        'Relative_Atomic_Mass': 6.941
    },
    4 : {
        'Atomic_No': 4,
        'Atomic_Mass': 9,
        'Element_Name': 'Beryllium',
        'Element_Symbol': 'Be',
        'Group_No': 2,
        'Period_No': 2,
        'Relative_Atomic_Mass': 9.0122
    },
    5 : {
        'Atomic_No': 5,
        'Atomic_Mass': 11,
        'Element_Name': 'Boron',
        'Element_Symbol': 'B',
        'Group_No': 13,
        'Period_No': 2,
        'Relative_Atomic_Mass': 10.811
    },
    6 : {
        'Atomic_No': 6,
        'Atomic_Mass': 12,
        'Element_Name': 'Carbon',
        'Element_Symbol': 'C',
        'Group_No': 14,
        'Period_No': 2,
        'Relative_Atomic_Mass': 12.0107
    },
    7 : {
        'Atomic_No': 7,
        'Atomic_Mass': 14,
        'Element_Name': 'Nitrogen',
        'Element_Symbol': 'N',
        'Group_No': 15,
        'Period_No': 2,
        'Relative_Atomic_Mass': 14.0067
    },
    8 : {
        'Atomic_No': 8,
        'Atomic_Mass': 16,
        'Element_Name': 'Oxygen',
        'Element_Symbol': 'O',
        'Group_No': 16,
        'Period_No': 2,
        'Relative_Atomic_Mass': 15.9994
    },
    9 : {
        'Atomic_No': 9,
        'Atomic_Mass': 19,
        'Element_Name': 'Fluorine',
        'Element_Symbol': 'F',
        'Group_No': 17,
        'Period_No': 2,
        'Relative_Atomic_Mass': 18.9984
    },
    10 : {
        'Atomic_No': 10,
        'Atomic_Mass': 20,
        'Element_Name': 'Neon',
        'Element_Symbol': 'Ne',
        'Group_No': 18,
        'Period_No': 2,
        'Relative_Atomic_Mass': 20.1797
    },
  11 : {
        'Atomic_No': 11,
        'Atomic_Mass': 23,
        'Element_Name': 'Sodium',
        'Element_Symbol': 'Na',
        'Group_No': 1,
        'Period_No': 3,
        'Relative_Atomic_Mass': 22.9897
    },
    12 : {
        'Atomic_No': 12,
        'Atomic_Mass': 24,
        'Element_Name': 'Magnesium',
        'Element_Symbol': 'Mg',
        'Group_No': 2,
        'Period_No': 3,
        'Relative_Atomic_Mass': 24.305
    },
    13 : {
        'Atomic_No': 13,
        'Atomic_Mass': 27,
        'Element_Name': 'Aluminium',
        'Element_Symbol': 'Al',
        'Group_No': 13,
        'Period_No': 3,
        'Relative_Atomic_Mass': 26.9815
    },
    14 : {
        'Atomic_No': 14,
        'Atomic_Mass': 28,
        'Element_Name': 'Silicon',
        'Element_Symbol': 'Si',
        'Group_No': 14,
        'Period_No': 3,
        'Relative_Atomic_Mass': 28.0855
    },
    15 : {
        'Atomic_No': 15,
        'Atomic_Mass': 31,
        'Element_Name': 'Phosphorus',
        'Element_Symbol': 'P',
        'Group_No': 15,
        'Period_No': 3,
        'Relative_Atomic_Mass': 30.9738
    },
    16 : {
        'Atomic_No': 16,
        'Atomic_Mass': 32,
        'Element_Name': 'Sulphur',
        'Element_Symbol': 'S',
        'Group_No': 16,
        'Period_No': 3,
        'Relative_Atomic_Mass': 32.065
    },
    17 : {
        'Atomic_No': 17,
        'Atomic_Mass': 35.5,
        'Element_Name': 'Chlorine',
        'Element_Symbol': 'Cl',
        'Group_No': 17,
        'Period_No': 3,
        'Relative_Atomic_Mass': 35.453
    },
    18 : {
        'Atomic_No': 18,
        'Atomic_Mass': 40,
        'Element_Name': 'Argon',
        'Element_Symbol': 'Ar',
        'Group_No': 18,
        'Period_No': 3,
        'Relative_Atomic_Mass': 39.948
    },
    19 : {
        'Atomic_No': 19,
        'Atomic_Mass': 39,
        'Element_Name': 'Potassium',
        'Element_Symbol': 'K',
        'Group_No': 1,
        'Period_No': 4,
        'Relative_Atomic_Mass': 39.0983
    },
    20 : {
        'Atomic_No': 20,
        'Atomic_Mass': 40,
        'Element_Name': 'Calcium',
        'Element_Symbol': 'Ca',
        'Group_No': 2,
        'Period_No': 4,
        'Relative_Atomic_Mass': 40.078
    },
    21 : {
        'Atomic_No': 21,
        'Atomic_Mass': 45,
        'Element_Name': 'scandium',
        'Element_Symbol': 'Sn',
        'Group_No': 3,
        'Period_No': 4,
        'Relative_Atomic_Mass': 44.956
    },
    22 : {
        'Atomic_No': 22,
        'Atomic_Mass': 48,
        'Element_Name': 'Titanium',
        'Element_Symbol': 'Ti',
        'Group_No': 4,
        'Period_No': 4,
        'Relative_Atomic_Mass': 47.867
    },
    23 : {
        'Atomic_No': 23,
        'Atomic_Mass': 51,
        'Element_Name': 'Vanadium',
        'Element_Symbol': 'V',
        'Group_No': 5,
        'Period_No': 4,
        'Relative_Atomic_Mass': 50.942
    },
    24 : {
        'Atomic_No': 24,
        'Atomic_Mass': 52,
        'Element_Name': 'Chiromium',
        'Element_Symbol': 'Cr',
        'Group_No': 6,
        'Period_No': 4,
        'Relative_Atomic_Mass': 51.996
    },
    25 : {
        'Atomic_No': 25,
        'Atomic_Mass': 55,
        'Element_Name': 'Manganese',
        'Element_Symbol': 'Mn',
        'Group_No': 7,
        'Period_No': 4,
        'Relative_Atomic_Mass': 54.938
    },
    26 : {
        'Atomic_No': 26,
        'Atomic_Mass': 56,
        'Element_Name': 'Iron',
        'Element_Symbol': 'Fe',
        'Group_No': 8,
        'Period_No': 4,
        'Relative_Atomic_Mass': 55.845
    },
    27 : {
        'Atomic_No': 27,
        'Atomic_Mass': 59,
        'Element_Name': 'Cobalt',
        'Element_Symbol': 'Co',
        'Group_No': 9,
        'Period_No': 4,
        'Relative_Atomic_Mass': 58.933
    },
    28 : {
        'Atomic_No': 28,
        'Atomic_Mass': 59,
        'Element_Name': 'Nickel',
        'Element_Symbol': 'Ni',
        'Group_No': 10,
        'Period_No': 4,
        'Relative_Atomic_Mass': 58.693
    },
    29 : {
        'Atomic_No': 29,
        'Atomic_Mass': 64,
        'Element_Name': 'Copper',
        'Element_Symbol': 'Cu',
        'Group_No': 11,
        'Period_No': 4,
        'Relative_Atomic_Mass': 63.546
    },
    30 : {
        'Atomic_No': 30,
        'Atomic_Mass': 65,
        'Element_Name': 'Zinc',
        'Element_Symbol': 'Zn',
        'Group_No': 12,
        'Period_No': 4,
        'Relative_Atomic_Mass': 65.38
    },
    31 : {
        'Atomic_No': 31,
        'Atomic_Mass': 70,
        'Element_Name': 'Gallium',
        'Element_Symbol': 'Ga',
        'Group_No': 13,
        'Period_No': 4,
        'Relative_Atomic_Mass': 69.723
    },
    32 : {
        'Atomic_No': 32,
        'Atomic_Mass': 73,
        'Element_Name': 'Germanium',
        'Element_Symbol': 'Ge',
        'Group_No': 14,
        'Period_No': 4,
        'Relative_Atomic_Mass': 72.64
    },
    33 : {
        'Atomic_No': 33,
        'Atomic_Mass': 74.922,
        'Element_Name': 'Arsenic',
        'Element_Symbol': 'As',
        'Group_No': 15,
        'Period_No': 4,
        'Relative_Atomic_Mass': 74.922
    },
    34 : {
        'Atomic_No': 34,
        'Atomic_Mass': 79,
        'Element_Name': 'Selenium',
        'Element_Symbol': 'Se',
        'Group_No': 16,
        'Period_No': 4,
        'Relative_Atomic_Mass': 78.96
    },
    35 : {
        'Atomic_No': 35,
        'Atomic_Mass': 80,
        'Element_Name': 'Bromine',
        'Element_Symbol': 'Br',
        'Group_No': 17,
        'Period_No': 4,
        'Relative_Atomic_Mass': 80
    },
    36 : {
        'Atomic_No': 36,
        'Atomic_Mass': 84,
        'Element_Name': 'Krypton',
        'Element_Symbol': 'Kr',
        'Group_No': 18,
        'Period_No': 4,
        'Relative_Atomic_Mass': 83.798
    },
    37 : {
        'Atomic_No': 37,
        'Atomic_Mass': 85,
        'Element_Name': 'Rubidium',
        'Element_Symbol': 'Rb',
        'Group_No': 1,
        'Period_No': 5,
        'Relative_Atomic_Mass': 85.468
    },
    38 : {
        'Atomic_No': 38,
        'Atomic_Mass': 88,
        'Element_Name': 'Strontium',
        'Element_Symbol': 'Sr',
        'Group_No': 2,
        'Period_No': 5,
        'Relative_Atomic_Mass': 87.62
    },
    39 : {
        'Atomic_No': 39,
        'Atomic_Mass': 88.906,
        'Element_Name': 'Yttrium',
        'Element_Symbol': 'Y',
        'Group_No': 3,
        'Period_No': 5,
        'Relative_Atomic_Mass': 88.906
    },
    40 : {
        'Atomic_No': 40,
        'Atomic_Mass': 91,
        'Element_Name': 'Zirconium',
        'Element_Symbol': 'Zr',
        'Group_No': 4,
        'Period_No': 5,
        'Relative_Atomic_Mass': 91.224
    }
}

In [4]:
local_folder="/Users/jessihoernschemeyer/pKaSchNet"
pkPDB_CSV = f"{local_folder}/pkas.csv"
PKAD_CSV = f"{local_folder}/WT_pka.csv"

def download_PDB(pdbname, input_df):
    """downloads PDBs from biopython. iterations = the number is the number of pdbs we make neighborhoods for """
    PDBList().retrieve_pdb_file(pdbname,obsolete=False, pdir='PDB',file_format = 'pdb')
    return mda.Universe(f"PDB/pdb{pdbname}.ent") #md universe 'u'

def titratable_res_info(resid_chain_in, u):
    """universe --> extract names, pos, res names of TITRATABLE INPUT RESIDUE . res_atoms.names[1]) #ALPHA CARBON!! """
    x_tit, y_tit, z_tit, tit_atomnames =[],[],[], []

    res_atoms=u.select_atoms(resid_chain_in) 
    
    for j in range(len(res_atoms.positions)): #len resatom positions
        x_tit.append(res_atoms.positions[j,0])
        y_tit.append(res_atoms.positions[j,1])
        z_tit.append(res_atoms.positions[j,2])

    for i in range(len(res_atoms.names)):
        tit_atomname=res_atoms.names[i][0]
        tit_atomnames.append(tit_atomname)

    return list(res_atoms.resnames), list(np.repeat(res_atoms.residues.resids, len(res_atoms.resnames))), tit_atomnames, x_tit, y_tit, z_tit

#called/used in 'neighborhood'
def MDAvicinity(in_cutoff, resid_chain_in, u, res_name_tit, res_ids_tit, atomname_tit, x_tit, y_tit, z_tit): 

    #MD
    vicinity=(((((u.select_atoms(f'around {in_cutoff} {resid_chain_in}')).select_atoms('not resname HOH')).select_atoms('resname [!0-9]**')).select_atoms('resname *[!0-9]*')).select_atoms('resname **[!0-9]')).residues
    
    #append vicinity info from mda onto titratable info
    for i in range(len(vicinity.atoms.names)): 
        atomname_tit.append(vicinity.atoms.names[i][0]) #get first letter of atom

        res_name_tit.append(vicinity.atoms.resnames[i])

        res_ids_tit.append(vicinity.atoms.resids[i]) 

        x_tit.append(vicinity.atoms.positions[i,0])
        y_tit.append(vicinity.atoms.positions[i,1])
        z_tit.append(vicinity.atoms.positions[i,2])

    return res_name_tit, res_ids_tit, atomname_tit, np.array(x_tit), np.array(y_tit), np.array(z_tit)


one_res_Atoms, one_res_atoms, list_Atoms_objs =[], [], []

def neighborhoods(input_df, in_cutoff):  
    for i in range(1): 
        z = []
        pdbname=str.lower(input_df.iloc[i]['PDB ID']) #get pdb name from inputdf
        next_pdb = input_df.iloc[i+1]['PDB ID']

        u = download_PDB(pdbname, input_df) #MD
        
        if bool(isinstance((input_df.iloc[1]['Res ID']), float) == True): #get residue number from input df
            val=math.trunc(input_df.iloc[i]['Res ID'])      
        else:
            val=input_df.iloc[i]['Res ID'] 

        resid_chain_in=f"resid {val} and segid {input_df.iloc[i]['Chain']}"   #atom selection string
        ID = "-".join([pdbname,str(input_df.iloc[i]['Res Name']),str(val), str(input_df.iloc[i]['Chain'])])

        #Get info for titratable residue
        res_name_tit, res_ids_tit, atomname_tit, x_tit, y_tit, z_tit = titratable_res_info(resid_chain_in, u) 
       
        #get info for the vicinity. this returns entire cutout info.
        resnam_vicinity, res_ids_vicinity, atomname_vicinity, x_vic_tit, y_vic_tit, z_vic_tit = MDAvicinity(in_cutoff, resid_chain_in, u, res_name_tit, res_ids_tit, atomname_tit, x_tit, y_tit, z_tit)
        

        pos=np.vstack((x_vic_tit, y_vic_tit, z_vic_tit))
   

        #get the z values
        for i in range(len(atomname_vicinity)):
            for d in dictionary.values():
                if d['Element_Symbol'] == atomname_vicinity[i][0]:
                    z=np.append(z, int(d['Atomic_No']))
        

        for i in range(len(z)):
            atom=Atom(z[i], #the first element of the atom names gives the regular atom name 
                position=(pos[0][i],pos[1][i], pos[2][i]))
            one_res_atoms.append(atom)
        one_res_Atoms=Atoms(one_res_atoms)
        list_Atoms_objs.append(one_res_Atoms)

        df = pd.DataFrame({"PDB": pdbname,
                    "Res Name": resnam_vicinity,
                    "Res No": res_ids_vicinity,
                    "Atom Name": atomname_vicinity,
                    "Z": z,
                    "x": x_vic_tit,
                    "y":y_vic_tit,
                    "z":z_vic_tit})
        
        df = df.astype({'Res No': 'uint8',
                       "PDB": "category",
                       "Res Name": "category",
                       "Atom Name": "category",
                       "Z": "uint8",
                       })

        if pdbname != next_pdb:
            os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")

        df.to_csv(ID, sep=',', index=False, encoding='utf-8')

        
        
    return Atoms(one_res_atoms), z


plz=dd.read_csv(pkPDB_CSV, delimiter=';', na_filter=False, dtype={'idcode':'category', 
                                                                  'residue_number':'uint8',
                                                                  'pk': 'float32',
                                                                  'residue_name':'category',
                                                                  'chain': 'category',
                                                                  'residue_name': 'category'
                                                                  })



                                                                  
plz=plz.rename(columns={'idcode': 'PDB ID', 'residue_number': 'Res ID', 'residue_name': 'Res Name', 'residue_number': 'Res ID', 'pk': 'pKa', 'chain' : 'Chain'}) #rename columns to match df from pkad 
plz=plz.sort_values(['PDB ID', 'Res ID'], ascending=[True, True]) #sorts both



In [5]:
full=plz.compute()
dff = full.reset_index()

In [6]:
one_res_atoms, z=neighborhoods(dff, 10)
z = one_res_atoms.get_atomic_numbers()

Structure exists: 'PDB/pdb107l.ent' 


In [7]:

def repeat_number(num, n):
    return [num for _ in range(n)]

# Example usage:
number = 10
times = len(z)
r = repeat_number(number, times)

In [8]:
one_res_atoms.set_cell([[1,0,0], [0,1,0], [0,0,1]])

In [9]:

from matscipy.neighbours import neighbour_list as msp_neighbor_list #quantities, atoms, cutoff, positions, cell, pbc, numbers, cell_origin
#'x', 'f_ij', 'idx_i', 'idx_j', and 'rcut_ij'

d, i, j = msp_neighbor_list('dij',  one_res_atoms, r)

inputs = {'Z':torch.tensor(z).long(), 'R':torch.tensor(d).float(), 'idx_i':torch.tensor(i).long(), 'idx_j': torch.tensor(j).long()}


In [11]:
weights = torch.load('tensor_dict.pth')
output_weight = torch.load('output_tensor.pth')

In [12]:
r=10 #this is a "copy" of the model
Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))

In [13]:
torch.set_printoptions(profile="full")
for keys, weight in weights.items():
    left = f"Model.{keys}" 
    right = f"torch.nn.Parameter(torch.{weight})"
    execu = f"{left} = {right}"
    st = execu.replace("\n       ","")
    st2 = st.replace("representation.","")
    try:
        exec(st2)
    finally:
        right2 = f"torch.nn.Parameter({weight})"
        E=f"{left} = {right2}"
        s = execu.replace("\n       ","")
        s2 = s.replace("representation.","")
        exec(s2)



In [14]:
modelll=nn.Sequential(Dense(128,64), Dense(64,1))

In [15]:
with torch.no_grad():
    modelll[0].weight = nn.Parameter(output_weight.get('dense1_weight'))
    modelll[0].bias = nn.Parameter(output_weight.get('dense1_bias'))
    modelll[1].weight = nn.Parameter(output_weight.get('dense2_weight'))
    modelll[1].bias = nn.Parameter(output_weight.get('dense2_bias'))


In [19]:
outputs = Model(inputs)
E = modelll(outputs.get('scalar_representation'))

In [22]:
E.detach().numpy()

array([[-24.968233 ],
       [-34.642582 ],
       [-27.53616  ],
       [-23.076633 ],
       [-22.824928 ],
       [-29.696007 ],
       [-43.31691  ],
       [-33.813667 ],
       [-31.105223 ],
       [-34.6015   ],
       [-35.87332  ],
       [-37.538483 ],
       [-27.315813 ],
       [-20.648262 ],
       [-26.059317 ],
       [-16.576513 ],
       [-37.14406  ],
       [-41.554157 ],
       [-30.114893 ],
       [-41.912163 ],
       [-31.138823 ],
       [-27.01241  ],
       [-18.218912 ],
       [-29.475584 ],
       [-29.92536  ],
       [-21.432056 ],
       [-18.074795 ],
       [-20.619051 ],
       [-20.77631  ],
       [-13.896519 ],
       [-16.067476 ],
       [-15.32528  ],
       [ -8.027823 ],
       [ -8.806393 ],
       [ -8.797975 ],
       [-26.563004 ],
       [-32.203    ],
       [-33.79385  ],
       [-33.87273  ],
       [-30.016605 ],
       [-19.771484 ],
       [-23.131031 ],
       [-24.639902 ],
       [-15.701159 ],
       [-33.33359  ],
       [-5

In [23]:
Sym = one_res_atoms.get_chemical_symbols()

In [29]:
out = {}
for i in range(len(Sym)):
    out[E[i]] = Sym[i]


In [None]:
out

In [17]:
#dummy model
dm_model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))

o = dm_model(inputs)
modelll(o.get('scalar_representation'))


tensor([[ 0.8105],
        [-0.7500],
        [-0.3105],
        [-0.2324],
        [-0.2299],
        [-0.0066],
        [ 0.5426],
        [-0.1250],
        [ 1.0211],
        [-0.5111],
        [-0.5129],
        [-0.0301],
        [-0.5932],
        [-0.5599],
        [-0.0209],
        [ 1.0835],
        [ 0.7194],
        [-0.5708],
        [-0.4243],
        [ 0.3846],
        [-0.9400],
        [-0.7370],
        [-0.6473],
        [ 0.0464],
        [ 1.1522],
        [-0.4586],
        [-0.6325],
        [-0.1977],
        [-0.4990],
        [-0.6901],
        [-0.5964],
        [-0.6512],
        [-0.6714],
        [-0.6052],
        [-0.5588],
        [ 0.7246],
        [-0.5413],
        [-0.6438],
        [-0.1895],
        [-0.4725],
        [-0.4425],
        [-0.3022],
        [-0.3247],
        [-0.2020],
        [ 0.9255],
        [-0.4148],
        [-0.6679],
        [ 0.2350],
        [-0.1922],
        [-0.3122],
        [ 0.4548],
        [ 0.0316],
        [ 0.

In [None]:
len(z)

len(z)

In [209]:
from matscipy.neighbours import neighbour_list as msp_neighbor_list #quantities, atoms, cutoff, positions, cell, pbc, numbers, cell_origin
#'x', 'f_ij', 'idx_i', 'idx_j', and 'rcut_ij'
from ase import Atom, Atoms

a = Atoms([1,8,1], [[1,0,0], [.5, 1,0], [1,0,0]], pbc=[True, True, True], cell=[[1,0,0], [0,1,0], [0,0,1]])
d, i, j = msp_neighbor_list('dij',  a, r)

rcut_ij = CosineCutoff(r)(torch.tensor(d).float()) #r, d_ij
f=GaussianRBF(n_rbf, r)(torch.tensor(d).float()) #rbf, r 
embedding=Embedding(9,n_feats)
x=embedding(z) #dictionary size, embedding_dim(#features)

lin1 = Linear(n_filters, n_filters // 2)
lin2 = Linear(n_filters // 2, 1)

n_interactions = 50
interactions = snn.replicate_module(
            lambda: SchNetInteraction(
                n_atom_basis=n_feats,
                n_rbf=n_rbf,
                n_filters=n_filters,
            ),n_interactions,0)

for interaction in interactions:
    v = SchNetInteraction(n_feats,n_rbf,n_filters,  activation=shifted_softplus)(x, f, torch.tensor(i).long(), torch.tensor(j).long(), rcut_ij) #features #rbf #filters
    x = x + v






from typing import Callable, Dict

import torch
from torch import nn

import schnetpack.properties as structure
from schnetpack.nn import Dense, scatter_add
from schnetpack.nn.activations import shifted_softplus

import schnetpack.nn as snn

from matscipy.neighbours import neighbour_list as msp_neighbor_list #quantities, atoms, cutoff, positions, cell, pbc, numbers, cell_origin
#'x', 'f_ij', 'idx_i', 'idx_j', and 'rcut_ij'
from ase import Atom, Atoms

d, i, j = msp_neighbor_list('dij',  a, r)


__all__ = ["SchNet", "SchNetInteraction"]


class SchNetInteraction(nn.Module):
    r"""SchNet interaction block for modeling interactions of atomistic systems."""

    def __init__(
        self,
        n_atom_basis: int,
        n_rbf: int,
        n_filters: int,
        activation: Callable = shifted_softplus,
    ):
        """
        Args:
            n_atom_basis: number of features to describe atomic environments.
            n_rbf (int): number of radial basis functions.
            n_filters: number of filters used in continuous-filter convolution.
            activation: if None, no activation function is used.
        """
        super(SchNetInteraction, self).__init__()
        self.in2f = Dense(n_atom_basis, n_filters, bias=False, activation=None)
        self.f2out = nn.Sequential(
            Dense(n_filters, n_atom_basis, activation=activation),
            Dense(n_atom_basis, n_atom_basis, activation=None),
        )
        self.filter_network = nn.Sequential(
            Dense(n_rbf, n_filters, activation=activation), Dense(n_filters, n_filters)
        )

    def forward(
        self,
        x: torch.Tensor,
        f_ij: torch.Tensor,
        idx_i: torch.Tensor,
        idx_j: torch.Tensor,
        rcut_ij: torch.Tensor,
    ):
        """Compute interaction output.

        Args:
            x: input values
            Wij: filter
            idx_i: index of center atom i
            idx_j: index of neighbors j

        Returns:
            atom features after interaction
        """
        x = self.in2f(x)
        Wij = self.filter_network(f_ij)
        Wij = Wij * rcut_ij[:, None]

        # continuous-filter convolution
        x_j = x[idx_j]
        x_ij = x_j * Wij
        x = scatter_add(x_ij, idx_i, dim_size=x.shape[0])

        x = self.f2out(x)
        return x


class SchNet(nn.Module):
    """SchNet architecture for learning representations of atomistic systems

    References:

    .. [#schnet1] Schütt, Arbabzadah, Chmiela, Müller, Tkatchenko:
       Quantum-chemical insights from deep tensor neural networks.
       Nature Communications, 8, 13890. 2017.
    .. [#schnet_transfer] Schütt, Kindermans, Sauceda, Chmiela, Tkatchenko, Müller:
       SchNet: A continuous-filter convolutional neural network for modeling quantum
       interactions.
       In Advances in Neural Information Processing Systems, pp. 992-1002. 2017.
    .. [#schnet3] Schütt, Sauceda, Kindermans, Tkatchenko, Müller:
       SchNet - a deep learning architecture for molceules and materials.
       The Journal of Chemical Physics 148 (24), 241722. 2018.

    """

    def __init__(
        self,
        n_atom_basis: int,
        n_interactions: int,
        radial_basis: nn.Module,
        cutoff_fn: Callable,
        n_filters: int = None,
        shared_interactions: bool = False,
        max_z: int = 100,
        activation: Callable = shifted_softplus,
    ):
        """
        Args:
            n_atom_basis: number of features to describe atomic environments.
                This determines the size of each embedding vector; i.e. embeddings_dim.
            n_interactions: number of interaction blocks.
            radial_basis: layer for expanding interatomic distances in a basis set
            cutoff_fn: cutoff function
            n_filters: number of filters used in continuous-filter convolution
            shared_interactions: if True, share the weights across
                interaction blocks and filter-generating networks.
            max_z: maximal nuclear charge
            activation: activation function
        """
        super().__init__()
        self.n_atom_basis = n_atom_basis
        self.size = (self.n_atom_basis,)
        self.n_filters = n_filters or self.n_atom_basis
        self.radial_basis = radial_basis
        self.cutoff_fn = cutoff_fn
        self.cutoff = cutoff_fn.cutoff

        # layers
        self.embedding = nn.Embedding(max_z, self.n_atom_basis, padding_idx=0)

        self.interactions = snn.replicate_module(
            lambda: SchNetInteraction(
                n_atom_basis=self.n_atom_basis,
                n_rbf=self.radial_basis.n_rbf,
                n_filters=self.n_filters,
                activation=activation,
            ),
            n_interactions,
            shared_interactions,
        )

    def forward(self, inputs: Dict[str, torch.Tensor]):
        atomic_numbers = inputs['Z']  # Use "Z" instead of Z
        d_ij = inputs['R']  # Use "Rij" instead of Rij
        idx_i = inputs['idx_i']  # Use "idx_i" instead of idx_i
        idx_j = inputs['idx_j']  # Use "idx_j" instead of idx_j

        # compute atom and pair features
        x = self.embedding(atomic_numbers)
        #d_ij = torch.norm(r_ij, dim=1)

        f_ij = self.radial_basis(d_ij)
        rcut_ij = self.cutoff_fn(d_ij) #here

        # compute interaction block to update atomic embeddings
        for interaction in self.interactions:
            v = interaction(x, f_ij, idx_i, idx_j, rcut_ij)
            x = x + v

        inputs["scalar_representation"] = x
        return inputs


In [342]:
from schnetpack.representation.schnet import SchNet
r=10 #this is a "copy" of the model
Model = SchNet(n_atom_basis=128, n_interactions=6, radial_basis=GaussianRBF(50, r), cutoff_fn=CosineCutoff(r))

SchNet(
  (radial_basis): GaussianRBF()
  (cutoff_fn): CosineCutoff()
  (embedding): Embedding(100, 128, padding_idx=0)
  (interactions): ModuleList(
    (0-5): 6 x SchNetInteraction(
      (in2f): Dense(
        in_features=128, out_features=128, bias=False
        (activation): Identity()
      )
      (f2out): Sequential(
        (0): Dense(in_features=128, out_features=128, bias=True)
        (1): Dense(
          in_features=128, out_features=128, bias=True
          (activation): Identity()
        )
      )
      (filter_network): Sequential(
        (0): Dense(in_features=50, out_features=128, bias=True)
        (1): Dense(
          in_features=128, out_features=128, bias=True
          (activation): Identity()
        )
      )
    )
  )
)

In [None]:
from matscipy.neighbours import neighbour_list as msp_neighbor_list #quantities, atoms, cutoff, positions, cell, pbc, numbers, cell_origin
#'x', 'f_ij', 'idx_i', 'idx_j', and 'rcut_ij'

a = Atoms([1,8,1], [[1,0,0], [.5, 1,0], [1,0,0]], pbc=[True, True, True], cell=[[1,0,0], [0,1,0], [0,0,1]])
d, i, j = msp_neighbor_list('dij',  a, r)

rcut_ij = CosineCutoff(r)(torch.tensor(d).float()) #r, d_ij
f=GaussianRBF(n_rbf, r)(torch.tensor(d).float()) #rbf, r 
embedding=Embedding(9,n_feats)
x=embedding(z) #dictionary size, embedding_dim(#features)

lin1 = Linear(n_filters, n_filters // 2)
lin2 = Linear(n_filters // 2, 1)

n_interactions = 50
interactions = snn.replicate_module(
            lambda: SchNetInteraction(
                n_atom_basis=n_feats,
                n_rbf=n_rbf,
                n_filters=n_filters,
            ),n_interactions,0)

for interaction in interactions:
    v = SchNetInteraction(n_feats,n_rbf,n_filters,  activation=shifted_softplus)(x, f, torch.tensor(i).long(), torch.tensor(j).long(), rcut_ij) #features #rbf #filters
    x = x + v
    print(x)
    h = lin1(x)
    h = shifted_softplus(h)
    out = lin2(h)

print(out)
    
#x = reset_parameters(out)


def reset_parameters(x):
    r"""Resets all learnable parameters of the module."""

    h = embedding(z)
    d, i, j = msp_neighbor_list('dij',  a, r)

    for interaction in interactions:
        v = SchNetInteraction(n_feats,n_rbf,n_filters, activation=shifted_softplus)(x, f, torch.tensor(i).long(), torch.tensor(j).long(), rcut_ij) #features #rbf #filters
        x = x + v

    h = lin1(x)
    h = shifted_softplus(h)
    out = lin2(h)

    return out

#print(x)

