In [1]:
import re 
from rdkit      import Chem
from rdkit.Chem import AllChem
from pyscf.gto.mole import Mole
import os
num_threads = 15
os.environ["OMP_NUM_THREADS"] = str(num_threads)
angstrom_to_bohr = 1.88973

def get_atom_string(atoms, locs):
    atom_string = atoms 
    atoms = re.findall('[a-zA-Z][^A-Z]*', atoms)
    str = ""
    for atom, loc in zip(atoms, locs): 
      str += "%s %4f %4f %4f; "%((atom,) + tuple(loc) )
    return atom_string, str 

smiles = [a for a in open("gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
smile = smiles[3000000//2:3000000//2+200][1]

print(">>>", smile)

atoms = [a for a in list(smile.upper()) if a == "C" or a == "N" or a == "O" or a == "F"]

b = Chem.MolFromSmiles(smile)
b = Chem.AddHs(b, explicitOnly=False)   

AllChem.EmbedMolecule(b)

atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
num_hs = len([a for a in atoms if a == "H"])
print(num_hs)
locs =  b.GetConformer().GetPositions() * angstrom_to_bohr  

atom_string, string = get_atom_string(" ".join(atoms), locs)

mol = Mole() 
mol = mol.build(mol, atom=string, unit="Bohr", basis="sto3g", spin=0, verbose=0)


>>> NC1CC(C=C1)=C(F)CO
10


# Initial Code (copied into one place from pyscf) 
includes test case and cpu timing

In [3]:
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape)
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(times[1:] - times[:-1])
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 

def original_becke(g):
    '''Becke, JCP 88, 2547 (1988); DOI:10.1063/1.454033'''
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    return g

from pyscf.data.elements import charge as elements_proton

def treutler_atomic_radii_adjust(mol, atomic_radii):
    '''Treutler atomic radii adjust function: [JCP 102, 346 (1995); DOI:10.1063/1.469408]'''
# JCP 102, 346 (1995)
# i > j
# fac(i,j) = \frac{1}{4} ( \frac{ra(j)}{ra(i)} - \frac{ra(i)}{ra(j)}
# fac(j,i) = -fac(i,j)
    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5
    #:return lambda i,j,g: g + a[i,j]*(1-g**2)
    def fadjust(i, j, g):
        g1 = g**2
        g1 -= 1.
        g1 *= -a[i,j]
        g1 += g
        return g1
    return fadjust


# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    if callable(radii_adjust) and atomic_radii is not None:
        f_radii_adjust = radii_adjust(mol, atomic_radii)
    else:
        f_radii_adjust = None
    atm_coords = numpy.asarray(mol.atom_coords() , order='C')
    atm_dist = gto.inter_distance(mol)


    coords_all = []
    weights_all = []
    for ia in range(mol.natm):  
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]
        #pbecke = gen_grid_partition(coords+atm_coords[ia])
        ngrids = coords.shape[0]
        #grid_dist = numpy.empty((mol.natm,ngrids))
        coords = coords+atm_coords[ia]
        grid_dist = numpy.empty((mol.natm,ngrids))
        for _ia in range(mol.natm):
            dc = coords - atm_coords[_ia]
            grid_dist[_ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc))

        pbecke = numpy.ones((mol.natm,ngrids))
        for i in range(mol.natm):
            for j in range(i): # O(natm**3) calls to 'original_becke'. 
                g = 1/atm_dist[i,j] * (grid_dist[i]-grid_dist[j]) # there are only n**2 values of this but we compute it n**3 times. 
                g = f_radii_adjust(i, j, g) #  get this function 
                #g = becke_scheme(g)# gets passed the one which returns None 
                print(g.shape)
                g = original_becke(g) # g is ~ (1100) for {C,N,O,F} and ~(560) for {H}
                #print(g)
                pbecke[i] *= .5 * (1-g)
                pbecke[j] *= .5 * (1+g)

        weights = vol * pbecke[ia] * (1./pbecke.sum(axis=0)) # only use (natm, natm) for this normalization. 
        coords_all.append(coords)
        weights_all.append(weights)

    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, treutler_atomic_radii_adjust, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  assert np.allclose(grids1.weights, grids2.weights)


(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)
(1086,)


KeyboardInterrupt: 

# Vectorized version

In [33]:
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape)
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(times[1:] - times[:-1])
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 

def original_becke(g):
    '''Becke, JCP 88, 2547 (1988); DOI:10.1063/1.454033'''
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    return g

from pyscf.data.elements import charge as elements_proton

def treutler_atomic_radii_adjust(mol, atomic_radii):
    '''Treutler atomic radii adjust function: [JCP 102, 346 (1995); DOI:10.1063/1.469408]'''
# JCP 102, 346 (1995)
# i > j
# fac(i,j) = \frac{1}{4} ( \frac{ra(j)}{ra(i)} - \frac{ra(i)}{ra(j)}
# fac(j,i) = -fac(i,j)
    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5
    #:return lambda i,j,g: g + a[i,j]*(1-g**2)
    def fadjust(i, j, g):
        g1 = g**2
        g1 -= 1.
        g1 *= -a[i,j]
        g1 += g
        return g1
    return fadjust


# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)

    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5
    #:return lambda i,j,g: g + a[i,j]*(1-g**2)
    def fadjust(i, j, g):
        g1 = g**2
        g1 -= 1.
        g1 *= -a[i,j]
        g1 += g
        return g1

    atm_coords = mol.atom_coords() 
    atm_dist   = gto.inter_distance(mol)

    coords_all  = []
    weights_all = []
    natm = mol.natm
    for ia in range(mol.natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)] # there are only n_atm_types different ones here! 5 for {C,N,O,F,H}
        ngrids      = coords.shape[0]
        coords      = coords+atm_coords[ia]

        # Block 1:  For loop 
        #grid_dist   = numpy.empty((mol.natm,ngrids))
        #for _ia in range(mol.natm):
        #    #dc = coords - atm_coords[_ia]
        #    #assert np.allclose(dc, dcs[_ia])
        #    dc = dcs[_ia]
        #    grid_dist[_ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc))
        #grid_dist = np.sqrt(np.einsum('nij,nij->ni', dcs, dcs))
        # Block 1:  Vectorized 
        dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)
        grid_dist = np.linalg.norm(dcs, axis=2) 

        pbecke = numpy.ones((mol.natm,ngrids))

        #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) 
        grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
        atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
        for i in range(mol.natm):
            for j in range(i): 
                #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                #g = 1 / atm_dist[i,j] * grid_dists[i, j] #( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                g = atm_grid_dists[i, j] #( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                #g = fadjust(i, j, g) 
                g1 = g**2
                g1 -= 1.
                g1 *= -a[i,j]
                g1 += g
                g = g1 

                #g = original_becke(g)
                g = (3 - g**2) * g * .5
                g = (3 - g**2) * g * .5
                g = (3 - g**2) * g * .5


                pbecke[i] *= .5 * (1-g)
                pbecke[j] *= .5 * (1+g)

        weights = vol * pbecke[ia] * (1./pbecke.sum(axis=0))
        coords_all.append(coords)
        weights_all.append(weights)

    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, treutler_atomic_radii_adjust, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  assert np.allclose(grids1.weights, grids2.weights)

  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists


(16904, 3)
[5.65052032e-05 2.83718109e-05 1.21593475e-05 4.29153442e-06
 3.88860703e-04 1.00135803e-04 5.26905060e-05 2.48193741e-04
 2.05833912e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[5.74588776e-05 2.83718109e-05 1.14440918e-05 4.05311584e-06
 3.50952148e-04 1.10387802e-04 5.38825989e-05 2.30789185e-04
 1.81314945e-02]
[0.   0.2  0.02 0.   0.  ]
(16904, 3)
[6.03199005e-05 3.33786011e-05 1.43051147e-05 4.29153442e-06
 3.61919403e-04 9.96589661e-05 5.38825989e-05 2.30073929e-04
 1.70199871e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[5.55515289e-05 2.69412994e-05 1.19209290e-05 4.29153442e-06
 3.56912613e-04 1.00374222e-04 5.72204590e-05 2.50577927e-04
 1.71055794e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[5.62667847e-05 2.74181366e-05 1.12056732e-05 3.81469727e-06
 3.63111496e-04 1.00612640e-04 5.26905060e-05 2.49385834e-04
 1.70495510e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[5.36441803e-05 2.64644623e-05 1.09672546e-05 4.29153442e-06
 3.77416611e-04 1.02996826e-04 5.31673431e-

KeyboardInterrupt: 

In [62]:
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape)
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(times[1:] - times[:-1])
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 


from pyscf.data.elements import charge as elements_proton



# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)

    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5
    #:return lambda i,j,g: g + a[i,j]*(1-g**2)
    def fadjust(i, j, g):
        g1 = g**2
        g1 -= 1.
        g1 *= -a[i,j]
        g1 += g
        return g1

    atm_coords = mol.atom_coords() 
    atm_dist   = gto.inter_distance(mol)

    coords_all  = []
    weights_all = []
    natm = mol.natm
    for ia in range(mol.natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)] # there are only n_atm_types different ones here! 5 for {C,N,O,F,H}
        ngrids      = coords.shape[0]
        coords      = coords+atm_coords[ia]

        # Block 1:  For loop 
        #grid_dist   = numpy.empty((mol.natm,ngrids))
        #for _ia in range(mol.natm):
        #    #dc = coords - atm_coords[_ia]
        #    #assert np.allclose(dc, dcs[_ia])
        #    dc = dcs[_ia]
        #    grid_dist[_ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc))
        #grid_dist = np.sqrt(np.einsum('nij,nij->ni', dcs, dcs))
        # Block 1:  Vectorized 
        dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)
        grid_dist = np.linalg.norm(dcs, axis=2) 

        pbecke    = numpy.ones((mol.natm,ngrids))

        #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) 
        grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
        atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists

        '''
        #g = fadjust(i, j, g) 
        #g1 = g**2
        #g1 -= 1.
        #g1 *= -a[i,j]
        #g1 += g
        #g = (atm_grid_dists[i,j]**2 - 1.)*-a[i,j] + atm_grid_dists[i,j]
        '''
        atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
        '''
        #g = original_becke(g)
        g = (3 - g**2) * g * .5
        g = (3 - g**2) * g * .5
        g = (3 - g**2) * g * .5

        '''
        f = lambda g: (3 - g**2) * g * .5
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)

        #atm_grid_dists = atm_grid_dists**2 - 1
        for i in range(mol.natm):
            for j in range(i): 
                #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                #g = 1 / atm_dist[i,j] * grid_dists[i, j] #( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                #g = atm_grid_dists[i, j] # ( grid_dist[i] - grid_dist[j] ) # There are only n**2 values of this but we compute it n**3 times. 
                #g = fadjust(i, j, g) 
                #g1 = g**2
                #g1 -= 1.
                #g1 *= -a[i,j]
                #g1 += g
                #g = (atm_grid_dists[i,j]**2 - 1.)*-a[i,j] + atm_grid_dists[i,j]
                #assert np.allclose(g, atm_grid_dists_radii_adjusted[i,j])
                g = atm_grid_dists_radii_adjusted[i,j]

                #g = original_becke(g)
                #g = (3 - g**2) * g * .5
                #g = (3 - g**2) * g * .5
                #g = (3 - g**2) * g * .5

                pbecke[i] *= .5 * (1 - g)
                pbecke[j] *= .5 * (1 + g)

        weights = vol * pbecke[ia] * (1./pbecke.sum(axis=0))
        coords_all.append(coords)
        weights_all.append(weights)

    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, treutler_atomic_radii_adjust, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  assert np.allclose(grids1.weights, grids2.weights)

  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists


(16904, 3)
[5.50746918e-05 2.71797180e-05 1.21593475e-05 4.52995300e-06
 3.94105911e-04 1.31607056e-04 7.67707825e-05 2.34365463e-04
 1.69250965e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.98430634e-05 2.76565552e-05 1.23977661e-05 4.29153442e-06
 4.68730927e-04 1.38759613e-04 7.03334808e-05 2.49624252e-04
 1.77855492e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[5.74588776e-05 2.81333923e-05 1.19209290e-05 4.76837158e-06
 4.63724136e-04 1.00851059e-04 5.76972961e-05 2.35319138e-04
 1.75397396e-02]
[0.   0.19 0.02 0.   0.  ]
(16904, 3)
[6.07967377e-05 2.76565552e-05 1.09672546e-05 3.81469727e-06
 3.91244888e-04 9.53674316e-05 5.50746918e-05 2.33411789e-04
 2.20179558e-02]
[0.   0.19 0.03 0.   0.  ]
(16904, 3)
[5.50746918e-05 2.59876251e-05 1.14440918e-05 4.05311584e-06
 3.88622284e-04 1.10626221e-04 5.72204590e-05 2.32219696e-04
 1.78749561e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.84125519e-05 2.78949738e-05 2.50339508e-05 4.76837158e-06
 4.01020050e-04 9.46521759e-05 5.53131104e-

KeyboardInterrupt: 

In [90]:
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape)
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(times[1:] - times[:-1])
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 


from pyscf.data.elements import charge as elements_proton



# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)

    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5
    #:return lambda i,j,g: g + a[i,j]*(1-g**2)
    def fadjust(i, j, g):
        g1 = g**2
        g1 -= 1.
        g1 *= -a[i,j]
        g1 += g
        return g1

    atm_coords = mol.atom_coords() 
    atm_dist   = gto.inter_distance(mol)

    coords_all  = []
    weights_all = []
    natm = mol.natm
    for ia in range(mol.natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)] # there are only n_atm_types different ones here! 5 for {C,N,O,F,H}
        ngrids      = coords.shape[0]
        coords      = coords+atm_coords[ia]

        # Block 1:  For loop 
        #grid_dist   = numpy.empty((mol.natm,ngrids))
        #for _ia in range(mol.natm):
        #    #dc = coords - atm_coords[_ia]
        #    #assert np.allclose(dc, dcs[_ia])
        #    dc = dcs[_ia]
        #    grid_dist[_ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc))
        #grid_dist = np.sqrt(np.einsum('nij,nij->ni', dcs, dcs))
        # Block 1:  Vectorized 
        dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)
        grid_dist = np.linalg.norm(dcs, axis=2) 

        pbecke    = numpy.ones((mol.natm,ngrids))

        #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) 
        grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
        atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists

        '''
        #g = fadjust(i, j, g) 
        #g1 = g**2
        #g1 -= 1.
        #g1 *= -a[i,j]
        #g1 += g
        #g = (atm_grid_dists[i,j]**2 - 1.)*-a[i,j] + atm_grid_dists[i,j]
        '''
        atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
        '''
        #g = original_becke(g)
        g = (3 - g**2) * g * .5
        g = (3 - g**2) * g * .5
        g = (3 - g**2) * g * .5

        '''
        f = lambda g: (3 - g**2) * g * .5
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)

        atm_grid_dists_radii_adjusted_m1 = 0.5*(1-atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted_p1 = 0.5*(1+atm_grid_dists_radii_adjusted)

        #atm_grid_dists = atm_grid_dists**2 - 1
        for i in range(mol.natm):
            for j in range(i, mol.natm): 
              atm_grid_dists_radii_adjusted_m1[i,j] = 1 
              atm_grid_dists_radii_adjusted_p1[i,j] = 1

        # Block 2: for loops 
        #for i in range(mol.natm):
        #    #pbecke[i] *= np.prod(atm_grid_dists_radii_adjusted_m1[i,:], axis=0)
        #    pbecke *= atm_grid_dists_radii_adjusted_p1[i, :]
        #    #for j in range(mol.natm): 
        #    #    #pbecke[i] *= atm_grid_dists_radii_adjusted_m1[i,j]
        #    #    pbecke[j] *= atm_grid_dists_radii_adjusted_p1[i,j] 
        # Block 2: vectorized (could apply log tricks!)
        pbecke *= np.prod(atm_grid_dists_radii_adjusted_m1, axis=1)
        pbecke *= np.prod(atm_grid_dists_radii_adjusted_p1, axis=0)

        norm = pbecke.sum(axis=0)
        weights = vol * pbecke[ia] * (1. / norm )
        coords_all.append(coords)
        weights_all.append(weights)


    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, treutler_atomic_radii_adjust, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  assert np.allclose(grids1.weights, grids2.weights)

  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists


(16904, 3)
[5.74588776e-05 2.71797180e-05 1.62124634e-05 6.19888306e-06
 4.79698181e-04 1.03950500e-04 5.36441803e-05 2.72750854e-04
 1.81984901e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.55515289e-05 2.67028809e-05 1.26361847e-05 4.76837158e-06
 3.99827957e-04 8.74996185e-05 5.17368317e-05 2.28881836e-04
 1.80420876e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.67436218e-05 2.64644623e-05 1.16825104e-05 4.29153442e-06
 3.87191772e-04 8.86917114e-05 5.14984131e-05 2.44379044e-04
 1.77793503e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.79357147e-05 2.71797180e-05 1.23977661e-05 4.52995300e-06
 3.91483307e-04 9.05990601e-05 4.91142273e-05 2.72512436e-04
 1.93643570e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.86509705e-05 2.76565552e-05 1.28746033e-05 5.00679016e-06
 3.99589539e-04 9.63211060e-05 6.24656677e-05 2.46524811e-04
 1.74753666e-02]
[0.   0.18 0.02 0.   0.  ]
(16904, 3)
[5.93662262e-05 2.88486481e-05 1.23977661e-05 5.48362732e-06
 4.11033630e-04 9.25064087e-05 5.31673431e-

KeyboardInterrupt: 

# V2 

In [112]:
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape)
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(times[1:] - times[:-1])
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 


from pyscf.data.elements import charge as elements_proton



# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)

    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5

    atm_coords = mol.atom_coords() 
    atm_dist   = gto.inter_distance(mol)
    coords_all  = []
    weights_all = []
    natm = mol.natm

    precompute_coords_m_atm_coords = {}

    for ia in range(natm):
      if mol.atom_symbol(ia) not in precompute_coords_m_atm_coords: 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        precompute_coords_m_atm_coords[ia] = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)


    for ia in range(natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        ngrids      = coords.shape[0]

        #dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) + atm_coords[ia].reshape(1,1,3) # atm_coords here is only difference!

        # we might be able to ... ? compute using inner product square formula;  then, when grid_dist is used in grid_dist - grid_dist the <2ab> terms cancel out! 

        # oh lol this is the third im can just write out naively! 
        dcs       = precompute_coords_m_atm_coords[ia] + atm_coords[ia].reshape(1,1,3)
        print(dcs.shape) # (natm, atm_gridsize, 3)
        #grid_dist = np.linalg.norm(dcs, axis=2) #+np.linalg.norm() # this is not linear, so we can't precompute the coords-atm_coords; 
        # perhaps lg transofrm can allow us to seperate? 
        grid_dist = np.sqrt( dcs[:, :, 0]**2 + dcs[:,:,1]**2 + dcs[:,:,2]**2 )
        # however, ||a+b||=||a||+||b||-2<a,b>  we could precompute ||a||

        pbecke    = numpy.ones((natm,ngrids))

        #g = 1 / atm_dist[i,j] * ( grid_dist[i] - grid_dist[j] ) 
        grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
        atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
        atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
        f = lambda g: (3 - g**2) * g * .5
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted = f(atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted_m1 = 0.5*(1-atm_grid_dists_radii_adjusted)
        atm_grid_dists_radii_adjusted_p1 = 0.5*(1+atm_grid_dists_radii_adjusted)

        #atm_grid_dists = atm_grid_dists**2 - 1
        for i in range(natm):
            for j in range(i, natm): 
              atm_grid_dists_radii_adjusted_m1[i,j] = 1 
              atm_grid_dists_radii_adjusted_p1[i,j] = 1

        pbecke *= np.prod(atm_grid_dists_radii_adjusted_m1, axis=1)
        pbecke *= np.prod(atm_grid_dists_radii_adjusted_p1, axis=0)

        norm = pbecke.sum(axis=0)
        weights = vol * pbecke[ia] * (1. / norm )
        coords_all.append(coords+atm_coords[ia])
        weights_all.append(weights)


    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, treutler_atomic_radii_adjust, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  assert np.allclose(grids1.weights, grids2.weights)

(20, 1086, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(16904, 3)
[6.98566437e-05 2.83718109e-05 1.33514404e-05 4.76837158e-06
 4.38451767e-04 1.35183334e-04 7.67707825e-05 3.32355499e-04
 2.09133625e-02]
[0.   0.18 0.02 0.   0.  ]


  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists


(20, 1086, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(16904, 3)
[9.94205475e-05 3.74317169e-05 1.57356262e-05 4.76837158e-06
 4.73022461e-04 1.34944916e-04 7.65323639e-05 2.96354294e-04
 1.89578533e-02]
[0.   0.19 0.02 0.   0.  ]
(20, 1086, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 1098, 3)
(20, 1086, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(20, 596, 3)
(16904, 3)
[8.34465027e-05 2.98023224e-05 2.55107880e-05 4.52995300e-06
 4.15802002e-04 9.67979431e-05 5.91278076e-05 2.33411789e-04
 1.77588463e-02]
[0.   0.18 0.02 0.   0.  ]
(20, 1086, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 1098, 3)
(20, 108

KeyboardInterrupt: 

# Naive vectorize (not trick to reduce natm to n_distinct_atm)

In [22]:
import jax
#print(jax.config.num_omp_threads)
print(jax.config.xla_num_threads)


AttributeError: 'Config' object has no attribute 'xla_num_threads'

In [64]:
#!source /nethome/alexm/poplar_sdk-ubuntu_20_04-3.1.0+1205-58b501c780/enable
#!bash ../source.sh #source /nethome/alexm/poplar_sdk-ubuntu_20_04-3.1.0+1205-58b501c780/enable
import re 
from rdkit      import Chem
from rdkit.Chem import AllChem
from pyscf.gto.mole import Mole
import os
num_threads = 15
os.environ["OMP_NUM_THREADS"] = str(num_threads)
angstrom_to_bohr = 1.88973

def get_atom_string(atoms, locs):
    atom_string = atoms 
    atoms = re.findall('[a-zA-Z][^A-Z]*', atoms)
    str = ""
    for atom, loc in zip(atoms, locs): 
      str += "%s %4f %4f %4f; "%((atom,) + tuple(loc) )
    return atom_string, str 

smiles = [a for a in open("gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
smile = smiles[3000000//2:3000000//2+200][1]

print(">>>", smile)

atoms = [a for a in list(smile.upper()) if a == "C" or a == "N" or a == "O" or a == "F"]

b = Chem.MolFromSmiles(smile)
b = Chem.AddHs(b, explicitOnly=False)   

AllChem.EmbedMolecule(b)

atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
num_hs = len([a for a in atoms if a == "H"])
print(num_hs)
locs =  b.GetConformer().GetPositions() * angstrom_to_bohr  

atom_string, string = get_atom_string(" ".join(atoms), locs)

mol = Mole() 
mol = mol.build(mol, atom=string, unit="Bohr", basis="sto3g", spin=0, verbose=0)


import os
os.environ["OMP_NUM_THREADS"] = "4"
import jax 
from jax import config 
config.update('jax_enable_x64', True) # perhaps it's the loss computation in the end? 
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    print(box_ids.shape) # this last call is the one that takes 20 ms. 
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    print(np.around(times[1:] - times[:-1], 2))
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 


from pyscf.data.elements import charge as elements_proton
def f(g): 
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    return g 

# vol is same size but coords increases a little in size. 
# could vmap this over natm if we pad coords! 
def iter(dcs, vol, ia, natm, ngrids, atm_dist, a):
    #print("COMPILING")
    import jax.numpy as np
    from jax import lax 
    grid_dist      = np.linalg.norm(dcs, axis=2)
    pbecke         = np.ones((natm,ngrids))
    grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
    atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
    atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
    atm_grid_dists_radii_adjusted =  f(atm_grid_dists_radii_adjusted) 
    atm_grid_dists_radii_adjusted_m1 = 0.5*(1-atm_grid_dists_radii_adjusted)
    atm_grid_dists_radii_adjusted_p1 = 0.5*(1+atm_grid_dists_radii_adjusted)

    '''for i in range(natm):
        for j in range(i, natm): 
            #atm_grid_dists_radii_adjusted_m1[i,j] = 1 
            #atm_grid_dists_radii_adjusted_p1[i,j] = 1

            atm_grid_dists_radii_adjusted_m1 = atm_grid_dists_radii_adjusted_m1.at[i,j].set( 1 )
            atm_grid_dists_radii_adjusted_p1 = atm_grid_dists_radii_adjusted_p1.at[i,j].set( 1 )'''
    #atm_grid_dists_radii_adjusted_m1[np.triu_indices(natm)] = 1
    #atm_grid_dists_radii_adjusted_p1[np.triu_indices(natm)] = 1

    atm_grid_dists_radii_adjusted_m1 = atm_grid_dists_radii_adjusted_m1.at[np.triu_indices(natm)].set(1)
    atm_grid_dists_radii_adjusted_p1 = atm_grid_dists_radii_adjusted_p1.at[np.triu_indices(natm)].set(1)

    pbecke *= np.prod(atm_grid_dists_radii_adjusted_m1, axis=1)
    pbecke *= np.prod(atm_grid_dists_radii_adjusted_p1, axis=0)

    norm = pbecke.sum(axis=0)
    weights = vol * pbecke[ia] * (1. / norm )
    return weights 

#iter = jax.jit(iter, static_argnums=(3,4))

# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII
                  ):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)

    charges = [elements_proton(x) for x in mol.elements]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5

    atm_coords = mol.atom_coords() 
    atm_dist   = gto.inter_distance(mol)
    coords_all  = []
    weights_all = []
    natm = mol.natm

    precompute_coords_m_atm_coords = {}

    # try jax.jit this? 
    #f = jax.jit(f)

    print(natm)

    for ia in range(natm):
      #print(mol.atom_symbol(ia), ia)
      if mol.atom_symbol(ia) not in precompute_coords_m_atm_coords: 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        precompute_coords_m_atm_coords[ia] = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)
        #print(coords.shape, vol.shape) 

    '''
    # so the first C,N,O,F have slightly different sizes.
    # the subsequent 10 hydrogens are all the same! 
    
    (20, 1086, 3) (1086,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1086, 3) (1086,) (20, 20) (20, 20)
    /tmp/ipykernel_1183754/2275182240.py:114: RuntimeWarning: divide by zero encountered in divide
    atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
    Output exceeds the size limit. Open the full output data in a text editor(20, 1098, 3) (1098,) (20, 20) (20, 20)
    (20, 1086, 3) (1086,) (20, 20) (20, 20)
    (20, 596, 3) (596,) (20, 20) (20, 20)
    (20, 596, 3) (596,) (20, 20) (20, 20)
    
    '''

    # we want to vectorize this part aswell. 
    for ia in range(0, 10):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        ngrids      = coords.shape[0]

        #dcs       = precompute_coords_m_atm_coords[ia] + atm_coords[ia].reshape(1,1,3)
        dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) + atm_coords[ia].reshape(1,1,3)
        #grid_dist = np.sqrt( dcs[:, :, 0]**2 + dcs[:,:,1]**2 + dcs[:,:,2]**2 )
        #grid_dist = np.linalg.norm(dcs, axis=2)
        
        print(dcs.shape, vol.shape, atm_dist.shape, a.shape)
        weights =  iter(dcs, vol, ia, natm, ngrids, atm_dist, a)  # vectorize this and do all in one call! 

        coords_all.append(coords+atm_coords[ia])
        weights_all.append(weights)

    batched_iter = jax.vmap(iter, in_axes=(0, None, 0, None, None, None, None), out_axes=(0))

    cat_input = []

    # this is all hydrogen. 
    for ia in range(10, natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        ngrids      = coords.shape[0]

        dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) + atm_coords[ia].reshape(1,1,3)
        _dcs      = dcs.reshape((1, 20, 596, 3))
        cat_input.append(_dcs)
        weights =  batched_iter(_dcs, vol, np.array(ia).reshape(1, 1 ), natm, ngrids, atm_dist, a)[0,0]  # vectorize this and do all in one call! 
        print(weights.shape)

        coords_all.append(coords+atm_coords[ia])
        #weights_all.append(weights)

    input = np.concatenate(cat_input, axis=0)
    # so we know the tuple at compile time; is that ok? 
    _weights =  batched_iter(input, vol, np.arange(10, natm).reshape(-1, 1), natm, ngrids, atm_dist, a)  # vectorize this and do all in one call! 
    print(_weights.shape)

    for i, w in enumerate(_weights):
       #print(w.shape)
       #print(np.max(np.abs( w -  weights_all[10+i])))
       #assert np.allclose(w, weights_all[10+i])
       weights_all.append(w.reshape(-1))

    coords_all = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)

    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  #self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme, concat=True) 
  #self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.radii_adjust, self.atomic_radii) 
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, None, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 2))

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()     
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)     
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  print(np.max(np.abs(grids1.weights - grids2.weights)))
  assert np.allclose(grids1.weights, grids2.weights, atol=1e-4)

>>> NC1CC(C=C1)=C(F)CO
10
20
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1086, 3) (1086,) (20, 20) (20, 20)


  atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists


(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(10, 1, 596)
(16904, 3)
[0.   0.   0.   0.   0.   0.   0.   0.   0.02]
[0.   0.82 0.02 0.   0.  ]
3.481659405224491e-13
20
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(596,)
(10, 1, 596)
(16904, 3)
[0.   0.   0.   0.   0.   0.   0.   0.   0.02]
[0.   0.78 0.02 0.   0.  ]
3.481659405224491e-13
20
(20, 1086, 3) (1086,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) (1098,) (20, 20) (20, 20)
(20, 1098, 3) 

KeyboardInterrupt: 

# Cleanup vectorized with jax.jit and final {CNOF} and {H} batching through jax.vmap

In [27]:
#!source /nethome/alexm/poplar_sdk-ubuntu_20_04-3.1.0+1205-58b501c780/enable
#!bash ../source.sh #source /nethome/alexm/poplar_sdk-ubuntu_20_04-3.1.0+1205-58b501c780/enable
import re 
from rdkit      import Chem
from rdkit.Chem import AllChem
from pyscf.gto.mole import Mole
import os
num_threads = 15
os.environ["OMP_NUM_THREADS"] = str(num_threads)
angstrom_to_bohr = 1.88973

def get_atom_string(atoms, locs):
    atom_string = atoms 
    atoms = re.findall('[a-zA-Z][^A-Z]*', atoms)
    str = ""
    for atom, loc in zip(atoms, locs): 
      str += "%s %4f %4f %4f; "%((atom,) + tuple(loc) )
    return atom_string, str 

smiles = [a for a in open("gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
smile = smiles[3000000//2:3000000//2+200][1]

print(">>>", smile)

atoms = [a for a in list(smile.upper()) if a == "C" or a == "N" or a == "O" or a == "F"]

b = Chem.MolFromSmiles(smile)
b = Chem.AddHs(b, explicitOnly=False)   

AllChem.EmbedMolecule(b)

atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
num_hs = len([a for a in atoms if a == "H"])
print(num_hs)
locs =  b.GetConformer().GetPositions() * angstrom_to_bohr  

atom_string, string = get_atom_string(" ".join(atoms), locs)

mol = Mole() 
mol = mol.build(mol, atom=string, unit="Bohr", basis="sto3g", spin=0, verbose=0)


import os
os.environ["OMP_NUM_THREADS"] = "4"
import jax 
from jax import config 
config.update('jax_enable_x64', True) # perhaps it's the loss computation in the end? 
import time
import numpy 
import pyscf.dft
import numpy as np 
import numpy 
from pyscf import gto 
from pyscf import lib

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    times = []
    times.append(time.time())
    atom_coords = mol.atom_coords()
    times.append(time.time())
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    times.append(time.time())
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    times.append(time.time())
    #tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    times.append(time.time())
    frac_coords = (coords - boundary[0]) * (1./box_size)
    times.append(time.time())
    box_ids = numpy.floor(frac_coords).astype(int)
    times.append(time.time())
    box_ids[box_ids<-1] = -1
    times.append(time.time())
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    times.append(time.time()) # this is the one that takes 20 ms? 
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    times.append(time.time())
    times = np.array(times)
    #print(np.around(times[1:] - times[:-1], 2))
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 


from pyscf.data.elements import charge as elements_proton
def f(g): 
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    return g 

# vol is same size but coords increases a little in size. 
# could vmap this over natm if we pad coords! 
def iter(vol, ia, natm, ngrids, atm_dist, a, atm_coords, coords):
    #print("COMPILING")
    import jax.numpy as np
    dcs         = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) + atm_coords[ia].reshape(1,1,3)
    grid_dist      = np.linalg.norm(dcs, axis=2)
    pbecke         = np.ones((natm,ngrids))
    grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
    atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
    atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
    atm_grid_dists_radii_adjusted =  f(atm_grid_dists_radii_adjusted) 
    atm_grid_dists_radii_adjusted_m1 = 0.5*(1-atm_grid_dists_radii_adjusted)
    atm_grid_dists_radii_adjusted_p1 = 0.5*(1+atm_grid_dists_radii_adjusted)

    atm_grid_dists_radii_adjusted_m1 = atm_grid_dists_radii_adjusted_m1.at[np.triu_indices(natm)].set(1)
    atm_grid_dists_radii_adjusted_p1 = atm_grid_dists_radii_adjusted_p1.at[np.triu_indices(natm)].set(1)

    pbecke *= np.prod(atm_grid_dists_radii_adjusted_m1, axis=1)
    pbecke *= np.prod(atm_grid_dists_radii_adjusted_p1, axis=0)

    norm = pbecke.sum(axis=0)
    weights = vol * pbecke[ia] * (1. / norm )
    return weights 


def iter2(vol, ia, natm, ngrids, atm_dist, a, atm_coords, coords):
    #print("COMPILING")
    import jax.numpy as np
    #print(">", dcs.shape)
    dcs         = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) + atm_coords[ia].reshape(1,1,3)
    #print(dcs.shape)
    grid_dist      = np.linalg.norm(dcs, axis=2)
    pbecke         = np.ones((natm,ngrids))
    grid_dists     = grid_dist.reshape(natm, 1, ngrids) - grid_dist.reshape(1, natm, ngrids)
    atm_grid_dists = (1/atm_dist).reshape(natm, natm, 1) * grid_dists
    atm_grid_dists_radii_adjusted = (atm_grid_dists**2 - 1) * -a.reshape(natm, natm, 1) + atm_grid_dists
    atm_grid_dists_radii_adjusted =  f(atm_grid_dists_radii_adjusted) 
    atm_grid_dists_radii_adjusted_m1 = 0.5*(1-atm_grid_dists_radii_adjusted)
    atm_grid_dists_radii_adjusted_p1 = 0.5*(1+atm_grid_dists_radii_adjusted)

    atm_grid_dists_radii_adjusted_m1 = atm_grid_dists_radii_adjusted_m1.at[np.triu_indices(natm)].set(1)
    atm_grid_dists_radii_adjusted_p1 = atm_grid_dists_radii_adjusted_p1.at[np.triu_indices(natm)].set(1)

    pbecke *= np.prod(atm_grid_dists_radii_adjusted_m1, axis=1)
    pbecke *= np.prod(atm_grid_dists_radii_adjusted_p1, axis=0)

    norm = pbecke.sum(axis=0)
    weights = vol * pbecke[ia] * (1. / norm )
    return weights[0]

iter = jax.jit(iter, static_argnums=(2,3))
batched_iter = jax.vmap(iter, in_axes=(None, 0, None, None, None, None, None, 0), out_axes=(0))# for hyrodgen
iter2 = jax.jit(iter2, static_argnums=(2,3))
batched_iter_no_h = jax.vmap(iter2, in_axes=(0, 0, None, None, None, None, None, 0), out_axes=(0))

# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  atomic_radii=radi.BRAGG_RADII
                  ):
    ''' Generate the mesh grid coordinates and weights for DFT numerical integration.
    Returns:
        coord   (N, 3) 
        weights (N, ) 
    '''
    #f_radii_adjust = treutler_atomic_radii_adjust(mol, atomic_radii)
    # [0.    0.    0.007 0.    0.042 0.011 0.    0.004 0.    0.015 0.001 0.001]

    times = []

    times.append(time.time()) # 0 ms

    #print(mol.elements) # ['N', 'C', 'C', 'C', 'C', 'C', 'C', 'F', 'C', 'O', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H', 'H']
    num_non_hs = len([a for a in mol.elements if a.lower() != "h"])
    charges = [elements_proton(x) for x in mol.elements]
    #print(charges) # [7, 6, 6, 6, 6, 6, 6, 9, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
    #print(atomic_radii) # [3.77945036 0.66140414 2.64561657 2.74010288 1.98421243 1.60626721 1.32280829 1.22832198 1.13383567 0.94486306  ... ]
    rad = numpy.sqrt(atomic_radii[charges]) + 1e-200
    rr = rad.reshape(-1,1) * (1./rad)
    a = .25 * (rr.T - rr)
    a[a<-.5] = -.5
    a[a>0.5] = 0.5

    times.append(time.time()) # 0 ms
    atm_coords  = mol.atom_coords() # (20, 3)
    atm_dist    = np.linalg.norm( atm_coords.reshape(-1, 1, 3) - atm_coords.reshape(1, -1, 3) , axis=2)
    coords_all  = []
    weights_all = []
    natm        = mol.natm
    times.append(time.time()) # 7ms 

    noh_vols_batch = []
    noh_coords_batch = []
    noh_sizes = []
    # reduced from 7->0ms by moving dcs computation into batched_iter.
    noh_ngrids = 1100 
    for ia in range(0, num_non_hs):  
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)] 
        noh_sizes.append(coords.shape[0])
        noh_coords_batch.append( np.concatenate([coords, np.zeros((noh_ngrids-coords.shape[0], 3))], axis=0).reshape(1, -1))
        noh_vols_batch.append( np.concatenate((vol, np.zeros(noh_ngrids-vol.shape[0]))).reshape(1, -1) )

    noh_vols_batch = np.concatenate(noh_vols_batch)
    noh_indxs = np.arange(0, 10).reshape(-1, 1)
    noh_coords_batch = np.concatenate(noh_coords_batch)

    times.append(time.time())
    h_coords_batch = []
    for ia in range(num_non_hs, natm): 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  
        ngrids      = coords.shape[0]
        h_coords_batch.append(coords.reshape(1, -1))

    h_coords_batch = np.concatenate(h_coords_batch)
    h_indxs = np.arange(10, natm).reshape(-1, 1)
    times.append(time.time()) 
    _weights =  batched_iter(vol, h_indxs, natm, ngrids, atm_dist, a, atm_coords, h_coords_batch)
    times.append(time.time()) # 42ms ( 29ms with taskset -c 0-14 )
    w1 = batched_iter_no_h(noh_vols_batch, noh_indxs, natm, noh_ngrids, atm_dist, a, atm_coords, noh_coords_batch)
    times.append(time.time())

    # 11 ms  
    for ia in range(0, num_non_hs):  weights_all.append( w1[ia].reshape(-1)[:noh_sizes[ia]] )

    # 1ms 
    for i, w in enumerate(_weights):
       weights_all.append(w.reshape(-1))

    for ia in range(0, natm):  # The first two here is just O(natm**2 * grid_size). 
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]  # there are only 5 different ones here (not 10)
        ngrids      = coords.shape[0]
        coords_all.append(coords+atm_coords[ia])


    times.append(time.time()) # 1ms 
    coords_all  = numpy.vstack(coords_all)
    weights_all = numpy.hstack(weights_all)
    #print(weights_all.shape)

    times.append(time.time())

    times = np.array(times)
    print("\t", np.around(times[1:] - times[:-1], 3), times[-1]-times[0])
    return coords_all, weights_all

def build_grid(self):
  mol = self.mol

  times = []
  times.append(time.time()) # below 0
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time()) # below 0.13
  self.coords, self.weights = _get_partition(mol, atom_grids_tab, self.atomic_radii) 
  times.append(time.time()) # below is 0.02
  idx = arg_group_grids(mol, self.coords)
  times.append(time.time()) # below is 0

  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time()) # below is 0 

  # this actually does do smth?
  if self.alignment > 1:
      def _padding_size(ngrids, alignment):
          if alignment <= 1:
              return 0
          return (ngrids + alignment - 1) // alignment * alignment - ngrids

      padding = _padding_size(self.size, self.alignment)
      #logger.debug(self, 'Padding %d grids', padding)
      if padding > 0:
          self.coords = numpy.vstack(
              [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)])
          self.weights = numpy.hstack([self.weights, numpy.zeros(padding)])
  
  times.append(time.time())
  self.screen_index = self.non0tab = None

  times = np.array(times)
  print(np.around(times[1:]-times[:-1], 3), times[-1]-times[0])

  return self

for _ in range(100):  
  t0 =time.time()
  grids1            = pyscf.dft.gen_grid.Grids(mol) 
  grids1.level      = 0 
  grids1.build()
  #print(time.time()-t0)
  #print(grids1.coords.shape)

  t0 =time.time()
  grids2            = pyscf.dft.gen_grid.Grids(mol) 
  grids2.level      = 0 
  build_grid(grids2)
  #print(grids2.coords.shape)
  #print(time.time()-t0)

  assert np.allclose(grids1.coords, grids2.coords)
  #print(np.max(np.abs(grids1.weights - grids2.weights)))
  assert np.allclose(grids1.weights, grids2.weights)

>>> NC1CC(C=C1)=C(F)CO
10
	 [0.    0.    0.    0.    0.601 0.608 0.008 0.   ] 1.2169206142425537
[1.000e-03 1.217e+00 1.800e-02 0.000e+00 0.000e+00] 1.2372937202453613
	 [0.    0.    0.    0.    0.021 0.037 0.008 0.   ] 0.06640243530273438
[0.001 0.067 0.019 0.001 0.   ] 0.0876307487487793
	 [0.    0.    0.    0.    0.025 0.044 0.008 0.   ] 0.07828497886657715
[0.001 0.079 0.018 0.    0.   ] 0.0984032154083252
	 [0.    0.    0.    0.    0.02  0.056 0.008 0.   ] 0.0848546028137207
[0.001 0.085 0.018 0.    0.   ] 0.10498166084289551
	 [0.    0.    0.    0.    0.019 0.053 0.008 0.   ] 0.07955336570739746
[0.001 0.08  0.019 0.    0.   ] 0.10108232498168945
	 [0.    0.    0.    0.    0.025 0.038 0.008 0.   ] 0.07160735130310059
[0.001 0.072 0.018 0.    0.   ] 0.09174513816833496
	 [0.    0.    0.    0.    0.019 0.055 0.008 0.   ] 0.08260273933410645
[0.001 0.083 0.018 0.    0.   ] 0.1031186580657959
	 [0.    0.    0.    0.    0.021 0.049 0.008 0.   ] 0.07844901084899902
[0.001 0.079 0.019 0

KeyboardInterrupt: 

# Other

In [None]:
# for grid stuff / ao, we can minimize grid size by min num_atoms
# if we want 1M, we may use 11 atoms with less hydrogens! 

In [16]:
!ls -lahS /a/scratch/alexm/research/mostrecent/

total 78G
-rw-r--r--  1 alexm all       27G Apr 12 09:16 gdb11_sto3g.parquet
-rw-r--r--  1 alexm all       20G Apr 19 08:01 gdb12.parquet
-rw-r--r--  1 alexm all       19G Apr 24 21:33 gdb11_sto3g_converged_atoms.parquet
-rw-r--r--  1 alexm all      2.8G Apr 24 20:36 gdb10_sto3g_converged.parquet
-rw-r--r--  1 alexm all      2.8G Apr 24 21:16 gdb10_sto3g_converged_atoms.parquet
-rw-r--r--  1 alexm all      2.6G Apr 27 14:23 bigrun4_v5.parquet
-rw-r--r--  1 alexm all      2.5G Apr 10 18:55 gbd10_sto3g.csv
-rw-r--r--  1 alexm all      368M Apr 24 20:31 gdb9_sto3g_converged.parquet
-rw-r--r--  1 alexm all      366M Apr 24 21:14 gdb9_sto3g_converged_atoms.parquet
-rw-r--r--  1 alexm all      330M Apr 10 11:13 gbd9_sto6g.csv
-rw-r--r--  1 alexm all      329M Apr 10 11:12 gbd9_sto3g.csv
-rw-r--r--  1 alexm all      255M Apr 12 09:11 gdb11_sto3g.csv
-rw-r--r--  1 alexm all      4.6M Apr 18 11:54 gdb11_sto3g_pyscf.parquet
drwxrwx--- 21 root  research  12K Apr 26 17:41 ..
drwxr-xr-x  3 alexm al

In [2]:
import pandas as pd 
df = pd.read_parquet("/a/scratch/alexm/research/mostrecent/gdb10_sto3g_converged_atoms.parquet")

In [None]:
df[:1]

In [20]:
df[10**6]

In [None]:
# atmmpt at making O(natm**2) algorithm 
import numpy as np 
import numpy 
from pyscf import gto 

GROUP_BOX_SIZE = 1.2
GROUP_BOUNDARY_PENALTY = 4.2
def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE):
    '''
    Parition the entire space into small boxes according to the input box_size.
    Group the grids against these boxes.
    '''
    import numpy 
    atom_coords = mol.atom_coords()
    boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY]
    # how many boxes inside the boundary
    boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int)
    tot_boxes = numpy.prod(boxes + 2)
    #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes)
    # box_size is the length of each edge of the box
    box_size = (boundary[1] - boundary[0]) / boxes
    frac_coords = (coords - boundary[0]) * (1./box_size)
    box_ids = numpy.floor(frac_coords).astype(int)
    box_ids[box_ids<-1] = -1
    box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0]
    box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1]
    box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2]
    rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3]
    return rev_idx.argsort(kind='stable')

from pyscf.dft import radi
import numpy 

def original_becke(g):
    '''Becke, JCP 88, 2547 (1988); DOI:10.1063/1.454033'''
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    g = (3 - g**2) * g * .5
    return g

def original_becke_v2(g):
    return ((3 - ((3 - ((3 - g**2) * g * 0.5)**2) * ((3 - g**2) * g * 0.5) * 0.5)**2) * ((3 - ((3 - g**2) * g * 0.5)**2) * ((3 - g**2) * g * 0.5) * 0.5) * 0.5)


# TODO: refactor this to be jnp
# will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit
def _get_partition(mol, atom_grids_tab,
                  radii_adjust=None, atomic_radii=radi.BRAGG_RADII,
                  becke_scheme=original_becke, concat=True):
    '''Generate the mesh grid coordinates and weights for DFT numerical integration.
    We can change radii_adjust, becke_scheme functions to generate different meshgrid.

    Kwargs:
        concat: bool
            Whether to concatenate grids and weights in return

    Returns:
        grid_coord and grid_weight arrays.  grid_coord array has shape (N,3);
        weight 1D array has N elements.
    '''
    if callable(radii_adjust) and atomic_radii is not None:
        f_radii_adjust = radii_adjust(mol, atomic_radii)
    else:
        f_radii_adjust = None
    atm_coords = numpy.asarray(mol.atom_coords() , order='C')
    atm_dist = gto.inter_distance(mol)

    from pyscf import lib

    def gen_grid_partition(coords):
        ngrids = coords.shape[0]
        #grid_dist = numpy.empty((mol.natm,ngrids))
        grid_dist = numpy.empty((mol.natm,ngrids))
        for ia in range(mol.natm):
            dc = coords - atm_coords[ia]
            grid_dist[ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc))
        pbecke = numpy.ones((mol.natm,ngrids))
        for i in range(mol.natm):
            for j in range(i):
                g = 1/atm_dist[i,j] * (grid_dist[i]-grid_dist[j])
                if f_radii_adjust is not None:
                    g = f_radii_adjust(i, j, g)
                #g = becke_scheme(g)# gets passed the one which returns None 
                g = original_becke(g)
                #print(g)
                pbecke[i] *= .5 * (1-g)
                pbecke[j] *= .5 * (1+g)
        return pbecke

    coords_all = []
    weights_all = []
    for ia in range(mol.natm):
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)]
        coords = coords + atm_coords[ia]
        pbecke = gen_grid_partition(coords)
        weights = vol * pbecke[ia] * (1./pbecke.sum(axis=0))
        coords_all.append(coords)
        weights_all.append(weights)

    if concat:
        coords_all = numpy.vstack(coords_all)
        weights_all = numpy.hstack(weights_all)
    return coords_all, weights_all

def get_partition(self, mol, atom_grids_tab=None,
                      radii_adjust=None, atomic_radii=radi.BRAGG_RADII,
                      becke_scheme=original_becke, concat=True):
        if atom_grids_tab is None:
            atom_grids_tab = self.gen_atomic_grids(mol)
        return _get_partition(mol, atom_grids_tab, radii_adjust, atomic_radii, becke_scheme, concat=concat)

from pyscf import lib


# this can all totally be done on ipu 
# now fully vectorized! 
def gen_grid_partition(coords, atm_dist, atm_coords, ia): 
  # coords is the only thing that changes .

  dcs       = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3)
  grid_dist = numpy.sqrt( numpy.einsum('nij,nij->ni', dcs, dcs) )
  diffs     = grid_dist.reshape(mol.natm, 1, -1) - grid_dist.reshape(1, mol.natm, -1)

  gss       = 1 / atm_dist.reshape( mol.natm, mol.natm, -1) * diffs.reshape(mol.natm, mol.natm, -1)  # diffs is the only ne that chagnes 
  gss       = original_becke(gss)   # can we compute this once? 
  gss1      = 0.5*(1-gss)
  gss2      = 0.5*(1+gss)

  # zero out j<i and then make them into ones. 
  # previously there was a nested forloop over (i,j), now it's vectorized. 
  gss1 = np.transpose( np.tril( np.transpose( gss1, (2,0,1)), k=-1), (1,2,0))
  gss2 = np.transpose( np.tril( np.transpose( gss2, (2,0,1)), k=-1), (1,2,0))
  O = np.triu(np.ones((mol.natm, mol.natm)), k=0).reshape( mol.natm, mol.natm, 1)
  gss1 += O 
  gss2 += O

  pbecke = numpy.ones((mol.natm,coords.shape[0]))
  pbecke *= np.prod(gss1, axis=1) 
  pbecke *= np.prod(gss2, axis=0)

  print(pbecke) #(20, 5130) -> return something (5130) normalzies os if we summed over all the atoms it'd be 1. 

  return pbecke[ia] * (1./pbecke.sum(axis=0))  # we only return 1 atm but compute all the others...  https://watermark.silverchair.com/2547_1_online.pdf?token=AQECAHi208BE49Ooan9kkhW_Ercy7Dm3ZL_9Cf3qfKAc485ysgAAAp8wggKbBgkqhkiG9w0BBwagggKMMIICiAIBADCCAoEGCSqGSIb3DQEHATAeBglghkgBZQMEAS4wEQQMq2FNi7X5mdCnF6-0AgEQgIICUv5lFXu6cPxCw8aDUX-oQbr_-9riOxoJAlRJEIEF7p8HrtasEp-apBCgErsWUgh3OWZpPes6B6FJBqyS-TkND4cgR7VuRJUVxG1uHPmOrx78lWOxF1meALbY8gwdzRb6Q62-eIEoXEOe54OXpNgBgqom2-Z1EDtiwncNd72Z7HAY0Fkp324entSRJcI_QkHD4RujCqCoEyclvsOlnXMAWFQMCN2NSHunHexaoBVqCjic3wL0tmEIlPCsCIsJAO5OBMmEc8GXniR1m_O1dfM8igjk6fNrNRLIMZUNSZKN48nQIYgL3CZs0m2IyP9y9_l_tWfm0fIKsoxhnbeK-S-fOaJ84w-CBRkWomJjpOzhE2bCU299US7qdasCqrQFyMkxKe6fgB8_aWftsVTIvYG6cqgcoLUkl1k6U5GpDFVp-9ofHAkpuQYUEFPnyrBNTHbEYC04vEKBDGgd02diGglB5-Rv37dWHDLfGHfJrSGgUPPXuycAEYOPAG-WxS0d6y6TU78xegLJxTX9mpIwmCRtLMmhrjFQsFFVqcGnH8Eb7eIt63D6cyqBeyezivhsbc2CadkmnnGXr2B4ajjI9t_PsoVqfm3CXWhN56ec5OecCTINMJDqtESugeP2yG0pySoO-ZDcdmzS1bX9WxO0nQWawa2og3rxrI-G-iSTPMu_Ey8V0DED8jAd1Wp6Pv7KNM0UxiK00ZMRcKXlc6mkf49pfo-20DvdufXLerhZDKEX7YLjZukt57LKGhhvBaIEFwK5FYmaKP-qPxRUGFK2mV1FzFYGWw
  # look like the initial article also says to do this. 
  

def _get_partition_v2(mol, atom_grids_tab, radii_adjust=None, atomic_radii=radi.BRAGG_RADII, concat=True):
    times = []
    times.append(time.time())
    atm_coords = numpy.asarray(mol.atom_coords() , order='C')
    atm_dist = gto.inter_distance(mol)
    times.append(time.time())

    # even if we do pbekce on cpu jax.jit compiling will make much fater! 
    # the same may be true for the Quax integrals! 


    # I think we can make an algorithm that precomputes a bunch of stuff, and then computes normalziation factor in lienar time! 
    # this bascially makes things from O(natmt**3) to O(natm**2)
    times.append(time.time())
    coords_all = []
    weights_all = []
    sum = 0
    # vectorize this for loop? 
    # anyway, I imagine this is ok for 
    for ia in range(mol.natm):  # each of thees etake 50 ms!  ;; this is O(natm**3)
        coords, vol = atom_grids_tab[mol.atom_symbol(ia)] # ahh, so this algorithm only works within atom types, so it's O(atom_types, natm**2) 

        print(coords.shape, atm_coords[ia]) # (5130, 3), 3
        coords = coords + atm_coords[ia] # (5130, 3);; this removes distance to only bee between grid points and everything but the current atom_ai. 
        dcs    = coords.reshape(1, -1, 3) - atm_coords.reshape(-1, 1, 3) # (1, 5130, 3) - (20, 1, 3) thonly difference here between runs is whether we include atm_coords[ia] above! 
        print(dcs.shape)

        # Main algorithm: 
        grid_dist = numpy.sqrt( numpy.einsum('nij,nij->ni', dcs, dcs) )
        diffs     = grid_dist.reshape(mol.natm, 1, -1) - grid_dist.reshape(1, mol.natm, -1)
        gss       = 1 / atm_dist.reshape( mol.natm, mol.natm, -1) * diffs.reshape(mol.natm, mol.natm, -1)  # diffs is the only ne that chagnes 
        gss       = original_becke(gss)   # can we compute this once? 
        gss1      = 0.5*(1-gss)
        gss2      = 0.5*(1+gss)
        gss1 = np.transpose( np.tril( np.transpose( gss1, (2,0,1)), k=-1), (1,2,0))
        gss2 = np.transpose( np.tril( np.transpose( gss2, (2,0,1)), k=-1), (1,2,0))
        O = np.triu(np.ones((mol.natm, mol.natm)), k=0).reshape( mol.natm, mol.natm, 1)
        gss1 += O 
        gss2 += O
        pbecke = numpy.ones((mol.natm,coords.shape[0]))
        pbecke *= np.prod(gss1, axis=1) 
        pbecke *= np.prod(gss2, axis=0)

        # output: 
        pbecke = pbecke[ia] * (1./pbecke.sum(axis=0)) # 5130 / np.sum((20, 5130), axis=0)

        #normalized = pbecke[ia] * (1./pbecke.sum(axis=0)) 
        weights = vol * pbecke 
        coords_all.append(coords)
        weights_all.append(weights)

    times.append(time.time()) 
    if concat:
        coords_all = numpy.vstack(coords_all)
        weights_all = numpy.hstack(weights_all)

    times.append(time.time())
    times = np.array(times)
    print("\t", np.sum(times[1:]-times[:-1]), np.around(times[1:]-times[:-1], 2))
    return coords_all, weights_all

def build_grid(self):
  mol = self.mol
  times = []

  times.append(time.time())
  atom_grids_tab            = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune)
  times.append(time.time())
  self.coords, self.weights = _get_partition_v2(mol, atom_grids_tab) # 800ms now, made it slower, lol. 
  times.append(time.time())
  idx = arg_group_grids(mol, self.coords) 
  times.append(time.time())
  self.coords  = self.coords[idx]
  self.weights = self.weights[idx]
  times.append(time.time())
  times = np.array(times)
  print("@", times[1:]- times[:-1])

  self.screen_index = self.non0tab = None

  return self



import time

for _ in range(100):  
  t0 =time.time() 
  g1            = pyscf.dft.gen_grid.Grids(mol) 
  g1.level      = 1 
  g1 = build_grid(g1)  
  t1 = time.time()

  g2            = pyscf.dft.gen_grid.Grids(mol) 
  g2.level      = 1 
  g2 = g2.build()
  t2 = time.time()


  ao              = mol.eval_gto('GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1', g1.coords, 4) 
  print(t2-t1, t1-t0, time.time()-t1)

  assert np.allclose(g1.coords, g2.coords)
  assert np.allclose(g1.weights, g2.weights)

(5180, 3) [-6.069884  0.50082   0.393247]
(20, 5180, 3)
(5120, 3) [-3.608663 -0.71908   0.54549 ]
(20, 5120, 3)
(5120, 3) [-1.728852  1.042963  1.701109]
(20, 5120, 3)


  gss       = 1 / atm_dist.reshape( mol.natm, mol.natm, -1) * diffs.reshape(mol.natm, mol.natm, -1)  # diffs is the only ne that chagnes
  gss       = 1 / atm_dist.reshape( mol.natm, mol.natm, -1) * diffs.reshape(mol.natm, mol.natm, -1)  # diffs is the only ne that chagnes


(5120, 3) [0.682019 0.442767 0.327844]
(20, 5120, 3)
(5120, 3) [-0.189712 -0.620789 -2.057385]
(20, 5120, 3)
(5120, 3) [-2.58479  -1.286769 -1.988916]
(20, 5120, 3)
(5120, 3) [3.019168 0.781612 1.068374]
(20, 5120, 3)
(5300, 3) [3.473776 1.797958 3.351123]
(20, 5300, 3)
(5120, 3) [ 5.161451  0.066148 -0.561195]
(20, 5120, 3)
(5204, 3) [ 6.571349 -1.74609   0.783382]
(20, 5204, 3)
(2484, 3) [-5.915439  1.969593 -0.928895]
(20, 2484, 3)
(2484, 3) [-7.505124 -0.779241 -0.039157]
(20, 2484, 3)
(2484, 3) [-3.742775 -2.505998  1.654412]
(20, 2484, 3)
(2484, 3) [-1.501425  0.710758  3.768463]
(20, 2484, 3)
(2484, 3) [-2.220073  3.076167  1.452334]
(20, 2484, 3)
(2484, 3) [ 0.962201 -0.871978 -3.767765]
(20, 2484, 3)
(2484, 3) [-3.689738 -2.137775 -3.555492]
(20, 2484, 3)
(2484, 3) [ 4.560978 -0.683934 -2.400397]
(20, 2484, 3)
(2484, 3) [ 6.389054  1.775628 -0.918618]
(20, 2484, 3)
(2484, 3) [ 7.936479 -0.812761  1.667917]
(20, 2484, 3)
	 1.123810052871704 [0.   0.   1.12 0.  ]
@ [0.00342584 1

ValueError: operands could not be broadcast together with shapes (76364,3) (76368,3) 

In [15]:
# Perform the calculation
import time

for _ in range(100): # ~ 100ms using all threads!  9.7ms using 4 threads ;; 35ms using 1 thread 2-5 ms using 20 threads ; ~3 ms using 16 threeds
  t0 =time.time()
  #mol.intor("int2e_sph") # this has #pragma OMP inside it's code 
  mol.intor("int1e_kin")   # this doesn't have OMP inside (the libcint code anyway)
  mol.intor("int1e_ovlp") # this also doesn't have OMp
  mol.intor("int1e_nuc") # this also doesn't have omp 
  print(time.time()-t0)

# so perhaps the omp code livs in pyscf? 

0.014376401901245117
0.006399869918823242
0.0041615962982177734
0.004090070724487305
0.003934621810913086
0.003927946090698242
0.003795146942138672
0.0038750171661376953
0.00420832633972168
0.003873586654663086
0.003798246383666992
0.004973888397216797
0.0037887096405029297
0.004019975662231445
0.003830432891845703
0.0039441585540771484
0.0037925243377685547
0.0037467479705810547
0.004841804504394531
0.00405430793762207
0.003945112228393555
0.003799915313720703
0.004167795181274414
0.0038340091705322266
0.003916025161743164
0.003950357437133789
0.00398564338684082
0.003941059112548828
0.0038042068481445312
0.0038263797760009766
0.0038230419158935547
0.003815889358520508
0.003909111022949219
0.0038077831268310547
0.0038864612579345703
0.0038471221923828125
0.003876924514770508
0.003877878189086914
0.003877401351928711
0.003955841064453125
0.0037293434143066406
0.003952503204345703
0.003779888153076172
0.003812074661254883
0.0038607120513916016
0.003923892974853516
0.003939390182495117
0