In [1]:
import tensorflow as tf
import numpy as np
import os
import ast
import logging
import string
import random
import yaml
from tqdm.notebook import tqdm
import pickle

from datetime import datetime

from corrnet.model.corrnet import CorrNet
from corrnet.model.activations import swish
from corrnet.training.metrics import Metrics
from corrnet.training.trainer import Trainer
from corrnet.training.data_container import DataContainer
from corrnet.training.data_provider import DataProvider


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

In [3]:
with open('config_corrnet.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [4]:
model_name = config['model_name']

num_basis_fct = config['num_basis_fct']
emb_size = config['emb_size']
num_interaction_blocks = config['num_interaction_blocks']
ao_vals = config['ao_vals']
num_grid_points = config['num_grid_points']
num_featuers = config['num_features']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset = config['dataset']
logdir = config['logdir']

num_steps = config['num_steps']
ema_decay = config['ema_decay']

learning_rate = config['learning_rate']
warmup_steps = config['warmup_steps']
decay_rate = config['decay_rate']
decay_steps = config['decay_steps']

batch_size = config['batch_size']
evaluation_interval = config['evaluation_interval']
save_interval = config['save_interval']
restart = config['restart']
comment = config['comment']
target = config['target']

In [5]:
data_container = DataContainer(dataset, target, 0.2)

# Initialize DataProvider (splits dataset into training, validation and test set based on data_seed)
data_provider = DataProvider(data_container, num_train, num_valid, batch_size,
                             seed=data_seed, randomized=True)

# Initialize datasets
dataset = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
dataset_iter = iter(dataset)

In [6]:
if model_name == "corrnet":
    model = CorrNet(ao_vals=ao_vals, num_featuers=num_featuers, 
    num_interaction_blocks=num_interaction_blocks, num_grid_points=num_grid_points, activation=swish)
else:
    model = CorrNet(ao_vals=ao_vals, num_interaction_blocks=num_interaction_blocks, num_grid_points=num_grid_points, activation=swish)

In [7]:
trainer = Trainer(model, learning_rate, warmup_steps, decay_steps, decay_rate, ema_decay, max_grad_norm=1000)

In [8]:
# Load the trained model from your own training run
files = os.listdir("logging")
rel_files = [f for f in files if model_name in f]
rel_files.sort()
#directory = "logging/20230810_165359_densnet_Ael8lbyb_md_h2.npz_densities_final"  # Fill this in
# Get latest run
directory = f"logging/{rel_files[-1]}"
best_ckpt_file = os.path.join(directory, 'best', 'ckpt')
model.load_weights(best_ckpt_file)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f3063310dc0>

In [9]:
metrics = Metrics('val', target)
target_shape = data_provider.shape_target
preds_total = np.zeros([data_provider.nsamples['test']] + target_shape[1:], dtype=np.float32)

In [10]:
steps_per_epoch = int(np.ceil(data_provider.nsamples['test'] / batch_size) - 1)

for step in tqdm(range(steps_per_epoch)):
    preds = trainer.predict_on_batch(dataset_iter, metrics)

    batch_start = step * batch_size
    batch_end = min((step + 1) * batch_size, data_provider.nsamples['test'])
    preds_total[batch_start:batch_end] = preds.numpy()

  0%|          | 0/6 [00:00<?, ?it/s]

[[0.0237280838 0.0230795257 0.0226791259 ... 0.0202298053 0.0200720243 0.019778762]
 [0.0230795257 0.0224486943 0.0220592376 ... 0.0196768641 0.0195233971 0.0192381497]
 [0.0226791259 0.0220592376 0.0216765385 ... 0.0193354953 0.01918469 0.0189043917]
 ...
 [0.0202298053 0.0196768641 0.0193354953 ... 0.0172472838 0.0171127655 0.0168627389]
 [0.0200720243 0.0195233971 0.01918469 ... 0.0171127655 0.0169792958 0.0167312194]
 [0.019778762 0.0192381497 0.0189043917 ... 0.0168627389 0.0167312194 0.0164867677]]
[[0.012291858 0.0112094199 0.00751666119 ... 0.0263416599 0.0237666722 0.023891015]
 [0.0112094199 0.0101272212 0.00643464085 ... 0.0252597 0.022684712 0.0228079818]
 [0.00751666119 0.00643464085 0.00274182227 ... 0.0215676557 0.0189919528 0.0191160571]
 ...
 [0.0263416599 0.0252597 0.0215676557 ... 0.0403924175 0.0378174298 0.0379412919]
 [0.0237666722 0.022684712 0.0189919528 ... 0.0378174298 0.0352419652 0.0353657082]
 [0.023891015 0.0228079818 0.0191160571 ... 0.0379412919 0.035365

In [11]:
preds_log_file = os.path.join(directory, 'preds.npz')
log_dict = {}
log_dict["MAE"] = metrics.mean_mae
log_dict["logMAE"] = metrics.mean_log_mae
log_dict["pred_densities"] = preds_total
log_dict["data_idx"] = data_provider.idx['test']
pickle.dump(log_dict, open(preds_log_file, "wb"))