In [1]:
import math
import torch
from torch.autograd.gradcheck import zero_gradients
from torch.autograd import grad

from projects.NeuralForceField.train import *
from projects.NeuralForceField.graphs import * 
from projects.NeuralForceField.tensorgrad import *

import numpy as np

In [2]:
# ethanol data 
train = np.load('/home/wwj/data/ethanol_ccsd_t-train.npz')
xyz_data = np.dstack((np.array([train.f.z]*1000).reshape(1000, -1, 1), np.array(train.f.R)))
force_data = train.f.F
energy_data = train.f.E.squeeze()

graph_data = load_graph_data(xyz_data=xyz_data, energy_data=energy_data, batch_size=2, cutoff=5.0,
                             force_data=force_data, au_flag=False, subtract_mean_flag = True)

In [3]:
# initialize parameters
par = dict()

par["n_filters"] = 256
par["n_gaussians"] = 32
par["n_atom_basis"] = 256
par["optim"] = 1e-4
par["scheduler"] = True
par["train_percentage"] = 0.8
par["T"] = 10
par["batch_size"] = 2
par["cutoff"] = 5.0 
par["max_epoch"] = 1000
par["trainable_gauss"] = True
par["rho"] = 0.1
par["eps"] = 1e-5

In [4]:
model = Model(par=par, graph_data=graph_data, device=3, job_name="hessian_test", graph_batching=False, root='./log/')

In [5]:
xyz, a, r, f, u, N = model.parse_batch(1)
xyz = xyz.reshape(-1, 9 ,3)
r = r.reshape(-1, 9)

# compute neural network hessian
Neural_hess(xyz=xyz, r=r, model=model.model, device=3)

tensor([[[ 30.5576,  -7.3615,  -1.7901,  ...,  11.0745,  -0.4226,   1.5851],
         [ -7.3615,  10.0585,   2.2802,  ...,   0.2150,   2.2880,   0.0945],
         [ -1.7901,   2.2802,  -8.6448,  ...,   1.5633,   0.1467,   2.0002],
         ...,
         [ 11.0745,   0.2150,   1.5633,  ..., -20.3048,   9.0847,  -8.3442],
         [ -0.4226,   2.2880,   0.1467,  ...,   9.0847,  -3.2379,  -2.6050],
         [  1.5851,   0.0945,   2.0002,  ...,  -8.3442,  -2.6050,  -5.4760]],

        [[ 17.6158, -10.8864,   9.7178,  ...,  15.4629,   6.6377,  -8.6067],
         [-10.8864,   7.7233,   3.4968,  ...,   6.8723,   3.0124,  -3.9666],
         [  9.7178,   3.4968, -14.1441,  ...,  -8.7672,  -4.1231,   4.9059],
         ...,
         [ 15.4629,   6.8723,  -8.7672,  ...,  11.4558,  -8.5331,   2.4387],
         [  6.6377,   3.0124,  -4.1231,  ...,  -8.5331,  -7.4124,   2.3447],
         [ -8.6067,  -3.9666,   4.9059,  ...,   2.4387,   2.3447,  -7.7900]]],
       device='cuda:3')