In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pickle
import torch
import random
from functools import partial

from e3nn.kernel import Kernel
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.non_linearities import rescaled_act
from e3nn.non_linearities.rescaled_act import relu, sigmoid
from e3nn.radial import CosineBasisModel
from e3nn.radial import GaussianRadialModel

torch.set_default_dtype(torch.float64)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
## load density data
picklename = "./density_data/dimer_data.pckl"
with open(picklename, 'rb') as f:
    dataset_coeffs, dataset_onehot, dataset_geom, dataset_typemap, Rs_out_list, coeff_by_type = pickle.load(f)


In [None]:
from networks import *

## set arguments to network
maxradius = 3.0
numbasis = 20
radiallayers = 3
radialbasis = "Gaussian"
## set Rs_in based on onehot vector
Rs_in = [(len(dataset_typemap[0]),0)]

print("Rs_in:",Rs_in)
print("\nOxygen Rs_out:",Rs_out_list[0])
print("Hydrogen Rs_out:",Rs_out_list[1])

mydict = {"Rs_in":Rs_in, "Rs_out_list":(Rs_out_list), "max_radius":maxradius,
            "number_of_basis":numbasis, "radial_layers":radiallayers, 
            "basistype":radialbasis}

#net = MixerNetwork(**mydict)
net = SplitNetwork(**mydict)

print(net)


#net.to(device)

In [None]:
## set up training

net.train()

optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)
optimizer.zero_grad()
loss_fn = torch.nn.modules.loss.MSELoss()

max_steps = 2000
minibatch_size = 16

print (device)


In [None]:
loss_minibatch = 0
for step in range(max_steps):
    i = random.randint(0, len(dataset_geom) - 3001)

    onehot = dataset_onehot[i]
    points = dataset_geom[i]
    atom_type_map = dataset_typemap[i]
    coeffs = dataset_coeffs[i]

    outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
    outputO = torch.flatten(outputO)
    outputH = torch.flatten(outputH)
    output = torch.cat((outputO,outputH),0).view(1,1,-1)

    loss = loss_fn(output, coeffs)
    step_loss = loss.item()
    loss.backward()
    loss_minibatch += step_loss

    if (step+1)%minibatch_size == 0:
        optimizer.step()
        optimizer.zero_grad()
        loss_minibatch = 0

    if step % 100 == 0:
        print('\nStep {0}, Loss {1}'.format(step, step_loss))
        j = random.randint(3000, len(dataset_geom) - 1)

        onehot = dataset_onehot[j]
        points = dataset_geom[j]*3
        atom_type_map = dataset_typemap[j]
        coeffs = dataset_coeffs[j]

        outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)
        outputO = torch.flatten(outputO)
        outputH = torch.flatten(outputH)
        output = torch.cat((outputO,outputH),0).view(1,1,-1)

        loss = loss_fn(output.to(device), coeffs.to(device))
        print('\nTest Loss {0}'.format(loss.item()))



In [None]:
from density_analysis_utils import *

testnumelectrons(net,device,2,"./density_data/a2.gbs",dataset_onehot,dataset_geom,dataset_typemap,coeff_by_type)

In [None]:
from density_analysis_utils import *
from e3nn.rs import dim, mul_dim

## define Gaussian Type Orbital basis functions
basis = lambda r, alpha, norm : norm * torch.exp(- alpha * r.unsqueeze(-1) **2)

## get exponent alphas
alphaO, alphaH = get_exponents('./density_data/a2.gbs')

## get normalization constants
normO, normH = parse_whole_normfile('./density_data/a2_norm.dat')
normO = torch.FloatTensor(normO)
normH = torch.FloatTensor(normH)

## get spherical harmonic normalization constants
Rs_out_O = Rs_out_list[0]
Rs_out_H = Rs_out_list[1]
sph_normsO, sph_normsH = get_spherical_harmonic_norms(Rs_out_O,Rs_out_H)

basis_on_r_O = partial(basis, alpha=alphaO, norm=normO)
basis_on_r_H = partial(basis, alpha=alphaH, norm=normH)

assert mul_dim(Rs_out_O) == normO.shape[0]
assert mul_dim(Rs_out_H) == normH.shape[0]

In [None]:
# pick a random structure to test

dimer_num = 4321
onehot = dataset_onehot[dimer_num]
points = dataset_geom[dimer_num]
atom_type_map = dataset_typemap[dimer_num]
outputO, outputH = net(onehot.to(device),points.to(device),atom_type_map)

outputO = outputO.data.cpu().numpy()
outputH = outputH.data.cpu().numpy()

In [None]:
from spherical import plot_data_on_grid
import e3nn.o3 as o3

## get the functions
f_list = []
# loop over types
for i, type in enumerate(atom_type_map):
    # loop over atoms
    for count, atom in enumerate(type):
        tot_f = 0
        center = points.data.squeeze().numpy()[atom]
        # oxygens
        if i == 0:
            #vsf = VisualizeSphericalFunction(basis_on_r_O, Rs_out_O, o3.spherical_harmonics_xyz)
            r, f = plot_data_on_grid(5.0, basis_on_r_O, Rs_out_O,
                                         n=20, center=center)
            for j, val in enumerate(outputO.squeeze()[count]):
                c = val
                norm = sph_normsO[j]
                # sum up contributions from every basis function
                tot_f += c*f[:,j]/norm
        # hydrogens
        if i == 1:
            #vsf = VisualizeSphericalFunction(basis_on_r_H, Rs_out_H, o3.spherical_harmonics_xyz)
            r, f = plot_data_on_grid(5.0, basis_on_r_H, Rs_out_H,
                                         n=20, center=center)
            for j, val in enumerate(outputH.squeeze()[count]):
                c = val
                norm = sph_normsH[j]
                # sum up contributions from every basis function
                tot_f += c*f[:,j]/norm

        f_list.append(tot_f)

all_atom_f = sum(f_list)
print(all_atom_f.max())

In [None]:
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go

plot_max = float(all_atom_f.max())

fig = go.Figure(data=go.Volume(
    x=r[:,0],
    y=r[:,1],
    z=r[:,2],
    #value=c * f[:, i],
    value=all_atom_f,
    isomin=-0.005*plot_max,
    isomax=0.005*plot_max,
    #isomin=-0.03,
    #isomax=0.03,
    opacity=0.3, # needs to be small to see through all surfaces
    opacityscale="uniform",
    surface_count=50, # needs to be a large number for good volume rendering
    colorscale='RdBu'))
    
xs = points.data.squeeze().numpy()[:,0]
ys = points.data.squeeze().numpy()[:,1]
zs = points.data.squeeze().numpy()[:,2]
fig.add_scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=12,color='Black',opacity=1.0))

fig.show()