## SchNet evaluation
* Example script for inference with a trained SchNet model

In [1]:
import os
import numpy as np
import torch

from ase.io import read, write

from schnetpack.environment import SimpleEnvironmentProvider
from schnetpack import AtomsLoader, AtomsData

In [16]:
# set up arguments
class Args:
    modelpath = './model/'
    cuda = 'cpu'
    datapath = './data/W10_geoms_all_lowestE.xyz'
    index = 0
    
args=Args()

### Encode water clusters in ASE database
* pass in .xyz files of water clusters
* https://sites.uw.edu/wdbase/database-of-water-clusters/

In [13]:
def data_loader(datapath, idx=0):
    atoms = read(datapath, index=f'{idx}:{idx+1}')
    at = atoms[0]
    
    property_list = {'energy': np.array([float(list(at.info.keys())[0])], dtype=np.float32)}
    
    new_dataset={}

    atom_positions = at.positions.astype(np.float32)
    atom_positions -= at.get_center_of_mass() 

    environment_provider = SimpleEnvironmentProvider()
    nbh_idx, offsets = environment_provider.get_environment(at)

    new_dataset['_atomic_numbers'] = torch.LongTensor(at.numbers.astype(np.int))
    new_dataset['_positions'] = torch.FloatTensor(atom_positions)
    new_dataset['_cell'] = torch.FloatTensor(at.cell.astype(np.float32))
    new_dataset['_neighbors'] = torch.LongTensor(nbh_idx.astype(np.int))
    new_dataset['_cell_offset'] = torch.FloatTensor(offsets.astype(np.float32))
    new_dataset['_idx'] = torch.LongTensor(np.array([idx], dtype=np.int))

    return AtomsLoader([new_dataset], batch_size=1), property_list


data_loader, property_list =  data_loader(args.datapath, idx=args.index)

### Load pre-trained SchNet model
* best_model trained on 500k water clusters
* https://aip.scitation.org/doi/full/10.1063/5.0009933

In [17]:
# load model
model = torch.load(os.path.join(args.modelpath, "best_model"), map_location=args.cuda)

# model created using DataParallel
model = torch.nn.DataParallel(model.module)

In [18]:
model.module.parameters

<bound method Module.parameters of AtomisticModel(
  (representation): SchNet(
    (embedding): Embedding(100, 100, padding_idx=0)
    (distances): AtomDistances()
    (distance_expansion): GaussianSmearing()
    (interactions): ModuleList(
      (0): SchNetInteraction(
        (filter_network): Sequential(
          (0): Dense(in_features=25, out_features=100, bias=True)
          (1): Dense(in_features=100, out_features=100, bias=True)
        )
        (cutoff_network): HardCutoff()
        (cfconv): CFConv(
          (in2f): Dense(in_features=100, out_features=100, bias=False)
          (f2out): Dense(in_features=100, out_features=100, bias=True)
          (filter_network): Sequential(
            (0): Dense(in_features=25, out_features=100, bias=True)
            (1): Dense(in_features=100, out_features=100, bias=True)
          )
          (cutoff_network): HardCutoff()
          (agg): Aggregate()
        )
        (dense): Dense(in_features=100, out_features=100, bias=True)
   

In [19]:
# predict energy
with torch.no_grad():
    for batch in data_loader:
        out_data = model(batch)
        
print(f"Actual:    {property_list['energy']}")
print(f"Predicted: {out_data['energy'].numpy()[0]}")

Actual:    [-94.670654]
Predicted: [-94.671135]
