# Prepare Input Data
We are going to train an MPNN to predict the B3LYP-level atomization energy.

In [1]:
%matplotlib inline
from matplotlib import pyplot as plt
from moldesign.score.mpnn.data import convert_nx_to_dict, make_type_lookup_tables, make_tfrecord
from moldesign.utils.conversions import convert_smiles_to_nx
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import tensorflow as tf
import pandas as pd
import numpy as np
import json



## Get the Data
It is stored on a [GitHub page](https://github.com/globus-labs/g4mp2-atomization-energy) from a previous project

In [2]:
data = pd.read_json('g4mp2_data.json.gz', lines=True)
print(f'Downloaded {len(data)} training entries')

Downloaded 130258 training entries


Convert the SMILES to a networkx object

In [3]:
%%time
data['nx'] = data['smiles_0'].apply(lambda x: convert_smiles_to_nx(x, add_hs=True))

CPU times: user 40 s, sys: 1.15 s, total: 41.1 s
Wall time: 41.1 s


## Save the Data as TF Records
We'll make both a training, validation and test set

In [4]:
test_set = data.query('in_holdout')
print(f'Set aside {len(test_set)} training entries')

Set aside 13026 training entries


In [5]:
train_set, val_set = train_test_split(data.query('not in_holdout'), test_size=0.1, random_state=1)
print(f'Split off {len(train_set)} training and {len(val_set)} validation entries')

Split off 105508 training and 11724 validation entries


Get all of the types observed in the dataset

In [6]:
%%time
atom_types, bond_types = make_type_lookup_tables(data['nx'])

CPU times: user 2.23 s, sys: 20 ms, total: 2.25 s
Wall time: 2.25 s


In [7]:
with open('atom_types.json', 'w') as fp:
    json.dump(atom_types, fp)
with open('bond_types.json', 'w') as fp:
    json.dump(bond_types, fp)

List out the solvation energy columns

In [8]:
sol_cols = sorted([s for s in data.columns if s.startswith('sol_')])
print(f'Found {len(sol_cols)} columns: {sol_cols}')

Found 5 columns: ['sol_acetone', 'sol_acn', 'sol_dmso', 'sol_ethanol', 'sol_water']


Save their dielectric constants

In [9]:
assert sol_cols == ['sol_acetone', 'sol_acn', 'sol_dmso', 'sol_ethanol', 'sol_water']
with open('dielectric_constants.json', 'w') as fp:
    json.dump([20.493, 35.688, 46.826, 20.493, 78.3553], fp)

Save the data in TF format

In [10]:
for name, dataset in zip(['train', 'valid', 'test'], [train_set, val_set, test_set]):
    dataset = dataset.sample(frac=1.)  # Shuffle contents
    with tf.io.TFRecordWriter(f'{name}_data.proto') as writer:
        for _, entry in tqdm(dataset.iterrows(), desc=name):
            record = convert_nx_to_dict(entry['nx'], atom_types, bond_types)
            record['solv_energies'] = np.array(entry[sol_cols].values, dtype=np.float32)
            writer.write(make_tfrecord(record))
    dataset.to_csv(f'{name}_data.csv', index=False)

train: 105508it [01:10, 1495.80it/s]
valid: 11724it [00:07, 1485.92it/s]
test: 13026it [00:08, 1480.91it/s]
