In [1]:
import tensorflow as tf
import numpy as np
import os
import logging
import string
import random
import yaml
from datetime import datetime

from dimenet.model.dimenet import DimeNet
from dimenet.model.activations import swish
from dimenet.training.trainer import Trainer
from dimenet.training.data_container import DataContainer
from dimenet.training.data_provider import DataProvider

In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

### Load config

In [3]:
with open('config.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [4]:
num_features = config['num_features']
num_blocks = config['num_blocks']

num_bilinear = config['num_bilinear']
num_spherical = config['num_spherical']
num_radial = config['num_radial']

cutoff = config['cutoff']
envelope_exponent = config['envelope_exponent']

num_before_skip = config['num_before_skip']
num_after_skip = config['num_after_skip']
num_dense_output = config['num_dense_output']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset = config['dataset']

batch_size = config['batch_size']
targets = config['targets']

In [None]:
# Directory to load the trained model from
directory = # Fill this in

log_dir = os.path.join(directory, 'logs')

### Load dataset

In [6]:
data_container = DataContainer(dataset, cutoff=cutoff, target_keys=targets)

# Initialize DataProvider (splits dataset into training, validation and test set based on data_seed)
data_provider = DataProvider(data_container, num_train, num_valid, batch_size,
                             seed=data_seed, randomized=True)
test = {}

# Initialize datasets
test['dataset'] = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
test['dataset_iter'] = iter(test['dataset'])

### Initialize model

In [8]:
model = DimeNet(num_features=num_features, num_blocks=num_blocks, num_bilinear=num_bilinear,
                num_spherical=num_spherical, num_radial=num_radial,
                cutoff=cutoff, envelope_exponent=envelope_exponent,
                num_before_skip=num_before_skip, num_after_skip=num_after_skip,
                num_dense_output=num_dense_output, num_targets=len(targets),
                activation=swish)

### Set up checkpointing and load latest checkpoint

In [11]:
# Set up checkpointing
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=trainer.optimizer, model=model)

# Restore latest checkpoint
ckpt_restored = tf.train.latest_checkpoint(log_dir)
ckpt.restore(ckpt_restored)

### Functions for prediction

In [12]:
def calculate_mae(targets, preds):
    """Calculate mean absolute error between two values."""
    delta = tf.abs(targets - preds)
    mae = tf.reduce_mean(delta, axis=0)
    mean_mae = tf.reduce_mean(mae)
    return mean_mae, mae

@tf.function
def test_on_batch(dataset_iter):
    inputs, outputs = next(dataset_iter)
    preds = model(inputs, training=False)
    mean_mae, mae = calculate_mae(outputs, preds)
    loss = mean_mae
    return loss, mean_mae, mae, preds

def update_average(avg, tmp, num):
    """Incrementally update an average."""
    return avg + (tmp - avg) / num

### Training loop

In [None]:
# Initialize aggregates
loss_avg = 0.
mae_avg = 0.
mean_mae_avg = 0.
preds_total = np.zeros([data_provider.nsamples['test'], len(targets)], dtype=np.float32)

In [14]:
steps_per_epoch = int(np.ceil(data_provider.nsamples['test'] / batch_size))
for step in range(steps_per_epoch):
    
    # Perform training step
    loss, mean_mae, mae, preds = test_on_batch(test['dataset_iter'])

    # Update aggregates
    loss_avg = update_average(loss_avg, loss, step + 1)
    mae_avg = update_average(mae_avg, mae, step + 1)
    mean_mae_avg = update_average(mean_mae_avg, mean_mae, step + 1)
    preds_total[step * batch_size:min((step + 1) * batch_size, data_provider.nsamples['test'])] = preds.numpy()
mean_log_mae_avg = np.mean(np.log(mae_avg))