In [1]:
import tensorflow as tf
import numpy as np
import os
import ast
import logging
import string
import random
import yaml
from tqdm.notebook import tqdm
import pickle

from datetime import datetime

from convnet.model.convnet import ConvNet
from convnet.model.dumbnet import DumbNet
from convnet.model.activations import swish
from convnet.training.metrics import Metrics
from convnet.training.trainer import Trainer
from convnet.training.data_container import DataContainer
from convnet.training.data_provider import DataProvider


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
# Set up logger
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
        fmt='%(asctime)s (%(levelname)s): %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

In [3]:
with open('config_convnet.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [4]:
model_name = config['model_name']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset = config['dataset']
logdir = config['logdir']

num_steps = config['num_steps']
ema_decay = config['ema_decay']

learning_rate = config['learning_rate']
warmup_steps = config['warmup_steps']
decay_rate = config['decay_rate']
decay_steps = config['decay_steps']

batch_size = config['batch_size']
evaluation_interval = config['evaluation_interval']
save_interval = config['save_interval']
restart = config['restart']
comment = config['comment']
target = config['target']

In [5]:
data_container = DataContainer(dataset, target)

data_provider = DataProvider(data_container, num_train, num_valid, batch_size, seed=data_seed, randomized=True)

dataset = data_provider.get_dataset('test').prefetch(tf.data.experimental.AUTOTUNE)
dataset_iter = iter(dataset)

(9600, 5, 5, 1)


In [6]:
model = ConvNet(activation=swish) if "dumb" not in model_name else DumbNet(activation=swish)

In [7]:
trainer = Trainer(model, learning_rate, warmup_steps, decay_steps, decay_rate, ema_decay, max_grad_norm=1000)

In [8]:
# Load the trained model from your own training run
files = os.listdir("logging")
rel_files = [f for f in files if model_name in f and "dumb" + model_name not in f]
rel_files.sort()
#directory = "logging/20230810_165359_densnet_Ael8lbyb_md_h2.npz_densities_final"  # Fill this in
# Get latest run
directory = f"logging/{rel_files[-1]}"
best_ckpt_file = os.path.join(directory, 'best', 'ckpt')
model.load_weights(best_ckpt_file)

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fb8d02f0160>

In [9]:
metrics = Metrics('val', target)
target_shape = data_provider.shape_target
preds_total = np.zeros([data_provider.nsamples['test']] + target_shape[1:], dtype=np.float32)

In [10]:
steps_per_epoch = int(np.ceil(data_provider.nsamples['test'] / batch_size) - 1)

for step in tqdm(range(steps_per_epoch)):
    preds = trainer.predict_on_batch(dataset_iter, metrics)

    batch_start = step * batch_size
    batch_end = min((step + 1) * batch_size, data_provider.nsamples['test'])
    preds_total[batch_start:batch_end] = preds.numpy()

  0%|          | 0/283 [00:00<?, ?it/s]

[[[0.00452879956]
  [0.00547274854]
  [0.00605123211]
  [0.00605014618]
  [0.00572067266]]

 [[0.00421398226]
  [0.00504264841]
  [0.00556757068]
  [0.00561437709]
  [0.00538389711]]

 [[0.00365508418]
  [0.00426648557]
  [0.00468906155]
  [0.00482327]
  [0.00477782171]]

 [[0.00307539501]
  [0.00350267626]
  [0.00383364339]
  [0.00402524602]
  [0.00411973801]]

 [[0.00257279677]
  [0.00288692024]
  [0.00315475743]
  [0.00335997972]
  [0.0035154589]]]
[[[0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]]

 [[0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]]

 [[0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]]

 [[0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]]

 [[0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]
  [0.00589939393]]]
[[[0.0124294991]
  [0.0140661485]
  [0.0152114592]
  [0.0157112665]
  [0.0157088451]

In [11]:
preds_log_file = os.path.join(directory, 'preds.npz')
print(preds_log_file)
log_dict = {}
log_dict["MAE"] = metrics.mean_mae
log_dict["logMAE"] = metrics.mean_log_mae
log_dict["pred_densities"] = preds_total
log_dict["data_idx"] = data_provider.idx['test']
log_dict["no_weights"] = sum([tf.reduce_prod(w.shape) for w in model.trainable_weights])
pickle.dump(log_dict, open(preds_log_file, "wb"))

logging/20230912_132124_convnet_JZrMddsl_md_h2.npz_corrs_final/preds.npz
