In [1]:
import tensorflow as tf
import numpy as np
import os
import logging
import yaml
from tqdm.notebook import tqdm
import pickle
import re
from model.dmnet import DMNet
from model.activations import swish
from training.metrics import Metrics
from training.trainer import Trainer
from training.data_container import DataContainer
from training.data_provider import DataProvider



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

 The versions of TensorFlow you are currently using is 2.10.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

In [3]:
with open('../configs/config_dmnet.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [4]:
model_name = config['model_name']

F = config['F']
L = config['L']
K = config['K']
r_cut = config['r_cut']
atoms = config["atoms"]
num_interaction_blocks = config['num_interaction_blocks']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset = config['dataset']
logdir = config['logdir']

num_steps = config['num_steps']
ema_decay = config['ema_decay']

learning_rate = config['learning_rate']
warmup_steps = config['warmup_steps']
decay_rate = config['decay_rate']
decay_steps = config['decay_steps']

batch_size = config['batch_size']
evaluation_interval = config['evaluation_interval']
save_interval = config['save_interval']
restart = config['restart']
comment = config['comment']
target = config['target']

In [5]:
data_container = DataContainer(L, dataset, target, r_cut)

data_provider = DataProvider(L, data_container, num_train, num_valid, batch_size, seed=data_seed, randomized=True)

# Initialize datasets
dataset = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
dataset_iter = iter(dataset)

2024-02-06 14:10:21.576190: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
model = DMNet(F, L, K, r_cut, num_interaction_blocks, atoms)



In [7]:
trainer = Trainer(model, learning_rate, warmup_steps, decay_steps, decay_rate, ema_decay, max_grad_norm=1000)

In [8]:
# Load the trained model from your own training run
files = os.listdir("../logging")
rel_files = [f for f in files if model_name in f]
rel_files.sort()
#directory = "logging/20230810_165359_densnet_Ael8lbyb_md_h2.npz_densities_final"  # Fill this in
# Get latest run
directory = f"../logging/{rel_files[-1]}"
run_date = re.search(r'(.{16})' + model_name, rel_files[-1]).group(1)
run_date = run_date.replace("_", " ")
run_date = run_date[:-1]
run_date = run_date[:4] + "." + run_date[4:6] + "." + run_date[6:11] + ":" + run_date[11:13] + ":" + run_date[13:]
print(f"Run date: {run_date}")
best_ckpt_file = os.path.join(directory, 'best', 'ckpt')
model.load_weights(best_ckpt_file)

Run date: 2024.02.06 12:33:14


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x107baf430>

In [9]:
metrics = Metrics('val', target)
target_shape = data_provider.shape_target
preds_total = np.zeros([36 * data_provider.nsamples['test']] + target_shape[1:], dtype=np.float32)

In [10]:
steps_per_epoch = int(np.floor(data_provider.nsamples['test'] / batch_size))
preds_total = np.zeros([0] + target_shape[1:], dtype=np.float32)
for step in tqdm(range(steps_per_epoch)):
    preds, inputs = trainer.predict_on_batch(dataset_iter, metrics)
    preds_total = np.concatenate([preds, preds_total], axis=0)
    #preds_total[batch_start:batch_end] = preds.numpy()

  0%|          | 0/800 [00:00<?, ?it/s]

In [11]:
preds_log_file = os.path.join(directory, 'preds.npz')
log_dict = {}
log_dict["MAE"] = metrics.mean_mae
log_dict["logMAE"] = metrics.mean_log_mae
log_dict["pred_densities"] = preds_total
log_dict["data_idx"] = data_provider.idx['test']
pickle.dump(log_dict, open(preds_log_file, "wb"))
print("MAE:", metrics.mean_mae)

MAE: 0.004576398059725761
