In [1]:
import tensorflow as tf
import numpy as np
import params
import utils
import os
import tensorflow_probability as tfp
import data_preparation
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

tfd = tfp.distributions
keras = tf.keras
AUTOTUNE = tf.data.experimental.AUTOTUNE
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

NUM_CLASSES = len(params.brand_models)

In [2]:
class constrain_conv(tf.keras.models.Model, tf.keras.callbacks.Callback):
    def __init__(self, model):
        super(constrain_conv, self).__init__()
        self.layer = model.layers[0]
        self.pre_weights = None

    def on_batch_begin(self, batch, logs={}):
        weights = self.layer.get_weights()[0]
        bias = self.layer.get_weights()[1]
        if self.pre_weights is None or np.all(self.pre_weights != weights):
            weights = weights*10000
            weights[2, 2, :, :] = 0
            s = np.sum(weights, axis=(0,1))
            for i in range(3):
                weights[:, :, 0, i] /= s[0, i]
            weights[2, 2, :, :] = -1
            self.pre_weights = weights
        self.layer.set_weights([weights, bias])

In [3]:
data_preparation.collect_split_extract()

train_size = 0
val_size = 0
num_images_per_class = []
class_weight = {}
for m in params.brand_models:
    num_images = len(os.listdir(os.path.join(params.patches_dir, 'train', m)))
    num_images_per_class.append(num_images)
    train_size += num_images
    val_size += len(os.listdir(os.path.join(params.patches_dir, 'val', m)))
    
num_batches = (train_size + params.BATCH_SIZE - 1) // params.BATCH_SIZE

100%|██████████| 2359/2359 [00:00<00:00, 267277.97it/s]


In [1]:
for n in range(len(params.brand_models)):
    class_weight[n] = (1 / num_images_per_class[n])*(train_size)/2.0
    print('Weight for class {}: {:.2f}'.format(n, class_weight[n]))

def _posterior_mean_field(kernel_size, bias_size=0, dtype=None):
    """Posterior function for variational layer."""
    n = kernel_size + bias_size
    c = np.log(np.expm1(1e-5))
    variable_layer = tfp.layers.VariableLayer(
      2 * n, dtype=dtype,
      initializer=tfp.layers.BlockwiseInitializer([
          keras.initializers.TruncatedNormal(mean=0., stddev=.05, seed=None),
          keras.initializers.Constant(np.log(np.expm1(1e-5)))], sizes=[n, n]))

    def distribution_fn(t):
        scale = 1e-5 + tf.nn.softplus(c + t[Ellipsis, n:])
        return tfd.Independent(tfd.Normal(loc=t[Ellipsis, :n], scale=scale),
                           reinterpreted_batch_ndims=1)
    distribution_layer = tfp.layers.DistributionLambda(distribution_fn)
    return tf.keras.Sequential([variable_layer, distribution_layer])


def _make_prior_fn(kernel_size, bias_size=0, dtype=None):
    del dtype  # TODO(yovadia): Figure out what to do with this.
    loc = tf.zeros(kernel_size + bias_size)
    def distribution_fn(_):
        return tfd.Independent(tfd.Normal(loc=loc, scale=1),
                           reinterpreted_batch_ndims=1)
    return distribution_fn


def make_divergence_fn_for_empirical_bayes(std_prior_scale, examples_per_epoch):
    def divergence_fn(q, p, _):
        log_probs = tfd.LogNormal(0., std_prior_scale).log_prob(p.stddev())
        out = tfd.kl_divergence(q, p) - tf.reduce_sum(log_probs)
        return out / examples_per_epoch
    return divergence_fn


def make_prior_fn_for_empirical_bayes(init_scale_mean=-1, init_scale_std=0.1):
    """Returns a prior function with stateful parameters for EB models."""
    def prior_fn(dtype, shape, name, _, add_variable_fn):
        """A prior for the variational layers."""
        untransformed_scale = add_variable_fn(
            name=name + '_untransformed_scale',
            shape=(1,),
            initializer=tf.compat.v1.initializers.random_normal(
                mean=init_scale_mean, stddev=init_scale_std),
            dtype=dtype,
            trainable=False)
        loc = add_variable_fn(
            name=name + '_loc',
            initializer=keras.initializers.Zeros(),
            shape=shape,
            dtype=dtype,
            trainable=True)
        # ??? why 1e-6 ???
        scale = 1e-6 + tf.nn.softplus(untransformed_scale)
        dist = tfd.Normal(loc=loc, scale=scale)
        batch_ndims = tf.size(input=dist.batch_shape_tensor())
        return tfd.Independent(dist, reinterpreted_batch_ndims=batch_ndims)
    return prior_fn

init_prior_scale_mean=-1.9994,
init_prior_scale_std=-0.30840,
std_prior_scale=3.4210

In [1]:
eb_prior_fn = make_prior_fn_for_empirical_bayes(
              init_prior_scale_mean, init_prior_scale_std)

divergence_fn = make_divergence_fn_for_empirical_bayes(
        std_prior_scale, train_size)

kl_divergence_function = (lambda q, p, _: tfd.kl_divergence(q, p) /  # pylint: disable=g-long-lambda
                        tf.cast(train_size, dtype=tf.float32))
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(3, (5, 5), 
                           padding='same'),
    tfp.layers.Convolution2DFlipout(96, 
        kernel_size=7, strides=2, padding='SAME', 
        kernel_prior_fn=eb_prior_fn,
        kernel_divergence_fn=divergence_fn,
        activation=tf.nn.selu),
    tf.keras.layers.MaxPool2D(
        pool_size=[3, 3], strides=2,
        padding='SAME'),
    tfp.layers.Convolution2DFlipout(
        64, kernel_size=5, strides=1,
        padding='SAME', 
        kernel_prior_fn=eb_prior_fn,
        kernel_divergence_fn=divergence_fn,
        activation=tf.nn.selu),
    tf.keras.layers.MaxPool2D(
        pool_size=[3, 3], strides=1,
        padding='SAME'),
    tfp.layers.Convolution2DFlipout(
        128, kernel_size=1, 
        strides=1, padding='SAME',
        kernel_prior_fn=eb_prior_fn,
        kernel_divergence_fn=divergence_fn,
        activation=tf.nn.selu),
    tf.keras.layers.MaxPool2D(
        pool_size=[3, 3], strides=2,
        padding='SAME'),
    tf.keras.layers.Flatten(),
#     tfp.layers.DenseFlipout(
#         50, kernel_divergence_fn=kl_divergence_function,
#         activation=tf.nn.selu),
    tfp.layers.DenseFlipout(
        NUM_CLASSES, 
        kernel_prior_fn=eb_prior_fn,
        kernel_divergence_fn=divergence_fn)])

In [1]:
train_ds = (tf.data.Dataset.list_files(params.patches_dir + '/train/*/*')
    .shuffle(buffer_size=1000)
    .map(data_preparation._parse_image, num_parallel_calls=AUTOTUNE)
    .batch(params.BATCH_SIZE)
)
val_ds = (tf.data.Dataset.list_files(params.patches_dir + '/val/*/*')
    .map(data_preparation._parse_image, num_parallel_calls=AUTOTUNE)
    .batch(params.BATCH_SIZE)
)
test_ds = (tf.data.Dataset.list_files(params.patches_dir + '/test/*/*')
    .map(data_preparation._parse_image, num_parallel_calls=AUTOTUNE)
    .batch(params.BATCH_SIZE)
)

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'],
              experimental_run_tf_function=False)
model.build(input_shape=[None, 256, 256, 1])
constrain_conv_layer = constrain_conv(model)

# Create a callback that saves the model's weights
ckpts_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./ckpts/dense/',
                                                 save_weights_only=True,
                                                 monitor='val_accuracy', mode='max',
                                                 save_best_only=True,
                                                 verbose=1)
logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

model.summary()

In [1]:
names = [layer.name for layer in model.layers 
        if 'flipout' in layer.name]
# dense_flipout/kernel_posterior_loc:0
qm_vals = [layer.kernel_posterior.mean() 
        for layer in model.layers
        if 'flipout' in layer.name]
# this stddev is after softplus
qs_vals = [layer.kernel_posterior.stddev()
        for layer in model.layers
        if 'flipout' in layer.name]

utils.plot_weight_posteriors(names, qm_vals, qs_vals, fname="weight.png")
print("mean of mean is {}, mean variance is {}".
      format(tf.reduce_mean(qm_vals[0]),
      tf.reduce_mean(qs_vals[0])))

In [1]:
history = model.fit(train_ds, epochs=10, 
                    callbacks=[constrain_conv_layer, ckpts_callback, tensorboard_callback], 
                    validation_data=val_ds, class_weight=class_weight)

In [1]:
model.evaluate(test_ds, verbose=1)