In [24]:
from __future__ import print_function

from absl import app as absl_app
import time
import tensorflow as tf
import numpy as np

from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule
import tensorflow_model_optimization as tfmot

ConstantSparsity = pruning_schedule.ConstantSparsity
keras = tf.keras
l = keras.layers


batch_size = 128
num_classes = 10
epochs = 1


def build_sequential_model(input_shape):
    return tf.keras.Sequential([
      l.Conv2D(
          32, 5, padding='same', activation='relu', input_shape=input_shape),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.BatchNormalization(),
      l.Conv2D(64, 5, padding='same', activation='relu'),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Flatten(),
      l.Dense(1024, activation='relu'),
      l.Dropout(0.4),
      l.Dense(num_classes, activation='softmax')
  ])

def build_functional_model(input_shape):
    inp = tf.keras.Input(shape=input_shape)
    x = l.Conv2D(32, 5, padding='same', activation='relu')(inp)
    x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
    x = l.BatchNormalization()(x)
    x = l.Conv2D(64, 5, padding='same', activation='relu')(x)
    x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x)
    x = l.Flatten()(x)
    x = l.Dense(1024, activation='relu')(x)
    x = l.Dropout(0.4)(x)
    out = l.Dense(num_classes, activation='softmax')(x)

    return tf.keras.models.Model([inp], [out])


def build_layerwise_model_pruned(input_shape, **pruning_params):
    return tf.keras.Sequential([
      prune.prune_low_magnitude(
          l.Conv2D(32, 5, padding='same', activation='relu'),
          input_shape=input_shape,
          **pruning_params),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.BatchNormalization(),
      prune.prune_low_magnitude(
          l.Conv2D(64, 5, padding='same', activation='relu'), **pruning_params),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Flatten(),
      prune.prune_low_magnitude(
          l.Dense(1024, activation='relu'), **pruning_params),
      l.Dropout(0.4),
      prune.prune_low_magnitude(
          l.Dense(num_classes, activation='softmax'), **pruning_params)
  ])


def train_and_save(models, x_train, y_train, x_test, y_test,accuracys,speeds):
    for model in models:
        model.compile(
            loss=tf.keras.losses.categorical_crossentropy,
            optimizer='adam',
            metrics=['accuracy'])

    # Print the model summary.
        model.summary()

    # Add a pruning step callback to peg the pruning step to the optimizer's
    # step. Also add a callback to add pruning summaries to tensorboard
        callbacks = [
        pruning_callbacks.UpdatePruningStep(),
        pruning_callbacks.PruningSummaries(log_dir='/tmp/mnist_train/')
    ]

        model.fit(
            x_train,
            y_train,
            batch_size=batch_size,
            epochs=epochs,
            verbose=1,
            callbacks=callbacks,
            validation_data=(x_test, y_test))
        score = model.evaluate(x_test, y_test, verbose=0)
        
        accuracys.append(score[0])
        
        start=time.time()
        model.predict(x_test)
        stop=time.time()
        speeds.append(stop-start)

    return accuracys,speeds


def main(unused_argv):
    accuracys=[]
    speeds=[]
  # input image dimensions
    img_rows, img_cols = 28, 28
    
  # the data, shuffled and split between train and test sets
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    num_images = x_train.shape[0] * (1 - 1/7)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    
    if tf.keras.backend.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')

  # convert class vectors to binary class matrices
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)
#PARAMETRES DU PRUNING ???
#     pruning_params = {
#       'pruning_schedule': ConstantSparsity(0.80, begin_step=2000, frequency=100)
#   }
    pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
    }
    
    layerwise_model_pruned = build_layerwise_model(input_shape, **pruning_params)
    sequential_model = build_sequential_model(input_shape)
    sequential_model_pruned = prune.prune_low_magnitude(
      sequential_model, **pruning_params)
    functional_model = build_functional_model(input_shape)
    functional_model_pruned = prune.prune_low_magnitude(
      functional_model, **pruning_params)

    models = [layerwise_model_pruned,sequential_model_pruned,functional_model_pruned,sequential_model, functional_model]
    accuracys,speeds=train_and_save(models, x_train, y_train, x_test, y_test,accuracys,speeds)
    return models,accuracys,speeds

models,accuracys,speeds=main('test')

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Model: "sequential_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_conv2d_4 (None, 28, 28, 32)        1634      
_________________________________________________________________
max_pooling2d_42 (MaxPooling (None, 14, 14, 32)        0         
_________________________________________________________________
batch_normalization_21 (Batc (None, 14, 14, 32)        128       
_________________________________________________________________
prune_low_magnitude_conv2d_4 (None, 14, 14, 64)        102466    
_________________________________________________________________
max_pooling2d_43 (MaxPooling (None, 7, 7, 64)          0         
_________________________________________________________________
flatten_21 (Flatten)         (None, 3136)              0         
______________________________________________



In [25]:
models

[<tensorflow.python.keras.engine.sequential.Sequential at 0x7f7b28429828>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f7af0433080>,
 <tensorflow.python.keras.engine.functional.Functional at 0x7f7af011e780>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f7af045b898>,
 <tensorflow.python.keras.engine.functional.Functional at 0x7f7af00ff780>]

In [26]:
accuracys

[0.1323108822107315,
 0.1447369009256363,
 0.17748628556728363,
 0.032473865896463394,
 0.03346732631325722]

In [27]:
speeds

[1.661036491394043,
 1.6880574226379395,
 1.6942851543426514,
 1.369884967803955,
 1.3599293231964111]