In [None]:
#from: https://keras.io/getting_started/intro_to_keras_for_researchers/

# HyperNetworks
A hypernetwork is a deep neural network whose weights are generated by another network (usually smaller).

Let's implement a really trivial hypernetwork: we'll use a small 2-layer network to generate the weights of a larger 3-layer network.

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [2]:
input_dim = 784
classes = 10

# This is the model we'll actually use to predict labels (the hypernetwork).
outer_model = keras.Sequential(
    [keras.layers.Dense(64, activation=tf.nn.relu), 
     keras.layers.Dense(classes),]
)

# It doesn't need to create its own weights, so let's mark its layers
# as already built. That way, calling `outer_model` won't create new variables.
for layer in outer_model.layers:
    layer.built = True


In [3]:
# This is the number of weight coefficients to generate. Each layer in the
# hypernetwork requires output_dim * input_dim + output_dim coefficients.
num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

# This is the model that generates the weights of the `outer_model` above.
inner_model = keras.Sequential(
    [
        keras.layers.Dense(16, activation=tf.nn.relu),
        keras.layers.Dense(num_weights_to_generate, activation=tf.nn.sigmoid),
    ]
)

In [5]:
# Loss and optimizer.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

# Prepare a dataset.
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)

# We'll use a batch size of 1 for this experiment.
dataset = dataset.shuffle(buffer_size=1024).batch(1)


@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        # Predict weights for the outer model.
        weights_pred = inner_model(x)

        # Reshape them to the expected shapes for w and b for the outer model.
        # Layer 0 kernel.
        start_index = 0
        w0_shape = (input_dim, 64)
        w0_coeffs = weights_pred[:, start_index : start_index + np.prod(w0_shape)]
        w0 = tf.reshape(w0_coeffs, w0_shape)
        start_index += np.prod(w0_shape)
        # Layer 0 bias.
        b0_shape = (64,)
        b0_coeffs = weights_pred[:, start_index : start_index + np.prod(b0_shape)]
        b0 = tf.reshape(b0_coeffs, b0_shape)
        start_index += np.prod(b0_shape)
        # Layer 1 kernel.
        w1_shape = (64, classes)
        w1_coeffs = weights_pred[:, start_index : start_index + np.prod(w1_shape)]
        w1 = tf.reshape(w1_coeffs, w1_shape)
        start_index += np.prod(w1_shape)
        # Layer 1 bias.
        b1_shape = (classes,)
        b1_coeffs = weights_pred[:, start_index : start_index + np.prod(b1_shape)]
        b1 = tf.reshape(b1_coeffs, b1_shape)
        start_index += np.prod(b1_shape)

        # Set the weight predictions as the weight variables on the outer model.
        outer_model.layers[0].kernel = w0
        outer_model.layers[0].bias = b0
        outer_model.layers[1].kernel = w1
        outer_model.layers[1].bias = b1

        # Inference on the outer model.
        preds = outer_model(x)
        loss = loss_fn(y, preds)

    # Train only inner model.
    grads = tape.gradient(loss, inner_model.trainable_weights)
    optimizer.apply_gradients(zip(grads, inner_model.trainable_weights))
    return loss

In [6]:
losses = []  # Keep track of the losses over time.
for step, (x, y) in enumerate(dataset):
    loss = train_step(x, y)

    # Logging.
    losses.append(float(loss))
    if step % 100 == 0:
        print("Step:", step, "Loss:", sum(losses) / len(losses))

    # Stop after 1000 steps.
    # Training the model to convergence is left
    # as an exercise to the reader.
    if step >= 10000:
        break

Step: 0 Loss: 4.6243977546691895
Step: 100 Loss: 2.6400829122798277
Step: 200 Loss: 2.2352249455818933
Step: 300 Loss: 2.048253087060816
Step: 400 Loss: 1.9136626055450394
Step: 500 Loss: 1.779765464322277
Step: 600 Loss: 1.7298317057680053
Step: 700 Loss: 1.637547222141946
Step: 800 Loss: 1.5927312879532527
Step: 900 Loss: 1.5260903194025564
Step: 1000 Loss: 1.4715679380608213
Step: 1100 Loss: 1.4332551930922335
Step: 1200 Loss: 1.3952297396319036
Step: 1300 Loss: 1.3691044495889846
Step: 1400 Loss: 1.343293452130987
Step: 1500 Loss: 1.3189049076194086
Step: 1600 Loss: 1.3087429432102997
Step: 1700 Loss: 1.2856774886782134
Step: 1800 Loss: 1.2653956043660983
Step: 1900 Loss: 1.2507410198291078
Step: 2000 Loss: 1.2228993528042147
Step: 2100 Loss: 1.2129999295344809
Step: 2200 Loss: 1.196503938620133
Step: 2300 Loss: 1.1844519183344342
Step: 2400 Loss: 1.1654086992530375
Step: 2500 Loss: 1.1604776558167191
Step: 2600 Loss: 1.154352454285118
Step: 2700 Loss: 1.1478716704938412
Step: 2800