In [23]:
import numpy as np
import os, re
import tensorflow as tf
from tensorflow import keras
import time
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
import argparse
import utils
from GSGM import GSGM
from GSGM_distill import GSGM_distill
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_addons as tfa
import horovod.tensorflow.keras as hvd

In [24]:
tf.random.set_seed(1233)

In [25]:
hvd.init()
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

In [26]:
config = utils.LoadJson('gaussian_config_jet.json')

event_count = int(1e5)
batch_size = config['BATCH']
data_size = event_count * batch_size
particle_count = 1
particle_feature_dim = config['NUM_FEAT']
particle_types = config['NUM_COND']
jet_feature_dim = config['NUM_JET']

In [27]:
def generate_particles(batch_size, particle_count, particle_feature_dim):
    #return tf.random.normal([batch_size, particle_count, particle_feature_dim], mean = 5., stddev = 1.)
    return tf.concat([tf.random.normal([batch_size // 2, particle_count, particle_feature_dim], mean = 2., stddev = 1.),
                      tf.random.normal([batch_size // 2, particle_count, particle_feature_dim], mean = -2., stddev = 1.)], axis = 0)


def generate_jets(batch_size, jet_feature_dim):
    return tf.concat([tf.random.normal([batch_size // 2, jet_feature_dim], mean = 2., stddev = 1.),
                      tf.random.normal([batch_size // 2, jet_feature_dim], mean = -2., stddev = 1.)], axis = 0)


def generate_conditional(batch_size, particle_types):
    return tf.ones([batch_size, particle_types])


def generate_mask(batch_size, particle_count):
    return tf.ones([batch_size, particle_count, 1])

In [28]:
def generate_batches(batch_size, particle_count, particle_feature_dim, particle_types, jet_feature_dim):
    particles = tf.data.Dataset.from_tensor_slices(generate_particles(batch_size, particle_count, particle_feature_dim))
    jets = tf.data.Dataset.from_tensor_slices(generate_jets(batch_size, jet_feature_dim))
    conditionals = tf.data.Dataset.from_tensor_slices(generate_conditional(batch_size, particle_types))
    masks = tf.data.Dataset.from_tensor_slices(generate_mask(batch_size, particle_count))
    return particles, jets, conditionals, masks

In [29]:
training_data = tf.data.Dataset.zip(generate_batches(int(0.8 * batch_size), particle_count, particle_feature_dim, particle_types, jet_feature_dim))
training_data = training_data.shuffle(event_count).repeat().batch(batch_size)

test_data = tf.data.Dataset.zip(generate_batches(int(0.1 * batch_size), particle_count, particle_feature_dim, particle_types, jet_feature_dim))
test_data = test_data.shuffle(event_count).repeat().batch(batch_size)

In [30]:
model = GSGM(config = config, npart = particle_count)
model_name = config['MODEL_NAME']
checkpoint_folder = '../checkpoints_{}/checkpoint'.format(model_name)

lr_schedule = tf.keras.experimental.CosineDecay(
    initial_learning_rate=config['LR']*hvd.size(),
    decay_steps=config['MAXEPOCH']*int(data_size*0.8/config['BATCH'])
)

opt = tf.keras.optimizers.Adamax(learning_rate=lr_schedule)
opt = hvd.DistributedOptimizer(opt, average_aggregated_gradients=True)

In [31]:
model.compile(
    optimizer=opt,
    #run_eagerly=True,
    experimental_run_tf_function=False,
    weighted_metrics=[])

callbacks = [
    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
    hvd.callbacks.MetricAverageCallback(),
    EarlyStopping(patience=100,restore_best_weights=True),
]

In [32]:
if hvd.rank() == 0:
    checkpoint = ModelCheckpoint(checkpoint_folder,mode='auto',
                                 period=1,save_weights_only=True)
    callbacks.append(checkpoint)



In [35]:
history = model.fit(
    training_data,
    epochs=config['MAXEPOCH'],
    callbacks=callbacks,
    steps_per_epoch=50,#int(data_size*0.8/config['BATCH']),
    validation_data=test_data,
    validation_steps=10,#int(data_size*0.1/config['BATCH']),
    verbose=1 if hvd.rank()==0 else 0,
    #steps_per_epoch=1,
)

Epoch 1/250
Epoch 2/250
Epoch 3/250
Epoch 4/250
Epoch 5/250
Epoch 6/250
Epoch 7/250
Epoch 8/250
Epoch 9/250
Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
 7/50 [===>..........................] - ETA: 1s - loss: 2.6623 - loss_part: 0.9922 - loss_jet: 1.6701

KeyboardInterrupt: 