Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

In [1]:
%load_ext autoreload
%autoreload 2

## Dependencies and configuration

Install the JAX experimental for IPU (and addons).  

Install `pyscf-ipu`:

In [4]:
# PySCF IPU dependencies 
%pip install -e "..[cpu]"

Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Obtaining file:///home/awf/dev/gc-gh-public/pyscf-ipu
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: pyscf-ipu
  Attempting uninstall: pyscf-ipu
    Found existing installation: pyscf-ipu 0.0.1
    Uninstalling pyscf-ipu-0.0.1:
      Successfully uninstalled pyscf-ipu-0.0.1
  Running setup.py develop for pyscf-ipu
Successfully installed pyscf-ipu
Note: you may need to restart the kernel to use updated packages.


# Download and preprocess GDB 11 dataset

In [5]:
load = False
gdb_filename = "../gdb/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")
if load:
  # Download and extract GDB11 dataset.
  !wget -p -O ./gdb/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./gdb/gdb11.tgz --directory ./gdb/

  from  gdb import sortgdb

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)

import os
os.system(f'ls -l {out_filename}')
assert os.path.getsize(out_filename) == 6985727

-rw-r--r-- 1 awf awf 6985727 Sep 14 17:55 ../gdb/gdb11_size09_sorted.csv


In [6]:
import os
# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_USE_MODEL"] = "True"

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu

# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [20]:
# Equivalent to command line:
# python density_functional_theory.py  -generate  -save  -fname <dataset_name>
#        -level 0  -plevel 0  -num_conformers <num_conformers>
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args, _eigh, recompute

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0
args.num_conformers = 32 # TODO 1000
args.id = 1
args.limit = 33 # Do only 33

args.gdb = 9
args.float32 = True

process_args(args)

if args.nan:
    jax.config.update("jax_debug_nans", True)

backend = args.backend
eigh = _eigh

t0 = time.time()
print("loading gdb data")

# TODO: tidy
if args.gdb == 10: args.smiles = [a for a in open("../gdb/gdb11_size10_sorted.csv", "r").read().split("\n")]
if args.gdb == 9:  args.smiles = [a for a in open("../gdb/gdb11_size09_sorted.csv", "r").read().split("\n")]
if args.gdb == 7:  args.smiles = [a for a in open("../gdb/gdb11_size07_sorted.csv", "r").read().split("\n")]
if args.gdb == 8:  args.smiles = [a for a in open("../gdb/gdb11_size08_sorted.csv", "r").read().split("\n")]

# used as example data for quick testing.
if args.gdb == 6:  args.smiles = ["c1ccccc1"]*args.num_conformers
if args.gdb == 5:  args.smiles = ['CCCCC']*args.num_conformers
if args.gdb == 4:  args.smiles = ['CCCC']*args.num_conformers


print("DONE!", time.time()-t0)


[BASIS] STO-3G
loading gdb data
DONE! 0.03591132164001465


In [21]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]

    print(smile)

    b = Chem.MolFromSmiles(smile)
    if not args.nohs: b = Chem.AddHs(b, explicitOnly=False)
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1: continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print(string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

Length GDB:  444285
FC(F)=C(F)C#CC#N
F  -5.976908 -0.272344 1.162637; C  -3.645289 -0.453388 0.099800; F  -3.255826 -1.468704 -2.208022; C  -1.700322 0.395499 1.375347; F  -1.910616 1.443580 3.702194; C  0.784915 0.247064 0.338877; C  2.897062 0.159965 -0.536063; C  5.393510 0.038812 -1.547242; N 7.413474 -0.090484 -2.387528; 
	 33
>>>  2 3
_2_3
14_GDB9_f32True_grid0_backendcpu_2_3


  0%|          | 0/2 [00:00<?, ?it/s]

[PAD] Last molecule had grisize=9816 we're using 10797. 
[Fc1nnc(=O)oc1F]
[conformers] 32
16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[1 / 32] Hs=    0 -565.726910 3257.8 0.0 0.0 0.1 4.6 0.3 381.1 0.1 25.4 1.7 0.2 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.9 0.2 2.9 3678.5 [1 ; 0]:   0%|          | 0/2 [00:04<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[2 / 32] Hs=    0 -565.719822 3499.4 0.0 0.0 0.1 4.2 0.3 372.5 0.1 25.7 2.1 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 2.7 0.2 2.6 3913.3 [1 ; 0]:   0%|          | 0/2 [00:08<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[3 / 32] Hs=    0 -565.720691 3446.0 0.0 0.0 0.1 4.0 0.3 383.4 0.0 20.2 1.7 0.2 0.3 0.1 0.0 0.0 2.0 0.1 0.0 0.0 0.2 0.3 12.9 0.2 2.6 3874.6 [1 ; 0]:   0%|          | 0/2 [00:11<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[4 / 32] Hs=    0 -565.718995 3385.9 0.0 0.0 0.1 3.5 0.4 382.0 0.1 19.9 1.7 0.2 0.4 0.4 0.0 0.0 2.6 0.1 0.0 0.0 0.2 0.4 0.5 0.2 2.2 3800.8 [1 ; 0]:   0%|          | 0/2 [00:15<?, ?it/s] 

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[5 / 32] Hs=    0 -565.721342 3377.4 0.0 0.0 0.1 2.5 0.2 356.6 0.0 18.4 1.9 0.3 0.3 0.1 0.0 0.0 2.3 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.6 3764.0 [1 ; 0]:   0%|          | 0/2 [00:19<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[6 / 32] Hs=    0 -565.727083 3305.1 0.0 0.0 0.1 2.5 0.2 352.6 0.0 21.3 2.0 0.3 0.3 0.1 0.0 0.0 2.3 0.1 0.0 0.0 0.2 0.4 0.4 0.2 2.1 3690.2 [1 ; 0]:   0%|          | 0/2 [00:23<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[7 / 32] Hs=    0 -565.718795 3350.3 0.0 0.0 0.1 3.9 0.3 378.8 0.1 23.0 1.8 0.2 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.5 3764.9 [1 ; 0]:   0%|          | 0/2 [00:27<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[8 / 32] Hs=    0 -565.730291 3149.6 0.0 0.0 0.1 3.9 0.3 414.5 0.1 25.5 1.8 0.3 0.3 0.1 0.0 0.0 2.9 0.5 0.0 0.0 0.2 0.3 0.6 0.2 6.6 3607.8 [1 ; 0]:   0%|          | 0/2 [00:30<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[9 / 32] Hs=    0 -565.723428 3571.9 0.0 0.0 0.1 2.7 0.2 383.7 0.0 24.0 1.7 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.5 3990.9 [1 ; 0]:   0%|          | 0/2 [00:34<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[10 / 32] Hs=    0 -565.721805 3678.1 0.0 0.0 0.1 3.0 0.2 372.1 0.0 21.8 2.0 0.3 0.3 0.1 0.0 0.0 2.4 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.5 4084.2 [1 ; 0]:   0%|          | 0/2 [00:38<?, ?it/s]

# Loading & visualizing generated data

After the dataset has been created, we can load the data.
(You may wish to spin up a new notebook, and view the data as 
it's being generated in this one).

In [10]:
import pandas as pd

In [11]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
rootpath = f'./data/generated/{args.fname}/'
paths = sorted(os.listdir(rootpath), key=lambda x: os.path.getmtime(rootpath + x))
filename = os.path.join(rootpath, paths[-1], "data.csv")

df = pd.read_csv(filename, compression="gzip")

In [12]:
df

Unnamed: 0.1,Unnamed: 0,smile,atoms,atom_positions,energies,std,pyscf_energies,pyscf_hlgap,pyscf_homo,pyscf_lumo,times,homo,lumo,hlgap,N,basis
0,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.310676011466988, -1.501283183213607, -1.07...","[-15234.328062231103, -15216.222079159548, -15...",0.001670,[0.0],0,0,0,[3.6225e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-3.627531,1.996017,5.623548,50,STO-3G
1,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.8138999079144, -1.8577504582035314, -0.781...","[-15233.707002997247, -15214.106921784089, -15...",0.001685,[0.0],0,0,0,[3.4137e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.810637,2.080500,5.891137,50,STO-3G
2,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-3.3157674827315, -0.5013038903612104, -0.791...","[-15233.268688412909, -15213.38878456933, -152...",0.001627,[0.0],0,0,0,[3.4426e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.781198,2.014855,5.796053,50,STO-3G
3,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.2718425484681184, -1.511352692749988, -0.9...","[-15233.92559508189, -15205.603897519642, -152...",0.002464,[0.0],0,0,0,[3.3526e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.801813,1.896403,5.698216,50,STO-3G
4,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-1.8202534347923944, -2.188123753845587, 0.66...","[-15234.38905939559, -15219.393843797065, -152...",0.001875,[0.0],0,0,0,[3.355e+03 0.000e+00 0.000e+00 1.000e-01 3.700...,-3.760773,1.895937,5.656710,50,STO-3G
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
68,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-5.829719643481505, -1.0442410230540844, -0.3...","[-12921.907973375575, -12847.172828548093, -12...",0.002065,[0.0],0,0,0,[3.5821e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.557515,0.994091,4.551606,50,STO-3G
69,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-3.879745074030985, -3.9453784376908647, 1.89...","[-12896.409026279161, -12448.225227459361, -12...",0.005064,[0.0],0,0,0,[3.8648e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-1.520286,-0.384428,1.135859,50,STO-3G
70,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[4.603785661294676, 2.201512234181421, -1.8361...","[-12921.780569965858, -12843.884053121841, -12...",0.001627,[0.0],0,0,0,[3.5696e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.408105,1.005516,4.413621,50,STO-3G
71,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-1.8221122964701737, -0.056891745197611164, -...","[-12921.990923514906, -12848.056616998725, -12...",0.000762,[0.0],0,0,0,[3.6236e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.538141,0.812941,4.351082,50,STO-3G


In [13]:
# HLgap data.
df["hlgap"]

0     5.623548
1     5.891137
2     5.796053
3     5.698216
4     5.656710
        ...   
68    4.551606
69    1.135859
70    4.413621
71    4.351082
72    3.920660
Name: hlgap, Length: 73, dtype: float64