In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split

from schnet.atoms import stats_per_atom
from schnet.models import SchNet
from schnet.data import ASEReader, DataProvider
from schnet.models import SchNet

In [2]:
# prepares the qm9 database where energy U0 will be the target 
data_reader = ASEReader('qm9.db', # path to database
                        ['energy_U0'], # target property
                        [], [(None, 3)])

In [3]:
std = 1004.55 # standard deviation of energy values for the current training data
mean = -10560.7 # mean of energy values for the current training data

class AtomGenerator(tf.keras.utils.Sequence):
    """Dataloader for batch generation."""
    def __init__(self,data_reader,ids,tar='energy_U0',batch_size=16,shuffle=False):
        self.data = data_reader
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.tar = tar
        self.index = ids
        self.on_epoch_end()

    def __len__(self):
        return len(self.index)//self.batch_size

    def __getitem__(self, index):
        ind = self.index[self.batch_size*index:self.batch_size*(index+1)]
        return self.get_batch(ind)
    
    def on_epoch_end(self):
        if self.shuffle:
            self.index = self.index[np.random.permutation(len(self.index))]
    
    def get_batch(self,idx):
        batch = self.data[idx]
        x_b = batch
        x_b = {k:x[None,...] for k,x in x_b.items()}
        y_b = (batch[self.tar][None,...]-mean)/std # for training we normalize the target values to std==1 and mean==0
        return x_b,y_b


In [4]:
tar = 'energy_U0'
batch_size = 24
intensive = True


N = len(data_reader)

train_idx, val_idx = train_test_split(np.arange(N),test_size=0.1,random_state=42)

train_provider = AtomGenerator(data_reader, batch_size=batch_size, ids=train_idx)

val_provider = AtomGenerator(data_reader, batch_size=batch_size, ids=val_idx)

atomref = np.load('atomref.npz')['atom_ref']

try:
    atomref = np.load('atomref.npz')['atom_ref']
    if tar == 'energy_U0':
        atomref = atomref[:, 1:2]
    if tar == 'energy_U':
        atomref = atomref[:, 2:3]
    if tar == 'enthalpy_H':
        atomref = atomref[:, 3:4]
    if tar == 'free_G':
        atomref = atomref[:, 4:5]
    if tar == 'Cv':
        atomref = atomref[:, 5:6]
except Exception as e:
    print(e)


# determine statistics (std and mean) of the train data
if 0:
    E = data_reader.get_property(tar, train_idx)
    Z = data_reader.get_atomic_numbers(train_idx)
    mean_energy_per_atom, stddev_energy_per_atom = \
        stats_per_atom(E, Z, intensive, atomref)
    print('Energy statistics: mu/atom=' + str(mean_energy_per_atom) +
                 ', std/atom=' + str(stddev_energy_per_atom))



schnet = SchNet(6, 64, 64, 20,
                filter_pool_mode='mean')

In [5]:
import warnings

warnings.filterwarnings("ignore") # (action='once')

In [6]:
# 1/True, if pretrained weights shall be loaded
weight_path = 'schnet_weights.h5'
if 1:
    schnet.predict(train_provider.__getitem__(1)[0]) # necessary for building the model
    schnet.load_weights(weight_path)

In [14]:
schnet.compile(loss='mse',optimizer=tf.keras.optimizers.Adam(lr=1e-4),metrics=['mae'])

In [None]:
schnet.fit(x=train_provider,validation_data=val_provider,epochs=2)

  ...
    to  
  ['...']
  ...
    to  
  ['...']
Train for 5020 steps, validate for 557 steps
Epoch 1/2

In [13]:
if 1: #1/True if model weights shall be saved
    schnet.save_weights('schnet_weights_1.h5')

In [9]:
if 0: #1/True if model shall be saved in tf SavedModel format
    schnet.save('schnet',save_format='tf')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: model.tf/assets
