Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

In [1]:
%load_ext autoreload
%autoreload 2

## Dependencies and configuration

Install the JAX experimental for IPU (and addons).  

Install `pyscf-ipu`:

In [1]:
# PySCF IPU dependencies 
%pip install -e "..[ipu]"

Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Looking in links: https://graphcore-research.github.io/jax-experimental/wheels.html
Collecting tessellate-ipu@ git+https://github.com/graphcore-research/tessellate-ipu.git@main (from -r ../requirements_ipu.txt (line 12))
  Cloning https://github.com/graphcore-research/tessellate-ipu.git (to revision main) to /tmp/pip-install-vg38zvhk/tessellate-ipu_3b9064933d6a400fad6bde34e593ba46
  Running command git clone --filter=blob:none --quiet https://github.com/graphcore-research/tessellate-ipu.git /tmp/pip-install-vg38zvhk/tessellate-ipu_3b9064933d6a400fad6bde34e593ba46
  Resolved https://github.com/graphcore-research/tessellate-ipu.git to

# Download and preprocess GDB 11 dataset

In [2]:
load = False
gdb_filename = "../gdb/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")
if load:
  # Download and extract GDB11 dataset.
  !wget -p -O ./gdb/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./gdb/gdb11.tgz --directory ./gdb/

  from  gdb import sortgdb

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)

import os
os.system(f'ls -l {out_filename}')
assert os.path.getsize(out_filename) == 6985727

-rw-r--r-- 1 awf awf 6985727 Sep 14 17:55 ../gdb/gdb11_size09_sorted.csv


In [3]:
import os
# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_USE_MODEL"] = "True"

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu

# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [4]:
# Equivalent to command line:
# python density_functional_theory.py  -generate  -save  -fname <dataset_name>
#        -level 0  -plevel 0  -num_conformers <num_conformers>
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args, _eigh, recompute

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0
args.num_conformers = 64 # TODO 1000

args.gdb = 9
args.float32 = True

process_args(args)

if args.nan:
    jax.config.update("jax_debug_nans", True)

backend = args.backend
eigh = _eigh

t0 = time.time()
print("loading gdb data")

# TODO: tidy
if args.gdb == 10: args.smiles = [a for a in open("../gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
if args.gdb == 9:  args.smiles = [a for a in open("../gdb/gdb11_size09_sorted.csv", "r").read().split("\n")]
if args.gdb == 7:  args.smiles = [a for a in open("../gdb/gdb11_size07_sorted.csv", "r").read().split("\n")]
if args.gdb == 8:  args.smiles = [a for a in open("../gdb/gdb11_size08_sorted.csv", "r").read().split("\n")]

# used as example data for quick testing.
if args.gdb == 6:  args.smiles = ["c1ccccc1"]*args.num_conformers
if args.gdb == 5:  args.smiles = ['CCCCC']*args.num_conformers
if args.gdb == 4:  args.smiles = ['CCCC']*args.num_conformers


print("DONE!", time.time()-t0)


[BASIS] STO-3G
loading gdb data
DONE! 0.0337986946105957


In [6]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]

    print(smile)

    b = Chem.MolFromSmiles(smile)
    if not args.nohs: b = Chem.AddHs(b, explicitOnly=False)
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1: continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print(string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

Length GDB:  444285
FC(F)(C#N)C(=O)C#N
F  -0.117612 2.317325 -2.503908; C  0.968662 1.271562 -0.342472; F  2.409504 2.965082 0.942392; C  2.610618 -0.880487 -1.051059; N  3.903357 -2.565436 -1.577449; C  -1.139314 0.305476 1.216695; O  -1.381402 1.142178 3.410163; C  -2.892418 -1.531279 0.316869; N -4.361394 -3.024422 -0.411232; 
	 444285
>>>  27767 55533
_27767_55533
HIT
8_GDB9_f32True_grid0_backendcpu_27767_55533
NO!
HIT
6_GDB9_f32True_grid0_backendcpu_27767_55533
NO!
HIT
5_GDB9_f32True_grid0_backendcpu_27767_55533
NO!
HIT
10_GDB9_f32False_grid0_backendcpu_27767_55533
(94, 16)
HIT
7_GDB9_f32True_grid0_backendcpu_27767_55533
NO!
HIT
9_GDB9_f32True_grid0_backendcpu_27767_55533
NO!
11_GDB9_f32False_grid0_backendcpu_27767_55533


  0%|          | 0/27767 [00:00<?, ?it/s]

[PAD] Last molecule had grisize=13416 we're using 14757. 
[CC(C(F)C=O)=C(F)F]
[conformers] 64


  else: E_xc, V_xc, E_coulomb, vj, vk = xc((args, indxs), density_matrix.astype(np.float64), dms.astype(np.float64), cycle, ao.astype(np.float64), electron_repulsion.astype(np.float64), weights.astype(np.float64), vj.astype(np.float64), vk.astype(np.float64), hyb, _num_calls)


Matmul
25.0 (50, 50, 50, 50) 15
11.8056 (4, 14757, 50) 20
Matmul


  else: E_xc, V_xc, E_coulomb, vj, vk = xc((args, indxs), density_matrix.astype(np.float64), dms.astype(np.float64), cycle, ao.astype(np.float64), electron_repulsion.astype(np.float64), weights.astype(np.float64), vj.astype(np.float64), vk.astype(np.float64), hyb, num_calls)
[1 / 64] Hs=    5 -560.141430 2545.8 0.0 0.0 0.1 2.5 0.3 678.0 0.1 52.6 2.5 0.3 0.5 0.2 0.0 0.0 3.6 0.1 0.0 0.0 0.2 0.4 1.5 0.2 2.5 3291.4 [1 ; 0]:   0%|          | 0/27767 [00:04<?, ?it/s]

Matmul
25.0 (50, 50, 50, 50) 15
11.8056 (4, 14757, 50) 20
Matmul


[1 / 64] Hs=    5 -560.141430 2545.8 0.0 0.0 0.1 2.5 0.3 678.0 0.1 52.6 2.5 0.3 0.5 0.2 0.0 0.0 3.6 0.1 0.0 0.0 0.2 0.4 1.5 0.2 2.5 3291.4 [1 ; 0]:   0%|          | 0/27767 [00:06<?, ?it/s]


KeyboardInterrupt: 

# Loading & visualizing generated data

As the dataset is being created in the background, we can load the data which has been already generated.

In [7]:
import pandas as pd

In [8]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
df = pd.read_csv(, compression="gzip")

NameError: name 'dft_process' is not defined

In [None]:
df

In [None]:
# HLgap data.
df["hlgap"]