<a href="https://colab.research.google.com/github/mmanngard/sparse-learning-mnist/blob/main/sparse_learning_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sparse Structural Learning on MNIST
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model and to enforse sparsity.


Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/keras_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np


## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Utils

In [None]:
def inspect_weights(model, eps=1E-6):
  for layer in model.layers:
      if len(layer.get_weights()) > 0:  # Check if the layer has any weights (some layers, like Flatten, don't)
          weights, biases = layer.get_weights()
          shape = weights.shape
          n_weights = shape[0]*shape[1]
          n_non_zero = np.sum(np.abs(weights) <= eps)
          print(f"Layer {layer.name}:")
          print(f"Weights shape: {weights.shape}")
          print(f"# non-zero weights: {n_non_zero}/{n_weights}, {n_non_zero/n_weights * 100 :.2f} % sparse")
      else:
          print(f"Layer {layer.name} has no weights.")

# Custom weighted l1 regularizer
class WeightedL1(tf.keras.regularizers.Regularizer):
    def __init__(self, reweights, _lambda=0.001):
        self.reweights = tf.convert_to_tensor(reweights, dtype=tf.float32)  # Reweighting matrix
        self._lambda = _lambda  # Regularization strength (lambda)

    def __call__(self,  weights):
        # Element-wise product between the reweighting matrix and the weights
        weighted_weights = tf.multiply( self.reweights, weights )

        # Compute the L1 norm of the element-wise product and multiply by lambda
        weighted_l1 = self._lambda * tf.reduce_sum(tf.abs(weighted_weights))
        return weighted_l1

    def get_config(self):
        return {'reweighting': self.reweighting.numpy(), 'l1': self._lambda}

def get_reweights(weights, eps=1E-6):
  reweights = np.zeros_like(weights)
  # If abs(weight) <= eps, set reweight to 1 / eps
  # Otherwise, set reweight to 1 / weight
  return np.where(np.abs(weights) > eps, 1 / weights, 1 / eps)

def truncate_weights(model, layer_idx, eps=1E-4):
  weights, biases = model.layers[layer_idx].get_weights()
  # set weights to zero if <= eps
  weights = np.where(np.abs(weights) <= eps, 0, weights)
  # set weights in model
  model.layers[layer_idx].set_weights([weights, biases])

  return model

def evaluate_model(model, ds_test):
  # validate model
  val_loss, val_accuracy = model.evaluate(ds_test)

  return val_loss, val_accuracy

def create_model(model_structure, reweighting):
  model = tf.keras.models.Sequential()

  # add input layer
  model.add(tf.keras.layers.Flatten(input_shape=model_structure['input_shape']))
  # create dense layers and add regularization if layer is sparse

  for i in range(model_structure['n_layers']):
    # model parameters
    n_nodes = model_structure['layer_structure'][i]
    activation = model_structure['layer_activation'][i]
    sparse_layer = model_structure['sparse_layer'][i]

    if sparse_layer:
      # reweigting parameters
      reweights = reweighting['reweights'][i]
      _lambda = reweighting['lambda']
      model.add(tf.keras.layers.Dense(n_nodes, activation=activation, kernel_regularizer=WeightedL1(reweights, _lambda=_lambda)))
    else:
      model.add(tf.keras.layers.Dense(n_nodes, activation=activation))

  # compile new model
  model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )

  return model

def structure_learning(model, model_structure, solver_opts, reweighting):
  # calculate reweigting
  for i in range(model_structure['n_layers']):
    if model_structure['sparse_layer'][i]:
      # get weights
      weights, biases = model.layers[i+1].get_weights()

      # update reweighting matrix
      reweights = get_reweights(weights, eps=1E-6)
      reweighting['reweights'][i] = reweights

      # create new model
      model_new = create_model(model_structure, reweighting)

      # set initial weights from previous model
      #model_new.set_weights(model.get_weights())

      # train model
      model_new.fit(
        ds_train,
        epochs=solver_opts['epochs'],
        validation_data=ds_test,
      )

      return model_new, reweighting

## Step 2: Create and train the initial model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)


In [None]:
inspect_weights(model, eps=1E-4)
print(model.weights[2].shape)

# Retrain the model with sparse optimization


In [None]:
model_structure = {
    'n_layers': 3,
    'input_shape': (28,28),
    'layer_structure': [128, 128, 10],
    'sparse_layer': [False, True, False],
    'layer_activation': ['relu', 'relu', 'softmax'],
}

reweighting = {
    'reweights': [None, None, None],
    'lambda': 1/(128 * 128)
}

solver_opts = {
    'epochs': 6
}

ITERATIONS = 4
eps = 1E-4

for it in range(ITERATIONS):
  print(f'############ ITERATION {it} ############')
  # train model with reweighting
  model, reweighting = structure_learning(model, model_structure, solver_opts, reweighting)

  # set small weights to zero in sparse layers
  for i in range(model_structure['n_layers']):
    if model_structure['sparse_layer'][i]:
      # get the weights from the layer
      weights, biases = model.layers[i+1].get_weights()
      # set weights that are <= eps to zero
      weights = np.where(np.abs(weights) <= eps, 0, weights)
      # get the modified weights back to the layer
      model.layers[i+1].set_weights([weights, biases])

  # Re-calculate validation accuracy
  val_loss, val_accuracy = model.evaluate(ds_test)

  print(f"Validation accuracy: {val_accuracy}")
  print(f"Validation loss: {val_loss}")
  inspect_weights(model, eps=1E-4)

