# Convert H5 file to ASE DB
Convert the data to ASE db objects where the data are partitioned into set training, test and validation sets

In [1]:
from typing import Iterator
from pathlib import Path
from sklearn.model_selection import train_test_split
from ase.calculators.singlepoint import SinglePointCalculator
from ase import Atoms, units
from ase.db import connect
from multiprocessing import Pool
from shutil import rmtree
from tqdm import tqdm
import numpy as np
import h5py

Configuration

In [2]:
methods_to_save = ['ccsd(t)_cbs', 'wb97x_dz']  # Which energies to save
test_split = 0.1  # Fraction of full dataset to reserve for testing
valid_split = 0.1  # Fraction of non-test data to hold off for training

## Open the File
The data is stored in an H5 where each keys is a different molecule, and the data for each molecule follows the layout in [this SciData paper](https://www.nature.com/articles/s41597-020-0473-z/tables/2).

In [3]:
original_data = h5py.File('raw-data/ani1x-release.h5')
print(f'Loaded dataset with {len(original_data)} molecules')

Loaded dataset with 3114 molecules


Divide off names

In [4]:
all_names = list(original_data.keys())

In [5]:
train_names, test_names = train_test_split(all_names, test_size=test_split, shuffle=True, random_state=1)
train_names, valid_names = train_test_split(train_names, test_size=valid_split, shuffle=True, random_state=1)

## Save every molecule to ASE database
Record the composition of the molecule as its name, save the energy and forces

In [6]:
data_dir = Path('data/')
data_dir.mkdir(exist_ok=True)

In [7]:
def page_to_atoms(page: h5py.Group, method: str) -> Iterator[Atoms]:
    """Convert a page of the HDF5 file to a single record
    
    Args:
        page: Page of the H5Py database
    Yields:
        ASE Atoms object with the energies and forces        
    """
    
    # Get the energies at the desired level
    all_energies = page[f'{method}.energy']
    
    # Get the forces if they are available
    force_name = f'{method}.forces' 
    all_forces = page[force_name] if force_name in page else None
    
    # Generate configurations
    def _generate_configurations(): 
        for i, coords in enumerate(page['coordinates']):
            # Skip if energy not done
            if np.isnan(all_energies[i]):
                continue
                
            yield (coords,
                   all_energies[i] * units.Ry,
                   None if all_forces is None else np.multiply(all_forces[i], units.Ry))
            
    def _to_atoms(x):
        coords, energy, forces = x
        # Make the atoms object
        atoms = Atoms(numbers=page['atomic_numbers'], positions=coords)
        atoms.center()

        # Add the calculator
        calc = SinglePointCalculator(
            atoms=atoms,
            energy=energy,
            forces=forces
        )
        atoms.calc = calc
        return atoms
        
    yield from map(_to_atoms, _generate_configurations())

In [8]:
for method in methods_to_save:
    # Prepare the output directory
    method_dir = data_dir / method
    if method_dir.exists():
        rmtree(method_dir)
    method_dir.mkdir()
    
    # Loop over each composition
    for composition, page in tqdm(original_data.items(), desc=method):
        # Determine which database to write to
        if composition in train_names:
            db_name = method_dir / 'train.db'
        elif composition in test_names:
            db_name = method_dir / 'test.db'
        elif composition in valid_names:
            db_name = method_dir / 'valid.db'
        else:
            raise ValueError()
         
        with connect(db_name) as db:
            for atoms in page_to_atoms(page, method):
                db.write(atoms, method=method, name=composition)

ccsd(t)_cbs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3114/3114 [07:49<00:00,  6.64it/s]
wb97x_dz: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3114/3114 [1:29:23<00:00,  1.72s/it]
