<a href="https://colab.research.google.com/github/hummosa/Hypernetworks_Keras_TF2/blob/master/Hypernetworks_in_keras_and_tf2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Guide to Hypernetworks in Keras and Tensorflow 2.0

This is a keras implementation of hypernetworks, which are typically a pair of networks where one generates the parameters (weights) of the other [(Ha, et al., 2016)](https://arxiv.org/abs/1609.09106). Keras layer implementation exposes the parameters of a layer as two modifiable properties: ‘kernel’ and ‘bias’, which allows assiging them new values during inference. 

This code will separate the hypernetwork into two keras models: an inference model which will perform the inference task  (for e.g. classify handwritten digits), and a hyper model, which will generate the parameters of the inference model for each input example. We will demonstrate using a convolutional network as the inference model to classify MNIST digits.


In [0]:
!pip install tf-nightly-gpu-2.0-preview

Collecting tf-nightly-gpu-2.0-preview
[?25l  Downloading https://files.pythonhosted.org/packages/74/74/c6d43d2a9a26ea430acd804068bbaf2fdf14dd470f543d113588b1f80828/tf_nightly_gpu_2.0_preview-2.0.0.dev20190611-cp36-cp36m-manylinux1_x86_64.whl (343.8MB)
[K     |████████████████████████████████| 343.8MB 50kB/s 
Collecting tensorflow-estimator-2.0-preview (from tf-nightly-gpu-2.0-preview)
[?25l  Downloading https://files.pythonhosted.org/packages/ff/d5/3ef112818fe4181e9bf1a1233a0cd8f9ffb63f386f268701aa205bf04384/tensorflow_estimator_2.0_preview-1.14.0.dev2019061100-py2.py3-none-any.whl (436kB)
[K     |████████████████████████████████| 440kB 41.4MB/s 
[?25hCollecting google-pasta>=0.1.6 (from tf-nightly-gpu-2.0-preview)
[?25l  Downloading https://files.pythonhosted.org/packages/d0/33/376510eb8d6246f3c30545f416b2263eee461e40940c2a4413c711bdf62d/google_pasta-0.1.7-py3-none-any.whl (52kB)
[K     |████████████████████████████████| 61kB 21.6MB/s 
Collecting tb-nightly<1.15.0a0,>=1.14.0a0 

In [0]:
# Import tensorflow and check version
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

print('tensorflow version: {}'.format(tf.__version__))
tf.keras.backend.clear_session()

tensorflow version: 2.0.0-dev20190611


For this tutorial will use the MNIST dataset to demonstrate the setup. The following code will download and prepare the MNIST dataset.

In [0]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# convert to float32 and normalize. 
x_train = x_train.astype('float32') /255
x_test = x_test.astype('float32')   /255

# one-hot encode the labels 
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# add a channel dimension to the images
x_train = x_train.reshape(x_train.shape[0], 28, 28,1)
x_test = x_test.reshape(x_test.shape[0], 28, 28,1)


# Define image dimensions
img_h = 28
img_w = 28
img_c = 1

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


### Inference model:
We now build the inference model, a simple convolutional network, with a fully connected layer on top, that we will use to classify MNIST digits. 


In [0]:
infer_model = tf.keras.models.Sequential(name='infer_model')
infer_model.add(tf.keras.layers.Input(shape=(img_h, img_w, img_c), name='input_x' ))
infer_model.add(tf.keras.layers.Conv2D(32, (3,3), activation='relu') )
infer_model.add(tf.keras.layers.MaxPool2D() )
infer_model.add(tf.keras.layers.Conv2D(32, (3,3), activation='relu') )
infer_model.add(tf.keras.layers.MaxPool2D() ) 
infer_model.add(tf.keras.layers.Flatten() )

infer_model.add(tf.keras.layers.Dense(10, activation= 'softmax', name='out_layer') )

infer_model.summary()

Model: "infer_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 32)        9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 800)               0         
_________________________________________________________________
out_layer (Dense)            (None, 10)                8010      
Total params: 17,578
Trainable params: 17,578
Non-trainable params: 0
___________________________________________________

Note that this model has a total of 17,578 parameters that need to be generated by the hyper model.  

### Hyper model:
Let us now define the hyper model with 2 convolutional layers and a fully connected layer on top to produce a latent embedding of size 784. The embedding is then fed into a stack of 3 transpose convolutional layers that produce a large number of values, which will be used as parameters for the inference model.

Note that the last layer uses a tanh activation function which produces values between -1 and 1. This allows generation of parameters with negative values. 

In [0]:
hyper_model_x = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=(img_h, img_w, img_c)),
        tf.keras.layers.Conv2D(16, (3,3), activation='relu') ,
        tf.keras.layers.MaxPool2D() ,
        tf.keras.layers.Conv2D(8, (3,3), activation='relu') ,
        tf.keras.layers.MaxPool2D() ,
        tf.keras.layers.Flatten() ,
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(units=784, activation=tf.nn.relu),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
        tf.keras.layers.Conv2DTranspose(
            filters=16, 
            kernel_size=3,
            strides=(2, 2),
            padding="same",
            activation=tf.nn.relu),
        tf.keras.layers.Conv2DTranspose(
            filters=8,  
            kernel_size=3,
            strides=(2, 2),
            padding="same",
            activation=tf.nn.relu),
        tf.keras.layers.Conv2DTranspose(
            filters=2, kernel_size=3, strides=(1, 1), padding="SAME", activation='tanh'),
        tf.keras.layers.Flatten()
    ], name='hyper_model'
)


To apply parameters to the inference model, we define a function 'parametrize_model', which consumes a tensor of generated parameters to parametrize the weights and the biases of each layer in a model. 

In [0]:
def parameterize_model(model, weights):
    # function to parametrizes all the trainable variables of model using the stream of weight values in weights
    # This assumes weights are passed a single batch.
    weights = tf.reshape( weights, [-1] ) # reshape the parameters to a vector
    
    last_used = 0
    for i in range(len(model.layers)):
        # check to make sure only conv and fully connected layers are assigned weights.
        if 'conv' in model.layers[i].name or 'out' in model.layers[i].name or 'dense' in model.layers[i].name: 
            weights_shape = model.layers[i].kernel.shape
            no_of_weights = tf.reduce_prod(weights_shape)
            new_weights = tf.reshape(weights[last_used:last_used+no_of_weights], weights_shape) 
            model.layers[i].kernel = new_weights
            last_used += no_of_weights
            
            if model.layers[i].use_bias:
              weights_shape = model.layers[i].bias.shape
              no_of_weights = tf.reduce_prod(weights_shape)
              new_weights = tf.reshape(weights[last_used:last_used+no_of_weights], weights_shape) 
              model.layers[i].bias = new_weights
              last_used += no_of_weights


### The training loop:
We are now ready to define the main training loop. Eager execution is enabled by default in tensorflow 2.0, which provides more control over the training process. Note that the loss function is differentiated with respect to the hyper model parameters only. In fact, the parameters of the inference model are no longer considered trainable by keras (can check by running infer_model.summary()). This loop updates the parameters of the hyper model only.

In [0]:
# Define accuracy metrics for validation
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(1e-3) 

loss_accum = 0.0
batch_size = 1
for step in range(1, 6001):
  idx = np.random.randint(low=0, high=x_train.shape[0], size=batch_size)
  x, y = x_train[idx], y_train[idx]
  
  with tf.GradientTape() as tape:
    # Predict weights for the infer model
    generated_parameters = hyper_model_x(x)
    parameterize_model(infer_model, generated_parameters)    
    
    # Inference on the infer model
    preds = infer_model(x)

    loss = loss_fn( y, preds)
    loss_accum += loss
    train_acc_metric( y, tf.expand_dims(preds, 0)) # update the acc metric

    if step % 1000 == 0: 
      loss_accum = 0.0
      var = generated_parameters.numpy()
      print('statistics of the generated parameters: '+'Mean, {:2.3f}, var {:2.3f}, min {:2.3f}, max {:2.3f}'.format(var.mean(), var.var(), var.min(), var.max()))
      for val_step in range(500): # 
        idx = np.random.randint(low=0, high=x_test.shape[0], size=batch_size)
        x, y = x_test[idx], y_test[idx]
        generated_parameters = hyper_model_x(x)
        parameterize_model(infer_model, generated_parameters)    
        preds = infer_model(x)
        val_acc_metric( y, tf.expand_dims(preds, 0)) # update the acc metric
      print('\n Step: {}, validation set accuracy: {:2.2f}     loss: {:2.2f}'.format(step, float(val_acc_metric.result()), loss_accum))
      val_acc_metric.reset_states()
         
        
    # Train only hyper model
    grads = tape.gradient(loss, hyper_model_x.trainable_weights)
    optimizer.apply_gradients(zip(grads, hyper_model_x.trainable_weights))

  

statistics of the generated parameters: Mean, -0.074, var 0.033, min -0.991, max 0.995

 Step: 1000, validation set accuracy: 0.97     loss: 0.00
statistics of the generated parameters: Mean, -0.087, var 0.069, min -1.000, max 0.999

 Step: 2000, validation set accuracy: 0.96     loss: 0.00
statistics of the generated parameters: Mean, -0.069, var 0.028, min -0.990, max 0.980

 Step: 3000, validation set accuracy: 0.92     loss: 0.00
statistics of the generated parameters: Mean, -0.074, var 0.024, min -0.996, max 0.983

 Step: 4000, validation set accuracy: 0.95     loss: 0.00
statistics of the generated parameters: Mean, -0.089, var 0.062, min -1.000, max 0.999

 Step: 5000, validation set accuracy: 0.94     loss: 0.00
statistics of the generated parameters: Mean, -0.082, var 0.050, min -0.999, max 0.998

 Step: 6000, validation set accuracy: 0.97     loss: 0.00


Two issues arise in building a hypernetwork in keras. First is managing mini-batching to speed up training on GPUs, and the second relates to initializing the weights. 
The need for a different weight matrix for each input sample introduces significant challenges in using mini-batches during training. While it is possible to create custom keras layers that can handle storing a batch of weights and biases for each layer, we kept batch_size at 1 for the purposes of this tutorial.
On another front, the intial values of a neural network parameters may significantly impact training dynamics. Keras, by default, uses the [Glorot initializer ](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf), which factors in the number of connections coming from previous layers. In a hypernetwork setup, assigning weights directly to a network sidesteps this initilaization that keras normally handles automatically. Accordingly, it is important to moniter and the statistics of the generated parameters and consider how their mean, variance and range might affect training dynamics. For the purposes of this guide, we found it important to use a tanh activation function in the last layer of the hyper model, which provided values cenetered around 0 with a range from -1 to 1.




Finally let's look at a histogram of the generated parameters. There is a sharp peak around very small negative values  around -0.1.


In [0]:
_ = plt.hist(generated_parameters, bins=100)

####References:
1) Ha, D., Dai, A., & Le, Q. V. (2016). Hypernetworks. arXiv preprint arXiv:1609.09106.
    
2) Glorot, X. & Bengio, Y.. (2010). Understanding the difficulty of training deep feedforward neural networks. Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, in PMLR 9:249-256
