In [1]:
import os
import zipfile
import tempfile

import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

In [2]:
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes = 20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape = input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss = tf.keras.losses.categorical_crossentropy,
      optimizer = 'adam',
      metrics = ['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer = False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression = zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()



### Prune whole model (Sequential and Functional)

Tips for better model accuracy:
- Try "Prune some layers" to skip pruning the layers that reduce accuracy the most.
- It's generally better to finetune with pruning as opposed to training from scratch.

To make the whole model train with pruning, apply tfmot.sparsity.keras.prune_low_magnitude to the model.

In [3]:
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()

Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________
