In [1]:
print("Loading numpy...")
import numpy as np
print("Loading torch...")
import torch
torch.set_default_dtype(torch.float64)
print("Loading e3nn...")
import e3nn
import torch_geometric as tg
print("Loading time...")
import time
from collections.abc import Mapping
print("Loading sparse_kernel_conv...")
from sparse_kernel_conv import SparseKernelConv, DummyConvolution
print("Loading laurent...")
from laurent import LaurentPolynomial
print("Loading functools...")
from functools import partial
print("Loading variable_networks...")
from variable_networks import VariableParityNetwork
print("Loading diagnostics...")
from diagnostics import print_parameter_size, count_parameters, get_object_size
print("Loading collections...")
from collections import deque
print("Loading copy...")
from copy import copy
print("Loading datetime...")
from datetime import timedelta
print("Loading re...")
import re
print("Loading sys...")
import sys
print("Loading os...")
import os
import traceback
print("Loading math...")
import math
print("Loading glob...")
from glob import glob
print("done loading modules.\n", flush=True)

os.environ["CUDA_VISIBLE_DEVICES"]=""
device = "cpu"
torch.device(device)
temp_tensor = torch.rand(10).to(device)
print("test tensor:")
print(temp_tensor)

Loading numpy...
Loading torch...
Loading e3nn...
Loading time...
Loading sparse_kernel_conv...
Loading laurent...
Loading functools...
Loading variable_networks...
Loading diagnostics...
Loading collections...
Loading copy...
Loading datetime...
Loading re...
Loading sys...
Loading os...
Loading math...
Loading glob...
done loading modules.

test tensor:
tensor([0.4655, 0.4685, 0.7253, 0.1272, 0.8389, 0.1653, 0.2154, 0.1551, 0.2570,
        0.5635])


In [2]:
# read the model from disk
model_filename = "offline/fluidstack_7-e003_b172869-checkpoint.torch"

print(f"Loading model from {model_filename}...", end="", flush=True)
model_dict = torch.load(model_filename, map_location=torch.device('cpu'))
print("done.", flush=True)

Loading model from offline/fluidstack_7-e003_b172869-checkpoint.torch...done.


In [3]:
model_kwargs = model_dict['model_kwargs']
for k,v in model_kwargs.items():
    print(k, ":", v)

kernel : <class 'sparse_kernel_conv.SparseKernelConv'>
convolution : <class 'sparse_kernel_conv.DummyConvolution'>
batch_norm : False
muls : [[30, 20, 10, 5, 5], [30, 20, 10, 5], [30, 30, 15], [30, 30, 15]]
lmaxes : [4, 3, 2, 2]
max_radius : 3.0
number_of_basis : 20
radial_h : 20
radial_layers : 1
n_norm : 8.0
batch_norm_momentum : 0.02
radial_model : None
Rs_in : [(7, 0, 1)]
Rs_out : [(1, 0, 1)]


In [4]:
model = VariableParityNetwork(**model_kwargs)

In [5]:
model.load_state_dict(model_dict["state_dict"])

<All keys matched successfully>

In [6]:
all_elements = model_dict["all_elements"]
print(all_elements)

[6, 1, 7, 8, 16, 9, 17]


In [7]:
# read a random geometry of acetone and its atomic numbers
import h5py
with h5py.File("../acetone/acetone-b3lyp_d3bj-631gd-gas-NMR-pcSseg_1.hdf5", "r") as h5:
    geoms_and_shieldings = np.array(h5.get("data"))
with h5py.File("../acetone/acetone-b3lyp_d3bj-631gd-gas-equilibrium_geometry.hdf5", "r") as h5:
    atomic_numbers = np.array(h5.get("atomic_numbers"))
# x,y,z,shielding
geometry = torch.tensor(geoms_and_shieldings[0,:,:3], dtype=torch.float64)
shieldings = torch.tensor(geoms_and_shieldings[0,:,-1], dtype=torch.float64)
print(geometry)
print(atomic_numbers)

tensor([[ 0.0187,  0.1518,  0.0397],
        [-0.0137,  1.3979,  0.0135],
        [-1.3000, -0.5873, -0.0089],
        [-1.0145, -1.4166, -0.5953],
        [-1.5199, -1.2240,  0.9375],
        [-2.2508,  0.0196, -0.4406],
        [ 1.2645, -0.5870, -0.0262],
        [ 1.5569, -1.0289,  1.0086],
        [ 1.2408, -1.4952, -0.7166],
        [ 2.1978,  0.0966, -0.3300]])
[6 8 6 1 1 1 6 1 1 1]


In [8]:
# generates one-hots for a list of atomic_symbols
def get_one_hots(atomic_symbols):
    one_hots = []
    for symbol in atomic_symbols:
        inner_list = [ 1. if symbol == i else 0. for i in all_elements ]
        one_hots.append(inner_list)
    return torch.tensor(one_hots, dtype=torch.float64)
one_hots = get_one_hots(atomic_numbers)
print(all_elements)
print(atomic_numbers)
print(one_hots)

[6, 1, 7, 8, 16, 9, 17]
[6 8 6 1 1 1 6 1 1 1]
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.]])


In [9]:
# turn input into a form that e3nn can understand
import e3nn.point.data_helpers as dh 
data = dh.DataNeighbors(x=one_hots, Rs_in = model_kwargs["Rs_in"], pos = geometry, r_max=model_kwargs["max_radius"], Rs_out=model_kwargs["Rs_out"])

In [10]:
model.eval()  # because of batchnorm
with torch.no_grad():  # do not compute gradients
    output = model(data.x, data.edge_index, data.edge_attr, n_norm=model_kwargs["n_norm"])
print(output) # this is the prediction of the model
print(shieldings) # sanity check

tensor([[ -32.3423],
        [-445.4471],
        [ 149.4786],
        [  30.5736],
        [  27.0376],
        [  26.9489],
        [ 152.5420],
        [  27.5167],
        [  28.2830],
        [  27.9769]])
tensor([  91.3635, -402.2257,  199.6065,   31.0486,   28.9889,   28.2823,
         200.1423,   29.0782,   29.3665,   28.2578])


In [23]:
print("  #     Z   predicted   expected   residual")
for i,atomic_number in enumerate(atomic_numbers):
    error = float(output[i])-float(shieldings[i])
    print(f"{i:3d}    {atomic_number:2d}    {float(output[i]):8.2f}   {float(shieldings[i]):8.2f}   {error:8.2f}")

  #     Z   predicted   expected   residual
  0     6      -32.34      91.36    -123.71
  1     8     -445.45    -402.23     -43.22
  2     6      149.48     199.61     -50.13
  3     1       30.57      31.05      -0.47
  4     1       27.04      28.99      -1.95
  5     1       26.95      28.28      -1.33
  6     6      152.54     200.14     -47.60
  7     1       27.52      29.08      -1.56
  8     1       28.28      29.37      -1.08
  9     1       27.98      28.26      -0.28
