Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

In [1]:
%load_ext autoreload
%autoreload 2

## Dependencies and configuration

Install the JAX experimental for IPU (and addons).  

Install `pyscf-ipu`:

In [4]:
# PySCF IPU dependencies 
%pip install -e "..[cpu]"

Looking in indexes: https://awf%40graphcore.ai:****@artifactory.sourcevertex.net:443/api/pypi/pypi-virtual/simple, https://pypi.python.org/simple/
Obtaining file:///home/awf/dev/gc-gh-public/pyscf-ipu
  Preparing metadata (setup.py) ... [?25ldone
Installing collected packages: pyscf-ipu
  Attempting uninstall: pyscf-ipu
    Found existing installation: pyscf-ipu 0.0.1
    Uninstalling pyscf-ipu-0.0.1:
      Successfully uninstalled pyscf-ipu-0.0.1
  Running setup.py develop for pyscf-ipu
Successfully installed pyscf-ipu
Note: you may need to restart the kernel to use updated packages.


# Download and preprocess GDB 11 dataset

In [5]:
load = False
gdb_filename = "../gdb/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")
if load:
  # Download and extract GDB11 dataset.
  !wget -p -O ./gdb/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./gdb/gdb11.tgz --directory ./gdb/

  from  gdb import sortgdb

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)

import os
os.system(f'ls -l {out_filename}')
assert os.path.getsize(out_filename) == 6985727

-rw-r--r-- 1 awf awf 6985727 Sep 14 17:55 ../gdb/gdb11_size09_sorted.csv


In [6]:
import os
# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_USE_MODEL"] = "True"

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu

# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [28]:
# Equivalent to command line:
# python pyscf_ipu/dft.py  -generate  -save  -fname "notebook_dataset"
#        -level 0  -plevel 0  -num_conformers 1000
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0

args.float32 = True

quick = True # Set to False to generate full dataset (takes some time)

if quick:
  args.id = 1
  args.num_conformers = 32 # Set to 1000 for full dataset
  args.limit = 33 # Comment out for full dataset
else:
  args.num_conformers = 1000

process_args(args)

# Load GDB09 data
gdb = 'gdb11_size09'
args.smiles = open(f"../gdb/{gdb}_sorted.csv", "r").read().split("\n")
print(f'Loaded {len(args.smiles)} molecules from {gdb}')

[BASIS] STO-3G
Loaded 444285 molecules from gdb11_size09


In [31]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft, recompute

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]

    print('Trying', smile)

    b = Chem.MolFromSmiles(smile)
    b = Chem.AddHs(b, explicitOnly=False)

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1:
       print('Did not embed', b) 
       continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print('Conformer: ', string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

Length GDB:  33
Trying FC(F)=C(F)C#CC#N
Conformer:  F  -5.775981 1.584354 -0.408267; C  -3.341397 1.198118 0.222585; F  -2.364088 2.509684 2.171222; C  -1.963741 -0.431030 -1.051220; F  -2.981350 -1.737195 -3.011672; C  0.644124 -0.851237 -0.382575; C  2.853464 -1.167069 0.178425; C  5.417701 -1.538671 0.866345; N 7.511269 -1.873382 1.415159; 
recompute: _0_32
HIT
15_GDB-1_f32True_grid0_backendcpu_0_32
(15, 16)
HIT
16_GDB-1_f32True_grid0_backendcpu_0_32
(24, 16)
17_GDB-1_f32True_grid0_backendcpu_0_32


  0%|          | 0/33 [00:00<?, ?it/s]

[PAD] Last molecule had grisize=9816 we're using 10797. 
[FC(F)(C#N)C(=O)C#N]
[conformers] 32
16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[1 / 32] Hs=    0 -529.163460 4052.5 0.0 0.0 0.1 4.3 0.3 374.6 0.1 24.7 2.6 0.3 0.3 0.1 0.0 0.0 2.2 0.2 0.0 0.0 0.2 0.2 0.5 0.2 2.2 4465.6 [1 ; 0]:   0%|          | 0/33 [00:05<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[2 / 32] Hs=    0 -529.159224 4257.5 0.0 0.0 0.1 3.1 0.2 374.8 0.1 25.1 2.3 0.4 0.4 0.1 0.0 0.0 2.5 0.1 0.0 0.0 0.2 0.4 0.5 0.2 2.4 4670.4 [1 ; 0]:   0%|          | 0/33 [00:09<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[3 / 32] Hs=    0 -529.167999 4195.1 0.0 0.0 0.1 4.1 0.3 437.8 0.1 22.5 2.2 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 4.4 0.2 2.3 4672.6 [1 ; 0]:   0%|          | 0/33 [00:14<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[4 / 32] Hs=    0 -529.163292 3937.3 0.0 0.0 0.1 2.6 0.2 383.6 0.1 27.6 1.7 0.4 0.3 0.1 0.0 0.0 2.5 0.1 0.0 0.0 0.2 0.3 3.3 0.1 2.1 4362.6 [1 ; 0]:   0%|          | 0/33 [00:18<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[5 / 32] Hs=    0 -529.166403 4101.1 0.0 0.0 0.1 3.9 0.3 353.5 0.1 21.6 2.4 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.9 4490.1 [1 ; 0]:   0%|          | 0/33 [00:23<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[6 / 32] Hs=    0 -529.168985 4040.0 0.0 0.0 0.1 3.2 0.2 365.3 0.1 25.0 2.3 0.4 0.3 0.1 0.0 0.0 2.4 0.1 0.0 0.0 0.3 0.4 9.7 0.2 2.6 4452.7 [1 ; 0]:   0%|          | 0/33 [00:27<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[7 / 32] Hs=    0 -529.163254 4055.1 0.0 0.0 0.1 2.7 0.2 364.9 0.1 25.4 1.8 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 0.4 0.1 2.2 4456.4 [1 ; 0]:   0%|          | 0/33 [00:32<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[8 / 32] Hs=    0 -529.157264 3967.8 0.0 0.0 0.1 2.9 0.4 356.3 0.1 25.3 1.8 0.3 0.3 0.1 0.0 0.0 2.3 0.1 0.0 0.0 0.2 0.3 1.7 0.1 2.5 4362.6 [1 ; 0]:   0%|          | 0/33 [00:36<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[9 / 32] Hs=    0 -529.162430 3915.9 0.0 0.0 0.1 3.8 0.2 369.7 0.1 25.6 2.0 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.4 0.2 2.4 4323.9 [1 ; 0]:   0%|          | 0/33 [00:40<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[10 / 32] Hs=    0 -529.171086 4086.9 0.0 0.0 0.1 3.5 0.3 398.5 0.1 21.9 2.3 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 3.1 0.2 2.4 4522.7 [1 ; 0]:   0%|          | 0/33 [00:45<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[11 / 32] Hs=    0 -529.170307 3991.6 0.0 0.0 0.1 2.9 0.3 370.8 0.1 27.2 1.9 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.8 4401.9 [1 ; 0]:   0%|          | 0/33 [00:49<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[12 / 32] Hs=    0 -529.164063 3826.7 0.0 0.0 0.1 2.7 0.2 354.6 0.1 20.1 2.3 0.3 0.3 0.1 0.0 0.0 2.6 0.2 0.0 0.0 0.2 0.4 0.5 0.2 2.2 4213.8 [1 ; 0]:   0%|          | 0/33 [00:53<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[13 / 32] Hs=    0 -529.168340 4072.0 0.0 0.0 0.1 2.6 0.2 353.4 0.1 25.4 1.8 0.3 0.3 0.1 0.0 0.0 2.3 0.1 0.0 0.0 0.2 0.3 0.4 0.2 2.8 4462.6 [1 ; 0]:   0%|          | 0/33 [00:58<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[14 / 32] Hs=    0 -529.157547 3979.7 0.0 0.0 0.1 3.0 0.4 360.8 0.0 24.8 2.1 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.8 4377.9 [1 ; 0]:   0%|          | 0/33 [01:02<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[15 / 32] Hs=    0 -529.167086 4130.9 0.0 0.0 0.1 2.4 0.2 377.9 0.1 19.9 2.1 0.2 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.5 4540.2 [1 ; 0]:   0%|          | 0/33 [01:07<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[16 / 32] Hs=    0 -529.065189 4047.9 0.0 0.0 0.1 3.0 0.2 386.4 0.1 20.6 2.2 0.2 0.3 0.1 0.0 0.0 2.4 0.1 0.0 0.0 0.2 0.3 0.5 0.2 3.3 4468.1 [1 ; 0]:   0%|          | 0/33 [01:11<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[17 / 32] Hs=    0 -529.165016 3786.4 0.0 0.0 0.1 3.9 0.2 376.2 0.1 25.1 1.9 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.5 4200.5 [1 ; 0]:   0%|          | 0/33 [01:16<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[18 / 32] Hs=    0 -529.167073 4365.5 0.0 0.0 0.2 3.3 0.2 363.1 0.1 25.1 2.1 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.4 0.2 2.5 4766.2 [1 ; 0]:   0%|          | 0/33 [01:20<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[19 / 32] Hs=    0 -529.162768 4112.2 0.0 0.0 0.1 2.5 0.2 348.3 0.1 21.5 2.2 0.3 0.4 0.2 0.1 0.0 2.3 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.8 4494.5 [1 ; 0]:   0%|          | 0/33 [01:25<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[20 / 32] Hs=    0 -529.167329 3848.5 0.0 0.0 0.1 2.7 0.2 358.9 0.1 25.4 2.0 0.3 0.3 0.1 0.0 0.0 2.5 0.1 0.0 0.0 0.2 0.4 0.6 0.2 2.5 4245.1 [1 ; 0]:   0%|          | 0/33 [01:29<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[21 / 32] Hs=    0 -529.166210 3957.2 0.0 0.0 0.1 2.6 0.2 380.4 0.1 25.4 2.0 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.4 4374.6 [1 ; 0]:   0%|          | 0/33 [01:33<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[22 / 32] Hs=    0 -529.160071 4023.8 0.0 0.0 0.1 3.1 0.2 369.8 0.0 19.6 2.3 0.3 0.3 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 0.4 0.2 2.9 4425.8 [1 ; 0]:   0%|          | 0/33 [01:38<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[23 / 32] Hs=    0 -529.165280 3848.3 0.0 0.0 0.1 3.9 0.2 364.4 0.1 26.4 1.8 0.2 0.2 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 9.0 0.2 2.4 4260.0 [1 ; 0]:   0%|          | 0/33 [01:42<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[24 / 32] Hs=    0 -529.162137 4056.7 0.0 0.0 0.1 3.7 0.2 377.4 0.1 26.4 2.0 0.3 0.3 0.1 0.0 0.0 2.5 0.1 0.0 0.0 0.2 0.4 1.7 0.2 2.5 4474.9 [1 ; 0]:   0%|          | 0/33 [01:47<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[25 / 32] Hs=    0 -529.168544 4403.8 0.0 0.0 0.1 4.3 0.3 466.8 0.1 25.3 2.2 0.3 0.3 0.1 0.0 0.0 2.5 0.1 0.0 0.0 0.2 0.4 0.6 0.2 2.7 4910.3 [1 ; 0]:   0%|          | 0/33 [01:51<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[26 / 32] Hs=    0 -529.160784 3945.7 0.0 0.0 0.1 2.6 0.2 356.6 0.1 20.4 2.4 0.4 0.4 0.1 0.0 0.0 2.1 0.1 0.0 0.0 0.2 0.3 0.4 0.1 1.9 4334.1 [1 ; 0]:   0%|          | 0/33 [01:56<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[27 / 32] Hs=    0 -529.165990 3901.7 0.0 0.0 0.1 2.6 0.2 365.1 0.1 24.7 1.9 0.2 0.2 0.1 0.0 0.0 2.3 0.1 0.0 0.0 0.2 0.3 2.3 0.2 3.6 4305.9 [1 ; 0]:   0%|          | 0/33 [02:00<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[28 / 32] Hs=    0 -529.166220 3915.2 0.0 0.0 0.1 2.5 0.2 370.8 0.0 27.3 1.8 0.3 0.3 0.1 0.0 0.0 2.4 0.1 0.0 0.0 0.2 0.3 0.5 0.2 2.2 4324.5 [1 ; 0]:   0%|          | 0/33 [02:04<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


[29 / 32] Hs=    0 -529.162915 3808.7 0.0 0.0 0.1 2.5 0.2 366.2 0.1 26.0 1.9 0.3 0.3 0.1 0.0 0.0 2.2 0.1 0.0 0.0 0.2 0.3 1.2 0.1 1.9 4212.4 [1 ; 0]:   0%|          | 0/33 [02:09<?, ?it/s]

16.4025 (45, 45, 45, 45) 15
7.77384 (4, 10797, 45) 20


# Loading & visualizing generated data

After the dataset has been created, we can load the data.
(You may wish to spin up a new notebook, and view the data as 
it's being generated in this one).

In [10]:
import pandas as pd

In [11]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
rootpath = f'./data/generated/{args.fname}/'
paths = sorted(os.listdir(rootpath), key=lambda x: os.path.getmtime(rootpath + x))
filename = os.path.join(rootpath, paths[-1], "data.csv")

df = pd.read_csv(filename, compression="gzip")

In [12]:
df

Unnamed: 0.1,Unnamed: 0,smile,atoms,atom_positions,energies,std,pyscf_energies,pyscf_hlgap,pyscf_homo,pyscf_lumo,times,homo,lumo,hlgap,N,basis
0,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.310676011466988, -1.501283183213607, -1.07...","[-15234.328062231103, -15216.222079159548, -15...",0.001670,[0.0],0,0,0,[3.6225e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-3.627531,1.996017,5.623548,50,STO-3G
1,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.8138999079144, -1.8577504582035314, -0.781...","[-15233.707002997247, -15214.106921784089, -15...",0.001685,[0.0],0,0,0,[3.4137e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.810637,2.080500,5.891137,50,STO-3G
2,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-3.3157674827315, -0.5013038903612104, -0.791...","[-15233.268688412909, -15213.38878456933, -152...",0.001627,[0.0],0,0,0,[3.4426e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.781198,2.014855,5.796053,50,STO-3G
3,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-2.2718425484681184, -1.511352692749988, -0.9...","[-15233.92559508189, -15205.603897519642, -152...",0.002464,[0.0],0,0,0,[3.3526e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.801813,1.896403,5.698216,50,STO-3G
4,0,CC(C(F)C=O)=C(F)F,CCCFCOCFFHHHHH,"[-1.8202534347923944, -2.188123753845587, 0.66...","[-15234.38905939559, -15219.393843797065, -152...",0.001875,[0.0],0,0,0,[3.355e+03 0.000e+00 0.000e+00 1.000e-01 3.700...,-3.760773,1.895937,5.656710,50,STO-3G
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
68,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-5.829719643481505, -1.0442410230540844, -0.3...","[-12921.907973375575, -12847.172828548093, -12...",0.002065,[0.0],0,0,0,[3.5821e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.557515,0.994091,4.551606,50,STO-3G
69,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-3.879745074030985, -3.9453784376908647, 1.89...","[-12896.409026279161, -12448.225227459361, -12...",0.005064,[0.0],0,0,0,[3.8648e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-1.520286,-0.384428,1.135859,50,STO-3G
70,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[4.603785661294676, 2.201512234181421, -1.8361...","[-12921.780569965858, -12843.884053121841, -12...",0.001627,[0.0],0,0,0,[3.5696e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.408105,1.005516,4.413621,50,STO-3G
71,0,COC(=O)C=C(F)C#C,COCOCCFCCHHHHH,"[-1.8221122964701737, -0.056891745197611164, -...","[-12921.990923514906, -12848.056616998725, -12...",0.000762,[0.0],0,0,0,[3.6236e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-3.538141,0.812941,4.351082,50,STO-3G


In [13]:
# HLgap data.
df["hlgap"]

0     5.623548
1     5.891137
2     5.796053
3     5.698216
4     5.656710
        ...   
68    4.551606
69    1.135859
70    4.413621
71    4.351082
72    3.920660
Name: hlgap, Length: 73, dtype: float64