In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import os
import importlib
import logging
import keras.backend as K
from functions import load_train_data

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
importlib.reload(logging)
logging.basicConfig(level = logging.INFO)

gpus = tf.config.experimental.list_physical_devices('GPU')
try:
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    tf.config.experimental.set_virtual_device_configuration(gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=20000)])
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
except RuntimeError as e:
    print(e)

2 Physical GPUs, 1 Logical GPU


In [2]:
#3flavor_clean, 3flavor_poisson, nsi_clean, nsi_poisson
learn_target = 'nsi_poisson'

x_train, y_train, x_val, y_val = load_train_data(learn_target)

In [3]:
def prior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    prior_model = tf.keras.Sequential([
            tfp.layers.DistributionLambda(
                lambda t: tfp.distributions.MultivariateNormalDiag(
                    loc=tf.zeros(n), scale_diag=tf.ones(n)/10))])
    return prior_model

def posterior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    posterior_model = tf.keras.Sequential([
            tfp.layers.VariableLayer(
                tfp.layers.MultivariateNormalTriL.params_size(n), dtype=dtype
            ), tfp.layers.MultivariateNormalTriL(n),])
    return posterior_model

In [4]:
def create_bnn_model(train_size):
    inputs = tf.keras.Input(shape=(len(x_train[0]),), name = 'input')
    features = tf.keras.layers.BatchNormalization()(inputs)
    features = tf.keras.layers.Dense(256, activation='relu')(features)
    features = tf.keras.layers.Dense(128, activation='relu')(features)

    for units in [32]:
        features = tfp.layers.DenseVariational(
            units=units,
            make_prior_fn=prior,
            make_posterior_fn=posterior,
            kl_weight=1 / train_size,
            activation="sigmoid",
        )(features)
    features = tf.keras.layers.Dense(16, activation='relu')(features)
    features = tf.keras.layers.Dense(len(y_train[0]), activation='linear')(features)
    outputs = features
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

alpha = np.array([1, 20, 1, 1]) if '3flavor' in learn_target else np.array([1, 20, 1, 1, 1, 1, 1])
alpha = alpha/sum(alpha)

def custom_mse(y_true, y_pred):
    loss = K.square(y_pred - y_true)  
    loss = loss * alpha      
    loss = K.sum(loss, axis=1)        
    return loss

In [5]:
bnn = create_bnn_model(len(x_train))
bnn.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    loss=custom_mse,
    metrics=[tf.keras.metrics.MeanSquaredError()],
)

Instructions for updating:
`scale_identity_multiplier` is deprecated; please combine it into `scale_diag` directly instead.


In [6]:
bnn.fit(x_train, y_train,
            validation_data=(x_val, y_val),
            batch_size=1024,
            epochs=100,
            verbose=1,
            shuffle = True,
            )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x7f4ca4488c40>

In [7]:
bnn.save_weights('./bnn/' + learn_target + '/weight_1.h5')