Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

In [1]:
%load_ext autoreload
%autoreload 2

## Dependencies and configuration

Install the JAX experimental for IPU (and addons).  

Install `pyscf-ipu`:

In [2]:
# PySCF IPU dependencies 
%pip install -e "..[ipu]"

Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Looking in links: https://graphcore-research.github.io/jax-experimental/wheels.html
Collecting tessellate-ipu@ git+https://github.com/graphcore-research/tessellate-ipu.git@main (from -r ../requirements_ipu.txt (line 12))
  Cloning https://github.com/graphcore-research/tessellate-ipu.git (to revision main) to /tmp/pip-install-m_6gehta/tessellate-ipu_3d246e1e1a604be897536c721c6b8e8c
  Running command git clone --filter=blob:none --quiet https://github.com/graphcore-research/tessellate-ipu.git /tmp/pip-install-m_6gehta/tessellate-ipu_3d246e1e1a604be897536c721c6b8e8c
  Resolved https://github.com/graphcore-research/tessellate-ipu.git to

# Download and preprocess GDB 11 dataset

In [13]:
load = False
gdb_filename = "../gdb/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")
if load:
  # Download and extract GDB11 dataset.
  !wget -p -O ./gdb/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./gdb/gdb11.tgz --directory ./gdb/

  from  gdb import sortgdb

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)

import os
os.system(f'ls -l {out_filename}')
assert os.path.getsize(out_filename) == 6985727

-rw-r--r-- 1 awf awf 6985727 Sep 14 17:55 ../gdb/gdb11_size09_sorted.csv


In [14]:
import os
# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_USE_MODEL"] = "True"

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu

# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [16]:
# Equivalent to command line:
# python density_functional_theory.py  -generate  -save  -fname <dataset_name>
#        -level 0  -plevel 0  -num_conformers <num_conformers>
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args, _eigh, recompute

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0
args.num_conformers = 64 # TODO 1000

args.gdb = 9
args.float32 = True

process_args(args)

if args.float32 or args.float16:
    if args.enable64: jax.config.update('jax_enable_x64', True) #
    EPSILON_B3LYP  = 1e-20
    CLIP_RHO_MIN   = 1e-9
    CLIP_RHO_MAX   = 1e12

else:  # float64
    jax.config.update('jax_enable_x64', True)
    EPSILON_B3LYP  = 1e-20
    CLIP_RHO_MIN   = 1e-9
    CLIP_RHO_MAX   = 1e12

if args.nan:
    jax.config.update("jax_debug_nans", True)

backend = args.backend
eigh = _eigh

t0 = time.time()
print("loading gdb data")

# TODO: tidy
if args.gdb == 10: args.smiles = [a for a in open("../gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
if args.gdb == 9:  args.smiles = [a for a in open("../gdb/gdb11_size09_sorted.csv", "r").read().split("\n")]
if args.gdb == 7:  args.smiles = [a for a in open("../gdb/gdb11_size07_sorted.csv", "r").read().split("\n")]
if args.gdb == 8:  args.smiles = [a for a in open("../gdb/gdb11_size08_sorted.csv", "r").read().split("\n")]

# used as example data for quick testing.
if args.gdb == 6:  args.smiles = ["c1ccccc1"]*args.num_conformers
if args.gdb == 5:  args.smiles = ['CCCCC']*args.num_conformers
if args.gdb == 4:  args.smiles = ['CCCC']*args.num_conformers


print("DONE!", time.time()-t0)


[BASIS] STO-3G
loading gdb data


FileNotFoundError: [Errno 2] No such file or directory: 'gdb/gdb11_size09_sorted.csv'

In [38]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]
    smile = smile

    print(smile)

    b = Chem.MolFromSmiles(smile)
    if not args.nohs: b = Chem.AddHs(b, explicitOnly=False)
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1: continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print(string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

Length GDB:  64
c1ccccc1
C  2.494448 0.606609 -0.073509; C  1.814968 -1.916764 -0.032200; C  -0.733273 -2.548026 0.043001; C  -2.489353 -0.601320 0.073324; C  -1.800524 1.916390 0.031808; C  0.738779 2.542368 -0.043100; H  4.476328 1.116906 -0.132153; H  3.274478 -3.352021 -0.059057; H  -1.245857 -4.518913 0.074777; H  -4.466875 -1.107240 0.131838; H  -3.293532 3.343278 0.059642; H 1.230413 4.518733 -0.074370; 


args(generate=True, num_conformers=64, nohs=False, verbose=False, choleskycpu=False, resume=False, density_mixing=False, skip_minao=False, num=10, id=0, its=20, step=1, spin=0, str='', numerror=False, ipumult=False, skippyscf=False, skipus=False, float32=True, float16=False, basis='STO-3G', xc='b3lyp', skip=0, backend='cpu', benchmark=False, nan=False, skipdiis=False, skipeigh=False, methane=False, H=False, forloop=False, he=False, level=0, plevel=0, C=-1, gdb=6, skiperi=False, randeri=False, save=True, fname='notebook_dataset', multv=2, intv=1, randomSeed=43, scale_eri=1, scale_w=1, scale_ao=1, scale_overlap=1, scale_cholesky=1, scale_ghamil=1, scale_eigvects=1, scale_sdf=1, scale_vj=1, scale_errvec=1, sk=[-2], debug=False, jit=False, enable64=False, rattled_std=0, profile=False, pyscf=False, uniform_pyscf=-1, threads=1, threads_int=1, split=[1, 16], limit=-1, seperate=False, gname='', checkc=False, geneigh=False, smiles=['c1ccccc1', 'c1ccccc1', 'c1ccccc1', 'c1ccccc1', 'c1ccccc1', 'c1

# Loading & visualizing generated data

As the dataset is being created in the background, we can load the data which has been already generated.

In [None]:
import pandas as pd

In [None]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
df = pd.read_csv(dft_process.path, compression="gzip")

In [None]:
df

In [None]:
# HLgap data.
df["hlgap"]