In [1]:
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow_model_optimization as tfmot

In [2]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

In [3]:
input_layer = tf.keras.layers.Input([28, 28])
net = tf.keras.layers.Reshape(target_shape=(28, 28, 1))(input_layer)
net = tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=None)(net)
net = tf.keras.layers.ReLU()(net)
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(net)
net = tf.keras.layers.Flatten()(net)
logits = tf.keras.layers.Dense(10)(net)

model = tf.keras.Model(input_layer, logits)

In [4]:
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [5]:
model.fit(
    train_images,
    train_labels,
    epochs=2,
    validation_split=0.1,
)

Train on 54000 samples, validate on 6000 samples
Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x19a55f5da08>

测试nni

In [7]:
from nni.compression.tensorflow import FPGMPruner
from nni.compression.tensorflow import LevelPruner

In [9]:
def show_prun(model):
    for i, w in enumerate(model.get_weights()):
        print("{} -- Total:{}, Zeros: {:.2f}%".format(model.weights[i].name, w.size, np.sum(w == 0) / w.size * 100))

In [12]:
config_list = [{ 'sparsity': 0.8, 'op_types': ['default'] }]
pruner = LevelPruner(model, config_list)
model_prun = pruner.compress()

In [14]:
show_prun(model_prun)

conv2d/kernel:0 -- Total:108, Zeros: 0.00%
conv2d/bias:0 -- Total:12, Zeros: 0.00%
dense/kernel:0 -- Total:20280, Zeros: 0.00%
dense/bias:0 -- Total:10, Zeros: 0.00%


--------------------------

In [6]:
_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

Baseline test accuracy: 0.9751


In [7]:
tf.keras.models.save_model(model, '../Prune_Test/original.h5', include_optimizer=False)

剪枝

In [8]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

In [9]:
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}

In [10]:
model_for_pruning = prune_low_magnitude(model, **pruning_params)

Instructions for updating:
Please use `layer.add_weight` method instead.


In [11]:
model_for_pruning.compile(optimizer='adam',
                          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                          metrics=['accuracy'])

In [12]:
model_for_pruning.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
prune_low_magnitude_reshape  (None, 28, 28, 1)         1         
_________________________________________________________________
prune_low_magnitude_conv2d ( (None, 26, 26, 12)        230       
_________________________________________________________________
prune_low_magnitude_re_lu (P (None, 26, 26, 12)        1         
_________________________________________________________________
prune_low_magnitude_max_pool (None, 13, 13, 12)        1         
_________________________________________________________________
prune_low_magnitude_flatten  (None, 2028)              1         
_________________________________________________________________
prune_low_magnitude_dense (P (None, 10)                40572 

In [13]:
logdir = os.path.join('D:\\coursera\\YoLoSerirs\\Prune_Test\\log')
callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
logdir

'D:\\coursera\\YoLoSerirs\\Prune_Test\\log'

In [14]:
model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)

Train on 54000 samples, validate on 6000 samples
Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x163a004a2c8>

In [15]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

Baseline test accuracy: 0.9751
Pruned test accuracy: 0.9686


In [16]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

tf.keras.models.save_model(model_for_export, '../Prune_Test/prune.h5', include_optimizer=False)

In [17]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

pruned_tflite_file = '../Prune_Test/prune_tfl.tflite'

with open(pruned_tflite_file, 'wb') as f:
    f.write(pruned_tflite_model)

In [18]:
def get_gzipped_model_size(file):
    # Returns size of gzipped model, in bytes.
    import zipfile
    
    name = file.split('/')[-1].split('.')[0]
    zipped_file = os.path.join('../Prune_Test/', name+'.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

In [19]:
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/original.h5')))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/prune.h5')))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size('../Prune_Test/prune_tfl.tflite')))

Size of gzipped baseline Keras model: 78333.00 bytes
Size of gzipped pruned Keras model: 25895.00 bytes
Size of gzipped pruned TFlite model: 24566.00 bytes
