Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# DFT dataset generation using PySCF IPU

This notebook shows how to generate DFT datasets on the Graphcore IPU based on
"PySCFIPU: Repurposing Density Functional Theory to Suit Deep Learning", Mathiasen et al, (SynS & ML) Workshop, ICML 2023

https://icml.cc/virtual/2023/28485

Density Functional Theory (DFT) accurately predicts the properties of molecules given their atom types and positions,
and often serves as ground truth for molecular property prediction tasks.
Research in other areas of machine learning has shown that generalisation performance 
of Neural Networks tends to improve with increased dataset size, however, 
the computational cost of DFT has limited the size of DFT datasets.

PySCF_IPU allowed us to create QM10X, a dataset with 100 million conformers, in 3000 IPU-hours.

This notebook, running on 4 IPUs on Paperspace, creates 100K conformers in about 1 hour;
it can be run on multiple 16-IPU systems (paid instances) to generate the full dataset.

Install `pyscf-ipu`:

In [None]:
# PySCF IPU dependencies 
%pip install -e "..[ipu]"
print('install done')

# Download and preprocess GDB 11 dataset

In [None]:
import os

gdb_filename = "./data/gdb11_size09.smi"
out_filename = gdb_filename.replace(".smi", "_sorted.csv")

loaded = os.path.exists(out_filename) and os.path.getsize(out_filename) == 6985727

if loaded:
  print(f'Found {out_filename}')
else:
  print(f'Did not find {out_filename}, or size was wrong')

  # Download and extract GDB11 dataset.
  !mkdir -p ./data
  !wget -p -O ./data/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1
  !tar -xvf ./data/gdb11.tgz --directory ./gdb/

  from  gdb import sortgdb

  # Filter & sort GDB11 dataset (size 9).
  gdb_sorted = sortgdb.sort_gdb(gdb_filename, keep_only_atoms_count=9)
  # Save output as csv.
  gdb_sorted.to_csv(out_filename, index=False, header=False)



In [None]:
import os

# PySCF IPU setup: use a single device per process.
os.environ["JAX_IPU_DEVICE_COUNT"] = "1"
# JAX/XLA IPU compilation cache.
os.environ['TF_POPLAR_FLAGS'] = """
  --executable_cache_path=/tmp/ipu-ef-cache
"""

# First import of JAX and TessellateIPU may take a few minutes...
import jax
import tessellate_ipu
print('import done')

# Create a DFT dataset using PySCF IPU

In the following example, we use only a single IPU. Multiple IPUs can be used by simply launching a collection of PySCF IPU processes instead of a single one.

In [None]:
# Equivalent to command line:
# python pyscf_ipu/dft.py  -generate  -save  -fname "notebook_dataset"
#        -level 0  -plevel 0  -num_conformers 1000
#        -gdb 9  -float32

import time
from pyscf_ipu.dft import get_args, process_args

args = get_args([])
args.backend = 'cpu'
args.generate = True
args.save = True
args.fname = "notebook_dataset"

args.level = 0
args.plevel = 0

args.float32 = True

quick = True # Set to False to generate full dataset (takes some time)

if quick:
  args.id = 1
  args.num_conformers = 32 # Set to 1000 for full dataset
  args.limit = 33 # Comment out for full dataset
else:
  args.num_conformers = 1000

process_args(args)

# Load GDB09 data
gdb = 'gdb11_size09'
args.smiles = open(f"../gdb/{gdb}_sorted.csv", "r").read().split("\n")
print(f'Loaded {len(args.smiles)} molecules from {gdb}')

In [None]:
from rdkit import Chem

from pyscf_ipu.dft import angstrom_to_bohr, get_atom_string, jax_dft, recompute

print("Length GDB: ", len(args.smiles))

if args.limit != -1:
    args.smiles = args.smiles[:args.limit]

for i in range(int(args.id), min(int(args.id)+1000, len(args.smiles))):
    smile = args.smiles[i]

    print('Trying', smile)

    b = Chem.MolFromSmiles(smile)
    b = Chem.AddHs(b, explicitOnly=False)

    e = Chem.AllChem.EmbedMolecule(b)
    if e == -1:
       print('Did not embed', b) 
       continue

    locs = b.GetConformer().GetPositions() * angstrom_to_bohr
    atoms = [atom.GetSymbol() for atom in b.GetAtoms()]
    atom_string, string = get_atom_string(" ".join(atoms), locs)

    print('Conformer: ', string)
    break

recompute(args, None, 0, 0, our_fun=jax_dft, str=string)

# Loading & visualizing generated data

After the dataset has been created, we can load the data.
(You may wish to spin up a new notebook, and view the data as 
it's being generated in this one).

In [None]:
import pandas as pd

In [None]:
# Output DFT dataset is a compressed CSV file.
# NOTE: it may take a couple of minutes before the file is generated.
rootpath = f'./data/generated/{args.fname}/'
paths = sorted(os.listdir(rootpath), key=lambda x: os.path.getmtime(rootpath + x))
filename = os.path.join(rootpath, paths[-1], "data.csv")

df = pd.read_csv(filename, compression="gzip")

In [None]:
df

In [None]:
# HLgap data.
df["hlgap"]