Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

This notebook shows how to generate DFT datasets on the Graphcore IPU based on
"PySCFIPU: Repurposing Density Functional Theory to Suit Deep Learning", Mathiasen et al, (SynS & ML) Workshop, ICML 2023

https://icml.cc/virtual/2023/28485

Density Functional Theory (DFT) accurately predicts the properties of molecules given their atom types and positions,
and often serves as ground truth for molecular property prediction tasks.
Research in other areas of machine learning has shown that generalisation performance 
of Neural Networks tends to improve with increased dataset size, however, 
the computational cost of DFT has limited the size of DFT datasets.

PySCF_IPU allowed us to create QM10X, a dataset with 100 million conformers, in 3000 IPU-hours.

This notebook, running on 4 IPUs on Paperspace, creates 100K conformers in about 1 hour;
it can be run on multiple 16-IPU systems (paid instances) to generate the full dataset.

Install `pyscf-ipu`:

In [1]:
# PySCF IPU dependencies 
%pip install -e "..[ipu]"
print('install done - restart kernel if packages where installed')

Obtaining file:///notebooks
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting jax@ https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jax-0.3.16%2Bipu-py3-none-any.whl
  Using cached https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jax-0.3.16%2Bipu-py3-none-any.whl (1.2 MB)
Collecting jaxlib@ https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jaxlib-0.3.15%2Bipu.sdk320-cp38-none-manylinux2014_x86_64.whl
  Downloading https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jaxlib-0.3.15%2Bipu.sdk320-cp38-none-manylinux2014_x86_64.whl (111.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.3/111.3 MB[0m [31m52.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tessellate-ipu@ git+https://github.com/graphcore-research/tessellate-ipu.git@main
  Clo

# Download and preprocess GDB 11 dataset

In [3]:
import os
from  gdb import sortgdb # Restart kernel if gdb not found - it was installed above

gdb_filename = "./data/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")

loaded = os.path.exists(out_filename) and os.path.getsize(out_filename) == 6985727

if loaded:
  print(f'Found {out_filename}')
else:
  print(f'Did not find {out_filename}, or size was wrong')

  # Download and extract GDB11 dataset.
  !mkdir -p ./data
  !wget -p -O ./data/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./data/gdb11.tgz --directory ./data

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)
  print('done')


Found ./data/gdb11_size09_sorted.csv


In [5]:
import os

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu
print('import done')

import done


# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [7]:
# Equivalent to command line:
# python pyscf_ipu/dft.py  -generate  -save  -fname "notebook_dataset"
#        -level 0  -plevel 0  -num_conformers 1000
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0

args.float32 = True

quick = True # Set to False to generate full dataset (takes some time)

if quick:
  args.id = 1
  args.num_conformers = 32 # Set to 1000 for full dataset
  args.limit = 33 # Comment out for full dataset
else:
  args.num_conformers = 1000

process_args(args)

# Load GDB09 data
args.gdb = 1 # positive => split for multi ipus
args.split = [0,1] # Partition 0 of 1 (i.e. single IPU)
gdb = 'gdb11_size09'
args.smiles = open(f"data/{gdb}_sorted.csv", "r").read().split("\n")
print(f'Loaded {len(args.smiles)} molecules from {gdb}')

[BASIS] STO-3G
Loaded 444285 molecules from gdb11_size09


In [8]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft, recompute

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]

    print('Trying', smile)

    b = Chem.MolFromSmiles(smile)
    b = Chem.AddHs(b, explicitOnly=False)

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1:
       print('Did not embed', b) 
       continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print('Conformer: ', string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

Length GDB:  444285
Trying FC(F)=C(F)C#CC#N
Conformer:  F  -3.303279 -2.551231 -0.665592; C  -3.551114 -0.180504 -1.594233; F  -5.642254 0.510763 -2.775206; C  -1.480470 1.359424 -1.330889; F  -1.829597 3.736696 -2.292413; C  0.721395 0.525248 -0.394711; C  2.921869 -0.329714 0.586346; C  5.124244 -1.156335 1.526403; N 7.039206 -1.914347 2.335843; 
recompute: 33
>>>  0 32
_0_32
0_GDB1_f32True_grid0_backendcpu_0_32


  0%|          | 0/33 [00:00<?, ?it/s]

[PAD] Last molecule had grisize=9816 we're using 10797. 
[FC(F)(C#N)C(=O)C#N]
[conformers] 32
dft_iter: 16.4025 (45, 45, 45, 45) 15
dft_iter: 7.77384 (4, 10797, 45) 20


[23 / 32] Hs=    0 -529.165139 1379.4 0.0 0.0 0.1 3.7 0.2 282.8 0.0 27.3 2.1 0.2 0.3 0.1 0.0 0.0 2.8 0.1 0.0 0.0 0.2 0.3 0.7 0.2 2.3 1702.8 [1 ; 0]:   0%|          | 0/33 [00:44<?, ?it/s]


KeyboardInterrupt: 

# Loading & visualizing generated data

After the dataset has been created, we can load the data.
(You may wish to spin up a new notebook, and view the data as 
it's being generated in this one).

In [10]:
import pandas as pd

In [11]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
rootpath = f'./data/generated/{args.fname}/'
paths = sorted(os.listdir(rootpath), key=lambda x: os.path.getmtime(rootpath + x))
filename = os.path.join(rootpath, paths[-1], "data.csv")

df = pd.read_csv(filename, compression="gzip")

In [12]:
df

Unnamed: 0.1,Unnamed: 0,smile,atoms,atom_positions,energies,std,pyscf_energies,pyscf_hlgap,pyscf_homo,pyscf_lumo,times,homo,lumo,hlgap,N,basis
0,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[0.15332033500393688, -2.3606980417798886, -2....","[-14387.150902075824, -14321.477650378294, -14...",0.000685,[0.0],0,0,0,[4.5475e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-5.872426,-0.87247,4.999956,45,STO-3G
1,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.3403320213763945, -1.054070549382328, 2.95...","[-14386.940987492488, -14327.730510746002, -14...",0.00141,[0.0],0,0,0,[1.386e+03 8.000e-01 2.000e-01 2.000e-01 5.600...,-5.924247,-1.223614,4.700633,45,STO-3G
2,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.508312250101616, -2.9848902777921023, 0.05...","[-14386.867679381283, -14319.960709000135, -14...",0.001749,[0.0],0,0,0,[1.432e+03 1.000e-01 0.000e+00 1.000e-01 3.200...,-5.924836,-1.235037,4.689799,45,STO-3G
3,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-3.2551803300125353, -2.1329334256170753, 0.8...","[-14387.076018358, -14329.75193126052, -14398....",0.001727,[0.0],0,0,0,[1.4513e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-5.899619,-1.169639,4.72998,45,STO-3G
4,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.84609364909322, -2.7599543764642225, -0.10...","[-14386.926915515363, -14322.835535730032, -14...",0.00168,[0.0],0,0,0,[1.3103e+03 0.0000e+00 0.0000e+00 1.0000e-01 2...,-5.838394,-1.148967,4.689428,45,STO-3G
5,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.5512611780706536, -0.09244560716693372, 2....","[-14387.064401315674, -14325.209987703642, -14...",0.001245,[0.0],0,0,0,[1.3518e+03 1.0000e-01 0.0000e+00 1.0000e-01 3...,-5.920997,-1.230427,4.690571,45,STO-3G
6,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.4575617356051964, -3.0864061240906833, -0....","[-14386.692958603096, -14324.748326944484, -14...",0.001462,[0.0],0,0,0,[1.3116e+03 1.0000e-01 0.0000e+00 1.0000e-01 2...,-5.964935,-1.271967,4.692968,45,STO-3G
7,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.7357029092813003, -2.8385241218542117, -0....","[-14386.7613926158, -14324.300083556986, -1439...",0.001496,[0.0],0,0,0,[1.295e+03 0.000e+00 0.000e+00 1.000e-01 2.900...,-5.921196,-1.042411,4.878785,45,STO-3G
8,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-2.42430136622039, -0.706409910822915, -2.953...","[-14387.090781291115, -14328.945259046752, -14...",0.001376,[0.0],0,0,0,[1.4328e+03 1.0000e-01 0.0000e+00 1.0000e-01 3...,-5.934455,-1.174914,4.75954,45,STO-3G
9,0,FC(F)(C#N)C(=O)C#N,FCFCNCOCN,"[-3.1701050619634006, -2.335481199371203, 0.06...","[-14387.030638933082, -14324.197525311902, -14...",0.001235,[0.0],0,0,0,[1.3675e+03 0.0000e+00 0.0000e+00 1.0000e-01 3...,-5.927952,-1.266243,4.661709,45,STO-3G


In [13]:
# HLgap data.
df["hlgap"]

0     4.999956
1     4.700633
2     4.689799
3     4.729980
4     4.689428
5     4.690571
6     4.692968
7     4.878785
8     4.759540
9     4.661709
10    4.663195
11    4.838700
12    4.670790
13    4.722942
14    4.548279
15    5.448868
16    4.756420
17    4.681641
18    4.663575
19    4.763903
20    4.765332
21    4.805719
22    4.738974
Name: hlgap, dtype: float64