In [1]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Concatenate
import tensorflow_probability as tfp
import numpy as np
import os
import importlib
import logging
from functions import load_train_data

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
importlib.reload(logging)
logging.basicConfig(level = logging.INFO)

gpus = tf.config.experimental.list_physical_devices('GPU')
try:
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    tf.config.experimental.set_virtual_device_configuration(gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=40000)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
except RuntimeError as e:
    print(e)

2 Physical GPUs, 1 Logical GPU


In [2]:
#3flavor_clean, 3flavor_poisson, nsi_clean, nsi_poisson
learn_target = '3flavor_clean'

x_train, y_train, x_val, y_val = load_train_data(learn_target)

Encoder 1 (parameter + spectrum)

In [3]:
def encoder1(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node):
    encoder_parameter_inputs = layers.Input(shape=(len(y_train[0]),),name = 'encoder_parameter_inputs')
    encoder_spectrum_inputs = layers.Input(shape=(len(x_train[0])),name = 'encoder_spectrum_inputs')
    encoder_spectrum_inputs_norm = layers.BatchNormalization()(encoder_spectrum_inputs)
    encoder1_inputs = Concatenate()([encoder_parameter_inputs, encoder_spectrum_inputs_norm])
    x_parameter = layers.Dense(x_parameter_node[0], activation="relu")(encoder1_inputs)
    x_parameter = layers.Dense(x_parameter_node[1], activation="relu")(x_parameter)
    x_parameter = layers.Dense(x_parameter_node[2], activation="relu")(x_parameter)

    z_mean = layers.Dense(latent_dim, name="z_mean")(x_parameter)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x_parameter)

    return keras.Model(inputs=[encoder_parameter_inputs, encoder_spectrum_inputs], outputs=[z_mean, z_log_var], name="encoder_1")

Encoder2 (spectrum to latent)

In [4]:
def encoder2(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node):
    encoder_spectrum_inputs = layers.Input(shape=(len(x_train[0]),),name = 'encoder_spectrum_inputs')
    encoder_spectrum_inputs_norm = layers.BatchNormalization()(encoder_spectrum_inputs)
    x_spectrum = layers.Dense(x_spectrum_node[0], activation="relu")(encoder_spectrum_inputs_norm)
    x_spectrum = layers.Dense(x_spectrum_node[1], activation="relu")(x_spectrum)
    x_spectrum = layers.Dense(x_spectrum_node[2], activation="relu")(x_spectrum)

    z_mean = layers.Dense(guassian_number*latent_dim, name="z_mean")(x_spectrum)
    z_log_var = layers.Dense(guassian_number*latent_dim, name="z_log_var")(x_spectrum)
    z_weight = layers.Dense(guassian_number, name="z_weight")(x_spectrum)

    return keras.Model(inputs=encoder_spectrum_inputs, outputs=[z_mean, z_log_var, z_weight], name="encoder_2")

Decoder (latent to spectrum)

In [5]:
def decoder(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node):
    decoder_latent_inputs = keras.Input(shape=(latent_dim,),name = 'decoder_latent_inputs')
    decoder_spectrum_inputs = layers.Input(shape=(len(x_train[0]),),name = 'decoder_spectrum_inputs')
    decoder_spectrum_inputs_norm = layers.BatchNormalization()(decoder_spectrum_inputs)
    decoder_inputs = Concatenate()([decoder_latent_inputs,decoder_spectrum_inputs_norm])

    x_latent = layers.Dense(x_latent_node[0], activation="relu", name = 'dense_1')(decoder_inputs)
    x_latent = layers.Dense(x_latent_node[1], activation="relu", name = 'dense_2')(x_latent)
    x_latent = layers.Dense(x_latent_node[2], activation="relu", name = 'dense_3')(x_latent)

    z2_mean = layers.Dense(latent_dim_2, name="z_mean")(x_latent)
    z2_log_var = layers.Dense(latent_dim_2, name="z_log_var")(x_latent)

    return keras.Model(inputs=[decoder_latent_inputs, decoder_spectrum_inputs], outputs=[z2_mean, z2_log_var], name="decoder")

In [6]:
class create_cvae_model(keras.Model):
    def __init__(self, guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node, kl_scaling, **kwargs):
        super(create_cvae_model, self).__init__(**kwargs)
        self.encoder1 = encoder1(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node)
        self.encoder2 = encoder2(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node)
        self.decoder = decoder(guassian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node)
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.val_reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
            self.val_reconstruction_loss_tracker,
                ]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            SMALL_CONSTANT = 1e-12
            
            z1_mean, z1_log_var = self.encoder1(x)
            
            temp_var_q = SMALL_CONSTANT + tf.exp(z1_log_var)
            mvn_q = tfp.distributions.MultivariateNormalDiag(
                          loc=z1_mean,
                          scale_diag=temp_var_q)
            
            z1 = mvn_q.sample()
            
            z2_mean, z2_log_var, z2_weight = self.encoder2(x[1])

            z2_mean = tf.reshape(z2_mean, (-1, gaussian_number, latent_dim))
            z2_log_var = tf.reshape(z2_log_var, (-1, gaussian_number, latent_dim))
            z2_weight = tf.reshape(z2_weight, (-1, gaussian_number))
            
            temp_var_r1 = SMALL_CONSTANT + tf.exp(z2_log_var)
            bimix_gauss = tfp.distributions.MixtureSameFamily(
                          mixture_distribution=tfp.distributions.Categorical(logits=z2_weight),
                          components_distribution=tfp.distributions.MultivariateNormalDiag(
                          loc=z2_mean,
                          scale_diag=temp_var_r1))
            
            reconstruction_mean, reconstruction_var = self.decoder([z1, x[1]])     
            
            temp_var_r2 = SMALL_CONSTANT + tf.exp(reconstruction_var)
            reconstruction_parameter = tfp.distributions.MultivariateNormalDiag(
                                     loc=reconstruction_mean,
                                     scale_diag=temp_var_r2)

            kl_loss = tf.reduce_mean(mvn_q.log_prob(z1) - bimix_gauss.log_prob(z1))*kl_scaling
            reconstruction_loss = -1.0*tf.reduce_mean(reconstruction_parameter.log_prob(y))
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def evaluate(self,
               x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None,
               callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False,return_dict=False,
               **kwargs):
        SMALL_CONSTANT = 1e-12
        
        z2_mean, z2_log_var, z2_weight = self.encoder2(x[1])

        z2_mean = tf.reshape(z2_mean, (-1, gaussian_number, latent_dim))
        z2_log_var = tf.reshape(z2_log_var, (-1, gaussian_number, latent_dim))
        z2_weight = tf.reshape(z2_weight, (-1, gaussian_number))
        
        temp_var_r1 = SMALL_CONSTANT + tf.exp(z2_log_var)
        bimix_gauss = tfp.distributions.MixtureSameFamily(
                        mixture_distribution=
                        tfp.distributions.Categorical(logits=z2_weight), components_distribution=
                        tfp.distributions.MultivariateNormalDiag(loc=z2_mean, scale_diag=temp_var_r1))
        z3 = bimix_gauss.sample()
        reconstruction_mean, reconstruction_var = self.decoder([z3, x[1]])     
        
        temp_var_r2 = SMALL_CONSTANT + tf.exp(reconstruction_var)
        reconstruction_parameter = tfp.distributions.MultivariateNormalDiag(
                                    loc=reconstruction_mean,
                                    scale_diag=temp_var_r2)

        reconstruction_loss = -1.0*tf.reduce_mean(reconstruction_parameter.log_prob(y))
        
        self.val_reconstruction_loss_tracker.update_state(reconstruction_loss)
        return {"reconstruction_loss": self.val_reconstruction_loss_tracker.result()}

Model Building, Train, and Save

In [7]:
gaussian_number = 10
latent_dim = 10
latent_dim_2 = len(y_train[0])
x_parameter_node = [64, 64, 64]
x_spectrum_node = x_parameter_node
x_latent_node = x_parameter_node
kl_scaling = 1
lr = 0.0001
# x_spectrum_node = [256, 64, 16]
# x_latent_node = [256, 64, 16]

In [8]:
search_target = 'gaussian_number'
if not os.path.isdir('./tb_log/cvae/{}/{}'.format(learn_target, search_target)): os.mkdir('./tb_log/cvae/{}/{}'.format(learn_target, search_target))
search_value = [10, 15, 20]

for num in search_value:
    cvae = create_cvae_model(gaussian_number, latent_dim, latent_dim_2, x_parameter_node, x_spectrum_node, x_latent_node, kl_scaling)
    cvae.compile(optimizer=keras.optimizers.Adam(learning_rate=lr))
    dir = './tb_log/cvae/{}/{}/{}'.format(learn_target, search_target, num)
    os.mkdir(dir)
    with open(dir + "/model_info.txt".format(search_target, num), 'w') as f:
        f.writelines('gaussian_number = {}\n'.format(gaussian_number))
        f.writelines('latent_dim = {}\n'.format(latent_dim))
        f.writelines('latent_dim_2 = {}\n'.format(latent_dim_2))
        f.writelines('x_parameter_node = {}\n'.format(x_parameter_node))
        f.writelines('x_spectrum_node = {}\n'.format(x_spectrum_node))
        f.writelines('x_latent_node = {}\n'.format(x_latent_node))
        f.writelines('kl_scaling = {}\n'.format(kl_scaling))
        f.writelines('lr = {}\n'.format(lr))
    
    tensorboard_callback = keras.callbacks.TensorBoard(log_dir = dir, histogram_freq = 1)
    cvae.fit(x = [y_train, x_train/1000],
                y = [y_train],
                batch_size=1000,
                epochs=100,
                validation_data=([[y_val, x_val/1000], y_val]),
                verbose=1,
                shuffle=True,
                callbacks=[tensorboard_callback]
    )

    path = "./cvae/{}/{}/{}/".format(learn_target, search_target, num)
    cvae.encoder1.save(path + "encoder_1.h5")
    cvae.encoder2.save(path + "encoder_2.h5")
    cvae.decoder.save(path + "decoder.h5")

Epoch 1/100
Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
E