# Purning

## Pruning for on-device inference w/ XNNPACK

Welcome to the guide on Keras weights pruning for improving latency of on-device inference via [XNNPACK](https://github.com/google/XNNPACK).

欢迎阅读Keras权重修剪指南，以通过XNNPACK改善设备上推理的延迟。

<br>

This guide presents the usage of the newly introduced `tfmot.sparsity.keras.PruningPolicy` API and demonstrates how it could be used for accelerating mostly convolutional models on modern CPUs using [XNNPACK Sparse inference](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md#sparse-inference).

本指南介绍了新引入的API的用法 `tfmot.sparsity.keras.PruningPolicy` ，并演示了如何使用XNNPACK稀疏推理在现代CPU上加速主要的卷积模型。

<br>

The guide covers the following steps of the model creation process:

本指南涵盖了模型创建过程的以下步骤：

* Build and train the dense baseline 构建并训练密集基线
* Fine-tune model with pruning 使用修剪微调模型
* Convert to TFLite 转换为TFLite
* On-device benchmark 设备上基准测试

<br>

The guide doesn't cover the best practices for the fine-tuning with pruning. For more detailed information on this topic, please check out our [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md).

本指南没有介绍使用修剪进行微调的最佳实践。有关此主题的更多详细信息，请查看我们的[综合指南](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md)。

### Setup 设置

In [None]:
# pip install -q tensorflow
# pip install -q tensorflow-model-optimization

In [33]:
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


### Build and train the dense model 构建和训练密集模型

We build and train a simple baseline CNN for classification task on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

我们构建并训练了一个简单的基线CNN，用于[CIFAR10数据集](https://www.cs.toronto.edu/~kriz/cifar.html)的分类任务。

In [34]:
# Load CIFAR10 dataset.

(ds_train, ds_val, ds_test), ds_info = tfds.load("cifar10",
                                        split=['train[:90%]', 'train[90%:]', 'test'],
                                        as_supervised=True,
                                        with_info=True,
                                        )

# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(img, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.image.convert_image_dtype(img, tf.float32), label



In [35]:
# Load the data in batches of 128 images.

batch_size = 128
def prepare_dataset(ds, buffer_size=None):
    ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) # 将ds数据集中的每个元素传递给normalize_img函数进行预处理。
    ds = ds.cache() # cache方法将预处理后的数据集缓存到内存中，可以在需要多次迭代访问数据集时提高访问速度。
    if buffer_size:
        ds = ds.shuffle(buffer_size)
    ds = ds.batch(batch_size) # ds.batch(batch_size)表示将ds数据集划分为多个批次，每个批次包含batch_size个元素。
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE) # 当模型训练时，prefetch方法会异步地从数据集中预取一定数量的元素，并将它们放入缓冲区中。
    return ds

ds_train = prepare_dataset(ds_train, ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

In [39]:
# Build the dense baseline model.

dense_model = keras.Sequential([keras.layers.InputLayer(input_shape=(32,32,3)),

                                keras.layers.ZeroPadding2D(padding=1), # 使用ZeroPadding2D层，我们可以有效地扩展输入特征图的大小，而不会损失输入图像的信息。

                                keras.layers.Conv2D(filters=8,kernel_size=(3, 3),strides=(2, 2),padding='valid'),
                                keras.layers.BatchNormalization(),
                                keras.layers.ReLU(),

                                keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
                                keras.layers.BatchNormalization(),
                                keras.layers.ReLU(),

                                keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
                                keras.layers.BatchNormalization(),
                                keras.layers.ReLU(),

                                keras.layers.ZeroPadding2D(padding=1),

                                keras.layers.DepthwiseConv2D(kernel_size=(3, 3), strides=(2, 2), padding='valid'),
                                keras.layers.BatchNormalization(),
                                keras.layers.ReLU(),

                                keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
                                keras.layers.BatchNormalization(),
                                keras.layers.ReLU(),

                                keras.layers.GlobalAveragePooling2D(),
                                keras.layers.Flatten(),
                                keras.layers.Dense(10)])

In [40]:
# Compile and train the dense model for 10 epochs.
dense_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

In [42]:
dense_model.fit(
  ds_train,
  epochs=10,
  validation_data=ds_val,
  verbose=2)

Epoch 1/10


352/352 - 2s - loss: 1.4151 - accuracy: 0.4879 - val_loss: 1.4340 - val_accuracy: 0.4856 - 2s/epoch - 6ms/step
Epoch 2/10
352/352 - 3s - loss: 1.4071 - accuracy: 0.4920 - val_loss: 1.5042 - val_accuracy: 0.4682 - 3s/epoch - 8ms/step
Epoch 3/10
352/352 - 2s - loss: 1.3999 - accuracy: 0.4938 - val_loss: 1.4051 - val_accuracy: 0.4970 - 2s/epoch - 6ms/step
Epoch 4/10
352/352 - 3s - loss: 1.3931 - accuracy: 0.4970 - val_loss: 1.4201 - val_accuracy: 0.4706 - 3s/epoch - 8ms/step
Epoch 5/10
352/352 - 2s - loss: 1.3864 - accuracy: 0.5008 - val_loss: 1.3930 - val_accuracy: 0.4928 - 2s/epoch - 6ms/step
Epoch 6/10
352/352 - 3s - loss: 1.3771 - accuracy: 0.5041 - val_loss: 1.4876 - val_accuracy: 0.4630 - 3s/epoch - 7ms/step
Epoch 7/10
352/352 - 2s - loss: 1.3739 - accuracy: 0.5050 - val_loss: 1.4081 - val_accuracy: 0.4882 - 2s/epoch - 7ms/step
Epoch 8/10
352/352 - 3s - loss: 1.3681 - accuracy: 0.5085 - val_loss: 1.4423 - val_accuracy: 0.4754 - 3s/epoch - 7ms/step
Epoch 9/10
352/352 - 1s - loss: 1.3

<keras.callbacks.History at 0x7fc0883cabb0>

In [45]:
# Evaluate the dense model
dense_model.evaluate(ds_test, verbose=1)





[1.4520046710968018, 0.4724000096321106]

### Build the sparse model 构建稀疏模型

Using the instructions from the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md), we apply `tfmot.sparsity.keras.prune_low_magnitude` function with parameters that target on-device acceleration via pruning i.e. `tfmot.sparsity.keras.PruneForLatencyOnXNNPack` policy.

使用来自[综合指南](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md)的说明，我们应用 `tfmot.sparsity.keras.prune_low_magnitude` 具有通过修剪以设备上加速为目标的参数的函数，即政策 `tfmot.sparsity.keras.PruneForLatencyOnXNNPack` 。

In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after after 5 epochs.
end_epochs = 5

num_iterations_per_epoch = len(ds_train)  # 幾個 batch
end_step  = num_iterations_per_epoch  * end_epochs

# Define parameters for pruning.

In [46]:
len(ds_train)

352