In [15]:
import tensorflow as tf
import numpy as np
import os
import ast
import logging
import string
import random
import yaml

from datetime import datetime

from model.densnet import DensNet
from model.dumbnet import DumbNet
from model.activations import swish
from training.metrics import Metrics
from training.trainer import Trainer
from training.data_container import DataContainer
from training.data_provider import DataProvider

In [16]:
# set up logger
logger = logging.getLogger()

logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
    fmt='%(asctime)s (%(levelname)s): %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('INFO')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.get_logger().setLevel('WARN')
tf.autograph.set_verbosity(2)

In [17]:
with open('../configs/config_densnet.yaml', 'r') as c:
    config = yaml.safe_load(c)

In [18]:
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

In [19]:
model_name = config['model_name']

s_type_per_atom = config['s_type_per_atom']
p_type_per_atom = config['p_type_per_atom']
emb_size = config['emb_size']
num_interaction_blocks = config['num_interaction_blocks']
width_ticks = config['width_ticks']
length_ticks = config['length_ticks']
cutoff = config['cutoff']

num_train = config['num_train']
num_valid = config['num_valid']
data_seed = config['data_seed']
dataset = config['dataset']
logdir = config['logdir']

num_steps = config['num_steps']
ema_decay = config['ema_decay']

learning_rate = config['learning_rate']
warmup_steps = config['warmup_steps']
decay_rate = config['decay_rate']
decay_steps = config['decay_steps']

batch_size = config['batch_size']
evaluation_interval = config['evaluation_interval']
save_interval = config['save_interval']
restart = config['restart']
comment = config['comment']
target = config['target']

***Create directories***

In [20]:
# Used for creating a random "unique" id for this run
def id_generator(size=8, chars=string.ascii_uppercase + string.ascii_lowercase + string.digits):
    return ''.join(random.SystemRandom().choice(chars) for _ in range(size))

# Create directories
# A unique directory name is created for this run based on the input
if restart is None:
    directory = ("../" + logdir + "/" + datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + model_name
                 + "_" + id_generator()
                 + "_" + os.path.basename(dataset)
                 + "_" + '-'.join(target)
                 + "_" + comment)
else:
    directory = restart
logging.info(f"Directory: {directory}")

if not os.path.exists(directory):
    os.makedirs(directory)
best_dir = os.path.join(directory, 'best')
if not os.path.exists(best_dir):
    os.makedirs(best_dir)
log_dir = os.path.join(directory, 'logs')
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
best_loss_file = os.path.join(best_dir, 'best_loss.npz')
best_ckpt_file = os.path.join(best_dir, 'ckpt')
step_ckpt_folder = log_dir

2023-10-09 14:12:32 (INFO): Directory: ../../logging/20231009_141232_densnet_XtBpMVxL_md_h2.npz_densities_final


***Create summary writer and metrics***

In [21]:
summary_writer = tf.summary.create_file_writer(log_dir)
train = {}
validation = {}
train['metrics'] = Metrics('train', target)
validation['metrics'] = Metrics('val', target)

*Load Dataset*

In [22]:
data_container = DataContainer(dataset, target, cutoff)
data_provider = DataProvider(data_container, num_train, num_valid, batch_size, seed=data_seed, randomized=True)

train['dataset'] = data_provider.get_dataset('train').prefetch(tf.data.experimental.AUTOTUNE)
train['dataset_iter'] = iter(train['dataset'])
validation['dataset'] = data_provider.get_dataset('val').prefetch(tf.data.experimental.AUTOTUNE)
validation['dataset_iter'] = iter(validation['dataset'])

*Initialize model*

In [23]:
if model_name == "dumbnet":
    model = None
else:
    model = DensNet(num_interaction_blocks=num_interaction_blocks, num_grid_points=width_ticks*length_ticks, emb_size=emb_size, s_type_per_atom=s_type_per_atom, p_type_per_atom=p_type_per_atom, activation=swish)

*Save/load best recorded loss*

In [24]:
if os.path.isfile(best_loss_file):
    loss_file = np.load(best_loss_file)
    metrics_best = {k: v.item() for k, v in loss_file.items()}
else:
    metrics_best = validation['metrics'].result()
    for key in metrics_best.keys():
        metrics_best[key] = np.inf
    metrics_best['step'] = 0
    np.savez(best_loss_file, **metrics_best)

*Initialize trainer*

In [25]:
trainer = Trainer(model, learning_rate, warmup_steps, decay_steps, decay_rate, ema_decay, max_grad_norm=1000)

*Set up checkpointing and load latest checkpoint*

In [26]:
# Set up checkpointing
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=trainer.optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, step_ckpt_folder, max_to_keep=3)

# Restore latest checkpoint
ckpt_restored = tf.train.latest_checkpoint(log_dir)
if ckpt_restored is not None:
    ckpt.restore(ckpt_restored)

*Training loop*

In [27]:
with summary_writer.as_default():
    steps_per_epoch = int(np.ceil(num_train / batch_size))

    if ckpt_restored is not None:
        step_init = ckpt.step.numpy()
    else:
        step_init = 1
    for step in range(step_init, num_steps + 1):
        # Update step number
        ckpt.step.assign(step)
        tf.summary.experimental.set_step(step)

        # Perform training step
        trainer.train_on_batch(train['dataset_iter'], train['metrics'])

        # Save progress
        if (step % save_interval == 0):
            manager.save()

        # Evaluate model and log results
        if (step % evaluation_interval == 0):

            # Save backup variables and load averaged variables
            trainer.save_variable_backups()
            trainer.load_averaged_variables()

            # Compute results on the validation set
            for i in range(int(np.ceil(num_valid / batch_size))):
                trainer.test_on_batch(validation['dataset_iter'], validation['metrics'])

            # Update and save best result
            if validation['metrics'].mean_mae < metrics_best['mean_mae_val']:
                metrics_best['step'] = step
                metrics_best.update(validation['metrics'].result())

                np.savez(best_loss_file, **metrics_best)
                model.save_weights(best_ckpt_file)

            for key, val in metrics_best.items():
                if key != 'step':
                    tf.summary.scalar(key + '_best', val)
                
            epoch = step // steps_per_epoch
            logging.info(
                f"{step}/{num_steps} (epoch {epoch + 1}):"
                f"Loss: train={train['metrics'].loss:.6f}, val={validation['metrics'].loss:.6f};"
                f"logMAE: train={train['metrics'].mean_log_mae:.6f}, "
                f"val={validation['metrics'].mean_log_mae:.6f}"
            )

            train['metrics'].write()
            validation['metrics'].write()

            train['metrics'].reset_states()
            validation['metrics'].reset_states()

            # Restore backup variables
            trainer.restore_variable_backups()

2023-10-09 14:12:36 (INFO): 100/10000 (epoch 4):Loss: train=0.499737, val=1.104903;logMAE: train=-0.693673, val=0.099757
2023-10-09 14:12:38 (INFO): 200/10000 (epoch 8):Loss: train=0.058403, val=0.966555;logMAE: train=-2.840384, val=-0.034017
2023-10-09 14:12:41 (INFO): 300/10000 (epoch 12):Loss: train=0.032842, val=0.836789;logMAE: train=-3.416045, val=-0.178183
2023-10-09 14:12:44 (INFO): 400/10000 (epoch 15):Loss: train=0.016356, val=0.734644;logMAE: train=-4.113172, val=-0.308369
2023-10-09 14:12:46 (INFO): 500/10000 (epoch 19):Loss: train=0.015583, val=0.643530;logMAE: train=-4.161548, val=-0.440787
2023-10-09 14:12:49 (INFO): 600/10000 (epoch 23):Loss: train=0.018584, val=0.557410;logMAE: train=-3.985455, val=-0.584454
2023-10-09 14:12:53 (INFO): 700/10000 (epoch 26):Loss: train=0.018285, val=0.484553;logMAE: train=-4.001658, val=-0.724529
2023-10-09 14:12:56 (INFO): 800/10000 (epoch 30):Loss: train=0.012898, val=0.426132;logMAE: train=-4.350688, val=-0.853005
2023-10-09 14:12:58