In [None]:
# install pyscf
import numpy
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from pyscf import gto
from scipy.io import savemat
from utils.buildBreadthFirst3d import buildBreadthFirst3d

# basis
mol = gto.M(
    verbose = 0,
    atom = '''
    o    0    0.       0.
    h    0    -0.757   0.587
    h    0    0.757    0.587''',
    basis = '6-31g')

dom = numpy.array([[-2.5], [2.5], [-2.5], [2.5], [-2.5], [2.5]])

# tree1
def pyscffunc1(x, y, z, mol):
  return numpy.array(mol.eval_gto('GTOval_sph',numpy.column_stack([x.flatten(),y.flatten(),z.flatten()])))
nd1 = 13
func1 = lambda x, y, z: pyscffunc1(x, y, z, mol)
tree1 = {
    'domain': dom, 
    'tol': 1.0e-4,
    'nSteps': 15,
    'level': numpy.array([0]),
    'height': numpy.array([0]),
    'id': numpy.array([0]), # 
    'parent': numpy.array([0]),
    'children': numpy.zeros((8,1)), 
    'coeffs': [],
    'col': numpy.array([0]),
    'row': numpy.array([0]),
    'n': 6,
    'checkpts': numpy.array([[0,    0,     0],
                             [0, -0.757, 0.757],
                             [0,  0.587, 0.587]]),
    'rint': numpy.array([[] for k in range(nd1)]),
    'vmax': numpy.array([[] for k in range(nd1)])                         
  }
tree1, rint = buildBreadthFirst3d(tree1, func1)
numpts = 51 # this needs to be consistent with the resolution in plot3dtree.m
xx, yy, zz = numpy.meshgrid(numpy.linspace(dom[0,0],dom[1,0],numpts),numpy.linspace(dom[2,0],dom[3,0],numpts),numpy.linspace(dom[4,0],dom[5,0],numpts),indexing='ij')
v1 = func1(xx.flatten(),yy.flatten(),zz.flatten())
savemat('tree1.mat', {'numpts': numpts, 'v': v1, 'xx': xx, 'yy': yy, 'zz': zz, 'rint': rint, 'fdomain': tree1['domain'], 'fn': tree1['n'], 'flevel': tree1['level'], 'fchildren': tree1['children'], 'fheight': tree1['height'], 'fid': tree1['id'], 'frint': tree1['rint'], 'ftol': tree1['tol'], 'fcheckpts': tree1['checkpts']})

# tree2
nd2 = 13**2
def pyscffunc2(x, y, z, mol):
  valstmp = numpy.array(mol.eval_gto('GTOval_sph',numpy.column_stack([x.flatten(),y.flatten(),z.flatten()])))
  nrows, ncols = valstmp.shape
  vals = numpy.zeros((nrows, ncols**2))
  for j in range(ncols):
    for k in range(ncols):
        vals[:, j*ncols+k] = valstmp[:, j] * valstmp[:, k]
  return vals
func2 = lambda x, y, z: pyscffunc2(x, y, z, mol)
tree2 = {
    'domain': dom, 
    'tol': 1.0e-4,
    'nSteps': 15,
    'level': numpy.array([0]),
    'height': numpy.array([0]),
    'id': numpy.array([0]), # 
    'parent': numpy.array([0]),
    'children': numpy.zeros((8,1)), 
    'coeffs': [],
    'col': numpy.array([0]),
    'row': numpy.array([0]),
    'n': 6,
    'checkpts': numpy.array([[0,    0,     0],
                             [0, -0.757, 0.757],
                             [0,  0.587, 0.587]]),
    'rint': numpy.array([[] for k in range(nd2)]),
    'vmax': numpy.array([[] for k in range(nd2)])                         
  }
numpts = 51 # this needs to be consistent with the resolution in plot3dtree.m
tree2, rint = buildBreadthFirst3d(tree2, func2)
xx, yy, zz = numpy.meshgrid(numpy.linspace(dom[0,0],dom[1,0],numpts),numpy.linspace(dom[2,0],dom[3,0],numpts),numpy.linspace(dom[4,0],dom[5,0],numpts),indexing='ij')
v2 = func2(xx.flatten(),yy.flatten(),zz.flatten())
savemat('tree2.mat', {'numpts': numpts, 'v': v2, 'xx': xx, 'yy': yy, 'zz': zz, 'rint': rint, 'fdomain': tree2['domain'], 'fn': tree2['n'], 'flevel': tree2['level'], 'fchildren': tree2['children'], 'fheight': tree2['height'], 'fid': tree2['id'], 'frint': tree2['rint'], 'ftol': tree2['tol'], 'fcheckpts': tree2['checkpts']})


# tree3
nd3 = 1
def pyscffunc3(x, y, z, mol):
  valstmp = numpy.array(mol.eval_gto('GTOval_sph',numpy.column_stack([x.flatten(),y.flatten(),z.flatten()])))
  nrows, ncols = valstmp.shape
  vals = numpy.zeros((nrows))
  for j in range(ncols):
    vals = vals + valstmp[:, j]**2
  return vals
func3 = lambda x, y, z: pyscffunc3(x, y, z, mol)
tree3 = {
    'domain': dom, 
    'tol': 1.0e-4,
    'nSteps': 15,
    'level': numpy.array([0]),
    'height': numpy.array([0]),
    'id': numpy.array([0]), # 
    'parent': numpy.array([0]),
    'children': numpy.zeros((8,1)), 
    'coeffs': [],
    'col': numpy.array([0]),
    'row': numpy.array([0]),
    'n': 6,
    'checkpts': numpy.array([[0,    0,     0],
                             [0, -0.757, 0.757],
                             [0,  0.587, 0.587]]),
    'rint': numpy.array([[] for k in range(nd3)]),
    'vmax': numpy.array([[] for k in range(nd3)])                         
  }
numpts = 51 # this needs to be consistent with the resolution in plot3dtree.m
tree3, rint = buildBreadthFirst3d(tree3, func3)
xx, yy, zz = numpy.meshgrid(numpy.linspace(dom[0,0],dom[1,0],numpts),numpy.linspace(dom[2,0],dom[3,0],numpts),numpy.linspace(dom[4,0],dom[5,0],numpts),indexing='ij')
v3 = func3(xx.flatten(),yy.flatten(),zz.flatten())
savemat('tree3.mat', {'numpts': numpts, 'v': v3, 'xx': xx, 'yy': yy, 'zz': zz, 'rint': rint, 'fdomain': tree3['domain'], 'fn': tree3['n'], 'flevel': tree3['level'], 'fchildren': tree3['children'], 'fheight': tree3['height'], 'fid': tree3['id'], 'frint': tree3['rint'], 'ftol': tree3['tol'], 'fcheckpts': tree3['checkpts']})


In [None]:
# plot, move to plot3dtree.py when ready
import plotly.graph_objects as go

tree = tree2 # tree1, tree3

# hopefully, I get the indexing correct, double check
ids = tree['id'][tree['height'] == 0]
xdata = numpy.vstack([tree['domain'][[[0],[1],[1],[0],[0],[0],[1],[1],[0],[0]],ids],\
                      numpy.full((1, ids.size), numpy.nan),\
                      tree['domain'][[[1],[1],[1],[1],[0],[0]], ids],\
                      numpy.full((1, ids.size), numpy.nan)])
xdata = xdata.transpose().flatten()
ydata = numpy.vstack([tree['domain'][[[2],[2],[3],[3],[2],[2],[2],[3],[3],[2]],ids],\
                      numpy.full((1, ids.size), numpy.nan),\
                      tree['domain'][[[2],[2],[3],[3],[3],[3]], ids],\
                      numpy.full((1, ids.size), numpy.nan)])
ydata = ydata.transpose().flatten()
zdata = numpy.vstack([tree['domain'][[[4],[4],[4],[4],[4],[5],[5],[5],[5],[5]],ids],\
                      numpy.full((1, ids.size), numpy.nan),\
                      tree['domain'][[[4],[5],[5],[4],[4],[5]], ids],\
                      numpy.full((1, ids.size), numpy.nan)])
zdata = zdata.transpose().flatten()

fig = go.Figure(data=go.Scatter3d(x=xdata, y=ydata, z=zdata, mode='lines+markers', line=dict(color='blue', width=2), marker=dict(symbol='circle', size=1)))

fig.update_layout(
  scene=dict(
    xaxis=dict(title='X-axis'),
    yaxis=dict(title='Y-axis'),
    zaxis=dict(title='Z-axis')
  ),
  title='3D Line Plot',
  showlegend=False,
  template='plotly_white'
)

fig.show()