<a href="https://colab.research.google.com/github/ferova/MorphNetTutorials/blob/master/notebooks/MorphNet_Keras_ResnetTutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using MorphNet on a Keras defined Model.

**Author:** [ferova](https://github.com/ferova)<br>
**Date created:** 2020//08//29<br>
**Last modified:** 2020//12//10<br>
**Description:** Morphnet on a ResnetV50 model extracted from the Keras model zoo and trained with the Keras API.

## Install MorphNet in the notebook





Here we install MorphNet, we also import Tensorflow with v1 compatibility and disable eager execution. This is mandatory for MorphNet to work.

In [None]:
!git clone https://github.com/google-research/morph-net
!pip install morph-net/ 

In [None]:
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()

## Data preparation

We download the training data (cifar10) and define variables needed for the creation of the model. Here BATCH_SIZE is the training batch size and the model structure in MorphNet's format is saved every SAVE_MODEL_EVERY epochs.



In [None]:
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train=tf.keras.utils.to_categorical(y_train)
y_test=tf.keras.utils.to_categorical(y_test)

In [None]:
NUM_CLASSES = y_train.shape[1]
INPUT_SIZE_X = x_train.shape[1]
INPUT_SIZE_Y = x_train.shape[2]

BATCH_SIZE = 2048
SAVE_MODEL_EVERY = 20
EPOCHS = 120

## Model definition

The model is taken from one of Keras Model Zoo's models. The model is created from the output of the base model in order to have all the layers available and not have it as a sub-model. We use GlobalAveragePooling2D as Flatten layers are not accepted in MorphNet.

In [None]:
base_model = tf.keras.applications.ResNet50(include_top= False, input_shape=(INPUT_SIZE_X, INPUT_SIZE_Y,3))

x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(1024, activation="relu")(x)
logits = tf.keras.layers.Dense(NUM_CLASSES)(x)
model = tf.keras.Model(inputs=base_model.input, outputs=logits)

## Loss definition

We import the regularizer. Here we use **flop_regularizer** because our network has Batch Normalization. 

A couple of things to take into account are:

*   The *regularizer strength* is set to 1e-6. This is the parameter that determines how agressive Morphnet is. Higher values produce an overall smaller network.
*   We also set the *threshold value*, that is, the value at which MorphNet treats a weight as being 0. Here we use 1e-2.

* In the input boundary, we use the tensorflow operation of the output of our model.

Finally, we define the *MorphNet loss* and the *cost*.


In [None]:
from morph_net.network_regularizers import flop_regularizer
from morph_net.tools import structure_exporter

regularization_strength = 1e-6

network_regularizer = flop_regularizer.GroupLassoFlopsRegularizer(
    output_boundary=[model.output.op],
    threshold=1e-2)

morph_net_loss = network_regularizer.get_regularization_term()*regularization_strength

cost = network_regularizer.get_cost()

We then add the new loss to the model. We define metrics for the cost and the new loss so we can keep track of them.

In [None]:
model.add_loss(lambda: morph_net_loss)
def cost_metric(y_true, y_pred):
    return cost
def morphloss_metric(y_true, y_pred):
    return morph_net_loss

## Structure exporter definition

Here we define the *structure exporter*. The **export_sctructure** function saves the current structure of the network by following the next steps:

1.   It creates a StructureExporter object from the *network_regularizer* we defined before.
2.   It then creates a dictionary containing all of the tensors in the regularizer and evaluates them.

3.  It populates the tensors with the evaluated values and saves the current status in a file. Here it saves the structure at *morphnet-log/models /learned_structure* with alive_EPOCH.json as the file name.

We then create a callback from our function that is called at the end of every epoch. This callback is created as a lambda callback so it can use a previously defined function.

**NOTE:**

We create the function as a callback so it has access to the tensors during training.





In [None]:
def export_structure(epoch, logs):
  if (epoch % SAVE_MODEL_EVERY == 0):

    exporter = structure_exporter.StructureExporter(network_regularizer.op_regularizer_manager)

    values = {}
    for key, item in exporter.tensors.items():
      values[key] = tf.keras.backend.eval(item)

    exporter.populate_tensor_values(values)

    exporter.create_file_and_save_alive_counts('morphnet-log/models/', str(epoch)+'.json')

export_structure_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs: export_structure(epoch, logs))

callback_list = [export_structure_callback]

## Model training

We define the optimizier we want to use together with the loss of the model. Keras automatically adds the two losses together.

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=5e-4)

model.compile(optimizer=opt, loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['categorical_accuracy', cost_metric, morphloss_metric])


We can also use tensorboard to monitor the training process of our network, if you don't want to use it skip running the next two lines. Note that in order to use it the tensorboard cell has to be run before the call to the fit method of the model.

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='morphnet-log/', write_graph=False)
callback_list.append(tensorboard_callback)

In [None]:
%load_ext tensorboard
%tensorboard --logdir morphnet-log/

We train the model. Note how the model overfits rapidly but the regularization provided by MorphNet results in a better validation accuracy. 

In [None]:
model.fit(x = x_train, y = y_train, validation_data = (x_test, y_test),  batch_size= BATCH_SIZE, epochs=EPOCHS, callbacks = callback_list)

## Model structure

Finally, we can retrieve the structure we want to use. Here we load the structure saved at epoch 100.

In [None]:
import json
with open('morphnet-log/models/learned_structure/alive_100.json') as json_file:
    structure = json.load(json_file)

In [None]:
structure