<a href="https://colab.research.google.com/github/mmanngard/sparse-learning-mnist/blob/main/sparse_learning_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sparse Structural Learning on MNIST
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model and to enforse sparsity.


Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/keras_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt


## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [None]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [None]:
def normalize_img(image, label, threshold=0.5):
  """Normalizes images: `uint8` -> `float32`."""
  image = tf.cast(image, tf.float32) / 255.0
  image = tf.where(image > threshold, 1.0, 0.0)
  return image, label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [None]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

# Training utils

In [None]:
def sparse_structure_learning(model, _lambda, eps, epochs):
  # compute reweights
  weights, biases = model.layers[2].get_weights()
  print(weights.shape)
  reweights = np.where(np.abs(weights) > eps, 1 / np.abs(weights), 1E8)

  # regularizer
  regularizer_l1_weighted = WeightedL1(reweights=reweights, _lambda=_lambda)

  # create new model
  new_model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=model_structure['input_shape']),
  tf.keras.layers.Dense(model_structure['layer_structure'][0], activation=model_structure['layer_activation'][0]),
  tf.keras.layers.Dense(model_structure['layer_structure'][1], activation=model_structure['layer_activation'][1], kernel_regularizer=regularizer_l1_weighted),
  tf.keras.layers.Dense(model_structure['layer_structure'][2], activation=model_structure['layer_activation'][2])
  ])
  new_model.compile(
      optimizer=tf.keras.optimizers.Adam(0.001),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
  )

  # assign initial weight values
  initial_weights = model.get_weights()
  new_model.set_weights(initial_weights)

  # train model
  new_model.fit(
      ds_train,
      epochs=epochs,
      validation_data=ds_test,
  )

  return new_model, reweights



# Custom weighted l1 regularizer
class WeightedL1(tf.keras.regularizers.Regularizer):
    def __init__(self, reweights, _lambda):
        self.reweights = tf.convert_to_tensor(reweights, dtype=tf.float32)  # Reweighting matrix
        self._lambda = _lambda  # Regularization strength (lambda)

    def __call__(self,  weights):
        # Element-wise product between the reweighting matrix and the weights
        weighted_weights = self.reweights * weights

        # Compute the L1 norm of the element-wise product and multiply by lambda
        weighted_l1 = self._lambda * tf.reduce_sum(tf.abs(weighted_weights))
        return weighted_l1

    def get_config(self):
        return {'reweighting': self.reweighting.numpy(), 'l1': self._lambda}



def inspect_weights(model, eps):
  '''Get information of weights in layers'''
  for layer in model.layers:
      if len(layer.get_weights()) > 0:  # Check if the layer has any weights (some layers, like Flatten, don't)
          weights, biases = layer.get_weights()
          shape = weights.shape
          n_weights = shape[0]*shape[1]
          n_zero = np.sum(np.abs(weights) < eps)
          print(f"Layer {layer.name}:")
          print(f"Weights shape: {weights.shape}")
          print(f"# zero weights/total: {n_zero}/{n_weights}, {n_zero/n_weights * 100 :.2f} % sparse")
      else:
          print(f"Layer {layer.name} has no weights.")


## Step 2: Create and train the initial model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [None]:
model_structure = {
    'n_layers': 3,
    'input_shape': (28,28),
    'layer_structure': [128, 128, 10],
    'sparse_layer': [False, True, False],
    'layer_activation': ['relu', 'relu', 'softmax'],
}

reweighting = {
    'reweights': [None, None, None],
    'lambda': 1/1000,
    'eps': 1E-4
}

solver_opts = {
    'epochs': 6
}

model_ini = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=model_structure['input_shape']),
  tf.keras.layers.Dense(model_structure['layer_structure'][0], activation=model_structure['layer_activation'][0]),
  tf.keras.layers.Dense(model_structure['layer_structure'][1], activation=model_structure['layer_activation'][1]),
  tf.keras.layers.Dense(model_structure['layer_structure'][2], activation=model_structure['layer_activation'][2])
])
model_ini.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model_ini.fit(
    ds_train,
    epochs=solver_opts['epochs'],
    validation_data=ds_test,
)


## Step 3: Apply reweighted $\ell_1$ regularization

Strycture in ANN can be revieled by struvtural sparsity learning techniques.

In [None]:
# ####################
ITERATIONS = 12
_lambda = 0.0001
eps = 1E-3
epochs = 1
# ####################

fig, ax = plt.subplots(1, 1, figsize=(6,6))
weights, biases = model_ini.layers[2].get_weights()
ax.plot( np.sort( np.abs(weights.reshape(-1)) ) )
print('---- iteration 0 ----')
inspect_weights(model_ini, eps=eps)
val_loss, val_accuracy = model_ini.evaluate(ds_test)
print(f'val loss: {val_loss}')
print(f'val accuracy: {val_accuracy}')
print(' ')

for i in range(ITERATIONS):
  model, reweights = sparse_structure_learning(model_ini, _lambda=_lambda, eps=eps, epochs=epochs)

  # set small weights to zero
  weights, biases = model.layers[2].get_weights()
  weights[np.abs(weights) < eps] = 0
  model.layers[2].set_weights([weights, biases])

  # plot
  ax.plot( np.sort( np.abs(weights.reshape(-1)) ) )
  #ax.set_yscale('log')

  print(f'---- iteration {i+1} ----')
  inspect_weights(model, eps=eps)
  val_loss, val_accuracy = model.evaluate(ds_test)
  print(f'val loss: {val_loss}')
  print(f'val accuracy: {val_accuracy}')

  print(' ')


# Test on my own handwritten numbers

In [None]:
from tensorflow.keras.preprocessing import image

# Load and preprocess your 28x28 image
def preprocess_image(img_path, threshold=0.5):
    # Load the image in grayscale mode (if it's not already in 28x28 size)
    img = image.load_img(img_path, target_size=(28, 28), color_mode='grayscale')
    img = image.img_to_array(img)
    img = tf.cast(img, tf.float32) / 255.0
    img = tf.where(img > threshold, 1.0, 0.0)
    img = tf.reshape(img, (1, 28, 28, 1))
    return img


img_pths = ['my_zero.png', 'my_one.png', 'my_two.png', 'my_three.png', 'my_four.png', 'my_five.png', 'my_six.png', 'my_seven.png', 'my_eight.png', 'my_nine.png']

for i, img_pth in enumerate(img_pths):
    # Process the image (assuming preprocess_image is defined to return the image in the correct shape)
    processed_image = preprocess_image(img_pth)

    # Use the model to make a prediction
    prediction = model.predict(processed_image)

    # Get the predicted class (digit)
    predicted_class = np.argmax(prediction, axis=1)[0]

    # Display the processed image in the subplot
    plt.imshow(tf.squeeze(processed_image).numpy(), cmap='gray')
    plt.title(f'Prediction: {predicted_class}')

    # Display all the subplots
    plt.tight_layout()
    plt.show()



In [None]:
# Shuffle the dataset to ensure random selection
ds_test_shuffled = ds_test.shuffle(10000)  # Shuffle the dataset

# Extract a batch from the shuffled dataset
random_batch = next(iter(ds_test_shuffled))  # Get one batch (batch size = 128)

# Get images and labels from the batch
images, labels = random_batch  # `images` is a batch of 128 images, `labels` are the corresponding labels

# Select a random image from the batch
random_index = np.random.randint(0, images.shape[0])  # Random index from 0 to batch size - 1
random_image = images[random_index]
random_label = labels[random_index]

# Preprocess the image (if necessary, depending on your model)
processed_image = tf.reshape(random_image, (1, 28, 28, 1))  # Add batch and channel dimensions

# Use the model to predict the class of the image
prediction = model.predict(processed_image)
predicted_class = np.argmax(prediction, axis=1)[0]

# Display the random image
plt.imshow(tf.squeeze(random_image).numpy(), cmap='gray')  # Remove dimensions for visualization
plt.title(f"True Label: {random_label.numpy()}, Predicted Class: {predicted_class}")
plt.show()
