In [1]:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot

In [2]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

In [3]:
class BatchNormalization(tf.keras.layers.BatchNormalization):
    """
    "Frozen state" and "inference mode" are two separate concepts.
    `layer.trainable = False` is to freeze the layer, so the layer will use
    stored moving `var` and `mean` in the "inference mode", and both `gama`
    and `beta` will not be updated !
    """
    def call(self, x, training=False):
        if not training:
            training = tf.constant(False)
        training = tf.logical_and(training, self.trainable)
        return super().call(x, training)

In [4]:
input_layer = tf.keras.layers.Input([28, 28])
net = tf.keras.layers.Reshape(target_shape=(28, 28, 1))(input_layer)

net = tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=None)(net)
net = BatchNormalization()(net)
net = tf.keras.layers.ReLU()(net)
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(net)

net = tf.keras.layers.Conv2D(filters=24, kernel_size=(3, 3), activation=None)(net)
net = BatchNormalization()(net)
net = tf.keras.layers.ReLU()(net)
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(net)

net = tf.keras.layers.Flatten()(net)
logits = tf.keras.layers.Dense(10)(net)

model = tf.keras.Model(input_layer, logits)

In [5]:
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [6]:
model.fit(
    train_images,
    train_labels,
    epochs=1,
    validation_split=0.1,
)

Train on 54000 samples, validate on 6000 samples


<tensorflow.python.keras.callbacks.History at 0x1db91aa4f08>

In [7]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

Baseline test accuracy: 0.9738


In [8]:
tf.keras.models.save_model(model, '../Prune_Test/original.h5', include_optimizer=False)

剪枝

In [11]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

In [12]:
batch_size = 128
epochs = 1
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}

In [13]:
def apply_pruning_to_dense_conv(layer):
    if isinstance(layer, tf.keras.layers.Dense) or isinstance(layer, tf.keras.layers.Conv2D):
        return tfmot.sparsity.keras.prune_low_magnitude(layer)
    return layer

def apply_pruning_to_conv(layer):
    if isinstance(layer, tf.keras.layers.Conv2D):
        print('find it')
        return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
    return layer

In [14]:
model_for_pruning = tf.keras.models.clone_model(
    model,
    clone_function=apply_pruning_to_conv,
)

find it
find it
Instructions for updating:
Please use `layer.add_weight` method instead.


In [15]:
model_for_pruning.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                          metrics=['accuracy'])

In [16]:
model_for_pruning.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
batch_normalization (BatchNo (None, 26, 26, 12)        48        
_________________________________________________________________
re_lu (ReLU)                 (None, 26, 26, 12)        0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 12)        0         
_________________________________________________________________
prune_low_magnitude_conv2d_1 (None, 11, 11, 24)        5210  

In [17]:
logdir = os.path.join('D:\\coursera\\YoLoSerirs\\Prune_Test\\log')
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
logdir

'D:\\coursera\\YoLoSerirs\\Prune_Test\\log'

In [18]:
model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)

Train on 54000 samples, validate on 6000 samples


<tensorflow.python.keras.callbacks.History at 0x1db94664b88>

In [19]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9738
Pruned test accuracy: 0.8513


In [20]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

tf.keras.models.save_model(model_for_export, '../Prune_Test/prune.h5', include_optimizer=False)

In [21]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

pruned_tflite_file = '../Prune_Test/prune_tfl.tflite'

with open(pruned_tflite_file, 'wb') as f:
    f.write(pruned_tflite_model)

In [22]:
def get_gzipped_model_size(file):
    # Returns size of gzipped model, in bytes.
    import zipfile
    
    name = file.split('/')[-1].split('.')[0]
    zipped_file = os.path.join('../Prune_Test/', name+'.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

In [23]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/original.h5')))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/prune.h5')))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/prune_tfl.tflite')))

Size of gzipped baseline Keras model: 36356.00 bytes
Size of gzipped pruned Keras model: 29738.00 bytes
Size of gzipped pruned TFlite model: 27178.00 bytes
