Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute "ERI redundancies" on IPU #63

Open
AlexanderMath opened this issue Sep 1, 2023 · 3 comments
Open

Compute "ERI redundancies" on IPU #63

AlexanderMath opened this issue Sep 1, 2023 · 3 comments
Assignees

Comments

@AlexanderMath
Copy link
Contributor

AlexanderMath commented Sep 1, 2023

#55 outlines several redundancies in ERI. We want to determine which integrals are above a threshold (to be computed) on the IPU.

Tasks

  1. write python for loop code which take O(N^2) following this post which computes the right thing (hopefully we can do this with 10-20 lines of code in single file, minimize complexity)
  2. jaxify that code
  3. change backend from cpu to ipu and optimize potentially memory layout if needed.
@AlexanderMath
Copy link
Contributor Author

AlexanderMath commented Sep 4, 2023

@awf The following code demonstrates the inequality.

import pyscf
import numpy as np 

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (1,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph")
N = mol.nao_nr()

for a in range(N):
  for b in range(N):
    for c in range(N):
      for d in range(N):
        abcd      = np.abs(ERI[a,b,c,d])
        sqrt_abab = np.sqrt(np.abs(ERI[a,b,a,b]))
        sqrt_cdcd = np.sqrt(np.abs(ERI[c,d,c,d]))

        print(abcd, sqrt_abab*sqrt_cdcd)
        assert abcd <= sqrt_abab*sqrt_cdcd*+1e9 # add 1e-9 atol 

Note: Computing the N^2 entries at compile-time will take <1ms using int2e_sph_cpu.cpp (if we add back the PRAGMA_OMP stuff). Might be useful in certain scenarios.

Note: IPU code int2e_sph.cpp can run DFT for fixed N without recompiling. I think we will be able to do top-2% of integrals without recomputing, that is, compile one graph once which can then do any sparsity pattern with nonzero<= 2% without recompilation (spending flops as if nonzero=2%).

@AlexanderMath AlexanderMath assigned mihaipgc and unassigned hatemhelal Sep 22, 2023
@mihaipgc
Copy link
Collaborator

The O(N^2) code from the post looks like this:

import pyscf
import numpy as np 

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (10,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph", aosym="s1")
N = mol.nao_nr()

tolerance = 1e-9

ERI[np.abs(ERI)<tolerance] = 0 
true_nonzero_indices = np.nonzero( ERI.reshape(-1) )[0]
true_nonzero_indices_4d = [np.unravel_index(c, (N, N, N, N)) for c in true_nonzero_indices]

screened_indices_4d = []

# find max value
I_max = 0
for a in range(N):
  for b in range(N):
    abab = np.abs(ERI[a,b,a,b])
    if abab > I_max:
        I_max = abab

# collect candidate pairs for s1
considered_indices = []
for a in range(N):
    for b in range(N):
        abab = np.abs(ERI[a,b,a,b])
        if abab*I_max>=tolerance:
            considered_indices.append((a, b))

# generate s1 indices
for ab in considered_indices:
    a, b = ab
    for cd in considered_indices:
        c, d = cd
        screened_indices_4d.append((a, b, c, d))

        
print('N', N)
print('I_max', I_max)
print('ERI.reshape(-1).shape', ERI.reshape(-1).shape)
print('len(considered_indices)', len(considered_indices))
print('len(screened_indices_4d)', len(screened_indices_4d))
print('len(true_nonzero_indices_4d)', len(true_nonzero_indices_4d))

check_s1 = [(item in screened_indices_4d) for item in true_nonzero_indices_4d]
assert np.array(check_s1).all()
print('PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_4d]!')

Output:

N 10
I_max 3.5419481332225047
ERI.reshape(-1).shape (10000,)
len(considered_indices) 50
len(screened_indices_4d) 2500
len(true_nonzero_indices_4d) 1468
PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_4d]!

Note: this version and the one above may not "PASS" for any atom configuration, but will likely be close enough; this also means the difference between "screened" and "true" should not be absolute - in other words, a better test might be computing the error vs the real ERI (will do)

@mihaipgc
Copy link
Collaborator

But it turns out we can do better, if we integrate the symmetries in the above screening strategy:

import pyscf
import numpy as np 

def get_i_j(val, xnp=np, dtype=np.uint64):
    i = (xnp.sqrt(1 + 8*val.astype(dtype)) - 1)//2 # no need for floor, integer division acts as floor. 
    j = (((val - i) - (i**2 - val))//2)
    return i, j

def c2ijkl(c):
    ij, kl = get_i_j(c)
    i, j = get_i_j(ij)
    k, l = get_i_j(kl)
    return (int(i), int(j), int(k), int(l))

mol = pyscf.gto.Mole(atom=[["C", (0,0,0)], ["C", (10,2,3)]])
mol.build()
ERI = mol.intor("int2e_sph", aosym="s1")
ERI_s8 = mol.intor("int2e_sph", aosym="s8")
N = mol.nao_nr()

tolerance = 1e-9

ERI[np.abs(ERI)<tolerance] = 0 
true_nonzero_indices = np.nonzero( ERI.reshape(-1) )[0]
true_nonzero_indices_4d = [np.unravel_index(c, (N, N, N, N)) for c in true_nonzero_indices]

ERI_s8[np.abs(ERI_s8)<tolerance] = 0
true_nonzero_indices_s8 = np.nonzero( ERI_s8.reshape(-1) )[0]
true_nonzero_indices_s8_4d = [c2ijkl(c) for c in true_nonzero_indices_s8]

screened_indices_s8_4d = []

# find max value
I_max = 0
for a in range(N):
  for b in range(N):
    abab      = np.abs(ERI[a,b,a,b])
    if abab > I_max:
        I_max = abab

# collect candidate pairs for s8
considered_indices = []
for a in range(N):
    for b in range(a, N):
        abab = np.abs(ERI[a,b,a,b])
        if abab*I_max>=tolerance:
            considered_indices.append((a, b)) # collect candidate pairs for s8

# generate s8 indices
for ab in considered_indices:
    a, b = ab
    for cd in considered_indices:
        c, d = cd
        if b<=d:
            screened_indices_s8_4d.append((d, c, b, a))

        
print('N', N)
print('I_max', I_max)
print('ERI.reshape(-1).shape', ERI.reshape(-1).shape)
print('ERI_s8.shape', ERI_s8.shape)
print('len(considered_indices)', len(considered_indices))
print('len(screened_indices_s8_4d)', len(screened_indices_s8_4d))
print('len(true_nonzero_indices_s8_4d)', len(true_nonzero_indices_s8_4d))

check_s8 = [(item in screened_indices_s8_4d) for item in true_nonzero_indices_s8_4d]
assert np.array(check_s8).all()
print('PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_s8_4d]!')

Output:

N 10
I_max 3.5419481332225047
ERI.reshape(-1).shape (10000,)
ERI_s8.shape (1540,)
len(considered_indices) 30
len(screened_indices_s8_4d) 505
len(true_nonzero_indices_s8_4d) 291
PASSED [(item in screened_indices_4d) for item in true_nonzero_indices_s8_4d]!

This directly computes the list closer to the nonzero distinct ERI values which we are aiming for (same note above on testing applies here as well)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

3 participants