In [1]:
from megnet.models import MEGNetModel
import numpy as np
from operator import itemgetter
import json

def predict(model, graph):
    """
    predict from graph
    """
    return model.predict(model.graph_convertor.graph_to_input(graph)).ravel()[0]

def get_graph_from_doc(doc):
    """
    Convert a json document into a megnet graph
    """
    atom = [i['type'] for i in doc['atoms']]

    index1_temp = [i['a_idx'] for i in doc['atom_pairs']]
    index2_temp = [i['b_idx'] for i in doc['atom_pairs']]
    bond_temp = [i['spatial_distance'] for i in doc['atom_pairs']]

    index1 = index1_temp + index2_temp
    index2 = index2_temp + index1_temp
    bond = bond_temp + bond_temp
    sort_key = np.argsort(index1)
    it = itemgetter(*sort_key)

    index1 = it(index1)
    index2 = it(index2)
    bond = it(bond)
    graph = {'atom': atom, 'bond': bond, 'index1': index1, 'index2': index2, 'state': [[0, 0]]}
    return graph

# load scalers
with open('../mvl_models/qm9/scaler.json', 'r') as f:
    scaler = json.load(f)
    
# load an example qm9 document
with open('../megnet/data/tests/qm9/000001.json', 'r') as f:
    doc = json.load(f)
# convert to a graph
graph = get_graph_from_doc(doc)

Using TensorFlow backend.


In [2]:
# all target names
names = ['mu', 'alpha', 'HOMO', 'LUMO', 'gap', 'R2', 'ZPVE', 'U0', 'U', 'H', 'G', 'Cv', 'omega1']


y_pred = []
y_true = []

for i in names:
    model = MEGNetModel.from_file('../mvl_models/qm9/' + i+'.hdf5')
    pred = predict(model, graph) 
    
    # if it is an extrinsic quantity, multiply by number of atoms
    # else multiply by 1
    if scaler[i]['is_per_atom']:
        n = len(graph['atom'])
    else:
        n = 1
    # inverse transform of x_transform = (x-x_mean)/x_standard_deviation to get x
    pred = (pred * scaler[i]['std'] + scaler[i]['mean'])  * n
    
    y_pred.append(pred)
    y_true.append(doc['mol_info'][i])


print('Target MEGNet QM9')
for i, j, k in zip(names, y_pred, y_true):
    print(i, j, k)

Target MEGNet QM9
mu -0.008367357600964215 0.0
alpha 13.127069931402847 13.21
HOMO -10.556634190047586 -10.54958839
LUMO 3.241194954064075 3.18637297
gap 13.622226037945168 13.73596136
R2 35.9745626512748 35.3641
ZPVE 1.215309852340181 1.2176516143
U0 -17.165518999471423 -17.1717476062
U -17.35292190545371 -17.2863862853
H -17.420152179695215 -17.3892155206
G -16.10713760463493 -16.1515096204
Cv 6.42743049756236 6.469
omega1 3151.62572191244 3151.7078
