# Prepare Models
SchNet expects the data to be in an ASE database.

In [1]:
from jcesr_ml.schnetpack import make_schnetpack_data
from jcesr_ml.benchmark import load_benchmark_data
import pickle as pkl
import pandas as pd
import numpy as np
import json
import os

In [2]:
useful_columns = ['u0', 'g4mp2_0k', 'g4mp2_atom', 'u0_atom']

In [3]:
project_root = os.path.join('..', '..')

## Load Training Data
Read in the list of datasets to be used to train the models

In [4]:
with open(os.path.join('..', 'datasets.json')) as fp:
    datasets = json.load(fp)

Dataset paths are relative to the root of this project

In [5]:
all_data = []
for path in datasets:
    # Load in the dataset description
    with open(os.path.join(project_root, path)) as fp:
        desc = json.load(fp)
    
    # Use the description to load in the data
    load_fn = getattr(pd, 'read_' + desc['dataset']['format'])
    dataset = load_fn(
        os.path.join(project_root, desc['dlhub']['files']['data']),
        **desc['dataset']['read_options']
    )
    
    # If the surgar dataset, remove an outlier and small molecules
    if 'sugar' in path:
        dataset.query('name != "syringol-4-propylsyringol.xyz" and n_heavy_atoms>9', inplace=True)
    
    # Add only the needed columns to the training data
    all_data.append(dataset[useful_columns + ['xyz']])
    
    del dataset
all_data = pd.concat(all_data)
print('Loaded {} training entries'.format(len(all_data)))

Loaded 130324 training entries


## Make an ASE Database
The current version of SchNetPack relies on storing data in an ASE SQLite database. We must convert our data into that format

In [6]:
db = make_schnetpack_data(all_data, 'train_dataset.db', properties=useful_columns)
with open('train_dataset.pkl', 'wb') as fp:
    pkl.dump(db, fp)