# **Distilling the Knowledge in a Neural Network**

Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531.

Ref.

*https://keras.io/examples/vision/knowledge_distillation/*

## **Default Setting**

In [None]:
import tensorflow as tf

import numpy as np

from adabelief_tf import AdaBeliefOptimizer
from pathlib import Path

print(f"tf.__version__: {tf.__version__}")

tf.__version__: 2.4.1


In [None]:
!nvidia-smi

Tue Apr 20 09:59:51 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Graphics Device     On   | 00000000:0A:00.0 Off |                  N/A |
|  0%   40C    P8    23W / 220W |     16MiB /  7979MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
tf.keras.mixed_precision.set_global_policy("mixed_float16")

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: Graphics Device, compute capability 8.6


In [None]:
# If you wanna avoid below error, you need to run below codes when you start kernel.

# UnknownError: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, 
# so try looking to see if a warning log message was printed above. [Op:Conv2D]

# Ref: https://blog.naver.com/vft1500/221793591386

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices("GPU")
        print(f"{len(gpus)} Physical GPUs, {len(logical_gpus)} Logical GPUs.")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs.


In [None]:
class HParams(object):
    def __init__(self):
        self.seed = 42
        
        self.num_classes = 10
        
        self.vl_size = 10_000

        self.global_batch_size = 256
        self.buffer_size = 20_000
        self.auto = tf.data.experimental.AUTOTUNE

        self.image_size = [112, 112]

        self.init_lr = 1e-3
        self.epochs = 10
        
        self.steps_per_epoch = None
        self.validation_steps = None
        self.steps_per_execution = 16

HPARAMS = HParams()

## **Prepare Dataset**

In [None]:
@tf.function
def resizing_and_rescaling(images, labels):
    images = tf.expand_dims(images, axis = -1)
    images = tf.image.convert_image_dtype(images, tf.float32)
    images = tf.image.resize(images, HPARAMS.image_size)
    # images = tf.image.grayscale_to_rgb(images)
    labels = tf.cast(labels, tf.int32)
    return images, labels


def get_shapes(element_spec):
    return [get_shapes(e) if isinstance(e, tuple) else e.shape for e in element_spec]

In [None]:
def get_dataset(
    batch_size = HPARAMS.global_batch_size
):
    ## Load dataset from tfds.
    (tr_X, tr_Y), (ts_X, ts_Y) = tf.keras.datasets.mnist.load_data()

    ## Train test split.
    tr_X, vl_X = tr_X[HPARAMS.vl_size:], tr_X[:HPARAMS.vl_size]
    tr_Y, vl_Y = tr_Y[HPARAMS.vl_size:], tr_Y[:HPARAMS.vl_size]

    ## Building.
    tr_ds = tf.data.Dataset.from_tensor_slices((tr_X, tr_Y)
                ).cache(
                ).repeat(
                ).shuffle(HPARAMS.buffer_size, reshuffle_each_iteration = True,
                ).batch(batch_size
                ).map(resizing_and_rescaling, num_parallel_calls = HPARAMS.auto
                ).prefetch(HPARAMS.auto)
    
    vl_ds = tf.data.Dataset.from_tensor_slices((vl_X, vl_Y)
                ).cache(
                ).repeat(
                # ).shuffle(HPARAMS.buffer_sz, reshuffle_each_iteration = True,
                ).batch(batch_size
                ).map(resizing_and_rescaling, num_parallel_calls = HPARAMS.auto
                ).prefetch(HPARAMS.auto)

    ts_ds = tf.data.Dataset.from_tensor_slices((ts_X, ts_Y)
                ).cache(
                # ).shuffle(HPARAMS.buffer_size, reshuffle_each_iteration = True,
                ).batch(batch_size
                ).map(resizing_and_rescaling, num_parallel_calls = HPARAMS.auto
                ).prefetch(HPARAMS.auto)

    steps_per_epoch  = np.ceil(np.shape(tr_X)[0] / batch_size)
    validation_steps = np.ceil(np.shape(vl_Y)[0] / batch_size)
    
    HPARAMS.steps_per_epoch  = steps_per_epoch
    HPARAMS.validation_steps = validation_steps
    
    print(f"# of training data: {tr_X.shape[0]}")
    print(f"# of validation data: {vl_X.shape[0]}")
    print(f"# of test data: {ts_X.shape[0]}\n")

    print(f"Global batch size: {batch_size}")
    print(f"Steps per epoch: {steps_per_epoch} (total {steps_per_epoch * HPARAMS.epochs} batches)")
    print(f"Validation steps: {validation_steps} (total {validation_steps * HPARAMS.epochs} batches)\n")
    
    print(f"Steps per execution: {HPARAMS.steps_per_execution}\n")

    print(f"tr_ds.element_spec: {get_shapes(tr_ds.element_spec)}")
    print(f"vl_ds.element_spec: {get_shapes(vl_ds.element_spec)}")
    print(f"ts_ds.element_spec: {get_shapes(ts_ds.element_spec)}\n")

    return tr_ds, vl_ds, ts_ds

## **Modeling**

### **Baseline**

In [None]:
def bn_ReLU_conv2D(x, filters, kernel_size):
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation(tf.nn.relu6)(x)
    x = tf.keras.layers.Conv2D(filters, kernel_size, padding = "same")(x)
    
    return x


def transition_block(x):
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(x.shape[-1] // 2, 1, padding = "same")(x)
    x = tf.keras.layers.AveragePooling2D((2, 2), strides = 2)(x)

    return x


def dense_block(x, num_conv, growth_rate):
    for i in range(num_conv):
        residual = x
        x = bn_ReLU_conv2D(x, 4 * growth_rate, 1)
        x = bn_ReLU_conv2D(x, growth_rate, 3)
        x = tf.keras.layers.Concatenate(axis = -1)([x, residual])

    return x

In [None]:
def create_NN(
    model_name, 
    growth_rate = 32, 
    embedding_dim = HPARAMS.num_classes
):
    ## DenseNet-121
    x = model_input = tf.keras.layers.Input(shape = (*HPARAMS.image_size, 1)) ## grayscale, not rgb

    ## Entry Flow
    x = tf.keras.layers.Conv2D(2 * growth_rate, 7, strides = 2, padding = "same")(x)
    x = tf.keras.layers.MaxPooling2D((3, 3), strides = 2, padding = "same")(x)


    ## Middle Flow
    for i, num_conv in enumerate([6, 12, 24, 16]):
        x = dense_block(x, num_conv, growth_rate)
        if i is not 3: 
            x = transition_block(x)

    ## Exit Flow
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(embedding_dim)(x)
    model_output = x = tf.keras.layers.Activation("linear", dtype = tf.float32)(x) ## no classifier!

    return tf.keras.Model(
        inputs = model_input,
        outputs = model_output,
        name = model_name
    )

In [None]:
# tmp = create_NN("tmp")
# tmp.summary()

In [None]:
# del tmp

### **Loss Function**

### **Teacher and Student Model**

In [None]:
class DistillationModelWrapper(tf.keras.Model):
    def __init__(
        self, 
        teacher, 
        student, 
        **kwargs
    ):
        super(DistillationModelWrapper, self).__init__(**kwargs)
        ## Assert teacher and Student have no classification layer (i.e. softmax).
        self.teacher = teacher
        self.student = student
        
    def compile(
        self, 
        optimizer,
        student_loss_fn,
        distillation_loss_fn,
        metrics,
        alpha = 0.9,
        temperature = 10,
        **kwargs,
    ):
        super(DistillationModelWrapper, self).compile(
            optimizer = optimizer,
            metrics = metrics,
            **kwargs
        )
        # self.optimizer = optimizer
        self.student_loss_fn = student_loss_fn ## sparse categorical crossentropy
        self.distillation_loss_fn = distillation_loss_fn ## Kullback–Leibler divergence
        # self.metrics = metrics ## accuracy
        self.alpha = alpha
        self.temperature = temperature
        
    @tf.function
    def train_step(self, x):
        inp, tar = x

        teacher_pred = self.teacher(inp, training = False)
        
        with tf.GradientTape() as tape:
            student_pred = self.student(inp)
            
            ## Calculate losses.
            student_loss = self.student_loss_fn(
                tar, 
                student_pred,
            )
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_pred / self.temperature),
                tf.nn.softmax(student_pred / self.temperature),
            )
            
            loss = (1. - self.alpha) * student_loss + self.alpha * distillation_loss
            scaled_loss = self.optimizer.get_scaled_loss(loss)
            
        scaled_grads = tape.gradient(scaled_loss, self.student.trainable_variables)
        grads = self.optimizer.get_unscaled_gradients(scaled_grads)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_weights))

        self.compiled_metrics.update_state(tar, student_pred)
        
        results = {m.name: m.result() for m in self.metrics}
        results.update({
            "student_loss": student_loss,
            "distillation_loss": distillation_loss,
        })
        
        return results
    
    @tf.function
    def test_step(self, x):
        inp, tar = x
        
        teacher_pred = self.teacher(inp, training = False)
        student_pred = self.student(inp, training = False)
        
        student_loss = self.student_loss_fn(
            tar, 
            student_pred,
        )
        distillation_loss = self.distillation_loss_fn(
            tf.nn.softmax(teacher_pred / self.temperature),
            tf.nn.softmax(student_pred / self.temperature),
        )
        
        loss = (1. - self.alpha) * student_loss + self.alpha * distillation_loss
        
        self.compiled_metrics.update_state(tar, student_pred)
        
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        
        return results

## **Fit**

### **Callbacks**

In [None]:
def get_callbacks(model_name, is_distiller = False):    
    ## Checkpoint callback.
    if is_distiller:
        ckpt_path = Path(f"ckpt/{model_name}/" + "cp-{epoch:03d}-{val_student_loss:.4f}.ckpt")
        ckpt_path.parent.mkdir(parents = True, exist_ok = True)

        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            ckpt_path, 
            verbose = 0, 
            monitor = "val_student_loss", 
            save_weights_only = True, 
            save_best_only = True,
        )
    else:
        ckpt_path = Path(f"ckpt/{model_name}/" + "cp-{epoch:03d}-{val_loss:.4f}.ckpt")
        ckpt_path.parent.mkdir(parents = True, exist_ok = True)

        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            ckpt_path, 
            verbose = 0, 
            monitor = "val_loss", 
            save_weights_only = True, 
            save_best_only = True,
        )

    ## TensorBoard callback.
    log_dir = Path(f"logs/fit/{model_name}")
    tb_callback = tf.keras.callbacks.TensorBoard(
        log_dir = log_dir, 
        histogram_freq = 1,
    )

    return [cp_callback, tb_callback]

### **Checkpoints**

In [None]:
def load_latest_checkpoint(ckpt_folder, growth_rate = 32, is_distiller = False):
    latest = tf.train.latest_checkpoint(ckpt_folder)
    print(f"Load latest checkpoints: {latest}...")
    
    model_name = ckpt_folder.split("/")[-1]
    
    if is_distiller:
        model = DistillationModelWrapper(
            teacher = create_NN(f"{model_name}-teacher-latest", growth_rate = 32),
            student = create_NN(f"{model_name}-student-latest", growth_rate = 12),
        )

        ckpt = tf.train.Checkpoint(model)
        ckpt.restore(latest).expect_partial()
        
        model.compile(
            optimizer = AdaBeliefOptimizer(
                learning_rate = HPARAMS.init_lr, 
                epsilon = 1e-14,
                weight_decay = 1e-5,
                rectify = True,
                print_change_log = False,
            ),
            student_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(),
            distillation_loss_fn = tf.keras.losses.KLDivergence(),
            metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
            steps_per_execution = HPARAMS.steps_per_execution,
        )
    else:
        model = create_NN(
            model_name = f"{model_name}-latest",
            growth_rate = growth_rate,
        )
        
        ckpt = tf.train.Checkpoint(model)
        ckpt.restore(latest).expect_partial()

        model.compile(
            optimizer = AdaBeliefOptimizer(
                learning_rate = HPARAMS.init_lr, 
                epsilon = 1e-14,
                weight_decay = 1e-5,
                rectify = False,
                print_change_log = False,
            ),
            loss = tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
            steps_per_execution = HPARAMS.steps_per_execution,
        )

    print(f"Restored model: {model.name}\n")
    
    return model

### **Baseline with Small Model**

Using the growth_rate parameter defined in DenseNet, it is made thinner than the normal model. The actual number of parameters is about 1M, which is significantly less than the 7M parameters of the large model.

In [None]:
!rm -rf ./ckpt
!rm -rf ./logs

In [None]:
%%time
tr_ds, vl_ds, ts_ds = get_dataset()

student = create_NN(
    model_name = "student",
    growth_rate = 12,
)

print(f"Training model: {student.name} (# of params: {student.count_params() / 1e+6:.2f}M)")

student.compile(
    optimizer = AdaBeliefOptimizer(
        learning_rate = HPARAMS.init_lr, 
        epsilon = 1e-14,
        weight_decay = 1e-5,
        rectify = True,
        print_change_log = False,
    ),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), ## 
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name = "acc")],
    steps_per_execution = HPARAMS.steps_per_execution,
)

_ = student.fit(
    tr_ds,
    validation_data = vl_ds,
    steps_per_epoch = HPARAMS.steps_per_epoch,
    validation_steps = HPARAMS.validation_steps,
    epochs = HPARAMS.epochs,
    verbose = 0,
    callbacks = get_callbacks(student.name),
)

print("Done.")

# of training data: 50000
# of validation data: 10000
# of test data: 10000

Global batch size: 256
Steps per epoch: 196.0 (total 1960.0 batches)
Validation steps: 40.0 (total 400.0 batches)

Steps per execution: 16

tr_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
vl_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
ts_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]

Training model: student (# of params: 1.03M)
Done.
CPU times: user 6min 17s, sys: 31.6 s, total: 6min 48s
Wall time: 7min 42s


In [None]:
latest_student = load_latest_checkpoint(f"./ckpt/{student.name}", growth_rate = 12)
latest_student.evaluate(ts_ds, verbose = 2)

Load latest checkpoints: ./ckpt/student/cp-004-0.0639.ckpt...
Restored model: student-latest

40/40 - 3s - loss: 1.2587 - sparse_categorical_accuracy: 0.9811


[1.2587331533432007, 0.9811000227928162]

### **Baseline with Large Model**

In [None]:
%%time
## Reducing the batch size in half.
tr_ds, vl_ds, ts_ds = get_dataset(batch_size = HPARAMS.global_batch_size // 2)

teacher = create_NN(
    model_name = "teacher",
    growth_rate = 32,
)

print(f"Training model: {teacher.name} (# of params: {teacher.count_params() / 1e+6:.2f}M)")

teacher.compile(
    optimizer = AdaBeliefOptimizer(
        learning_rate = HPARAMS.init_lr, 
        epsilon = 1e-14,
        weight_decay = 1e-5,
        rectify = True,
        print_change_log = False,
    ),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), ##
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name = "acc")],
    steps_per_execution = HPARAMS.steps_per_execution,
)

_ = teacher.fit(
    tr_ds,
    validation_data = vl_ds,
    steps_per_epoch = HPARAMS.steps_per_epoch,
    validation_steps = HPARAMS.validation_steps,
    epochs = HPARAMS.epochs,
    verbose = 0,
    callbacks = get_callbacks(teacher.name),
)

print("Done.")

# of training data: 50000
# of validation data: 10000
# of test data: 10000

Global batch size: 128
Steps per epoch: 391.0 (total 3910.0 batches)
Validation steps: 79.0 (total 790.0 batches)

Steps per execution: 16

tr_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
vl_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
ts_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]

Training model: teacher (# of params: 7.05M)
Done.
CPU times: user 10min 35s, sys: 53.5 s, total: 11min 28s
Wall time: 14min 33s


In [None]:
latest_teacher = load_latest_checkpoint(f"./ckpt/{teacher.name}", growth_rate = 32)
latest_teacher.evaluate(ts_ds, verbose = 2)

Load latest checkpoints: ./ckpt/teacher/cp-005-0.0824.ckpt...
Restored model: teacher-latest

79/79 - 4s - loss: 1.2586 - sparse_categorical_accuracy: 0.9830


[1.2585575580596924, 0.9829999804496765]

### **Distiller Model**

With the weights of the latest teacher model fixed, we train a new student model from scratch. By referring to the method introduced in the paper and the keras tutorial, the weight of the soft label was increased (loss weight = 0.9). Specifically, by using kl-divergence as the distiller loss function, the difference between the two entropy probability distributions was calculated.

In [None]:
## About twice as much VRAM is required, reducing the batch size in half.
tr_ds, vl_ds, ts_ds = get_dataset(batch_size = HPARAMS.global_batch_size // 2)

latest_teacher.trainable = False ## freeze
distiller = DistillationModelWrapper(
    teacher = latest_teacher,
    student = create_NN("distiller-student", growth_rate = 12),
    name = "distiller",
)

print(f"Training model: {distiller.name}... ") #"(# of params: {distiller.count_params() / 1e+6:.2f}M)")

distiller.compile(
    optimizer = AdaBeliefOptimizer(
        learning_rate = HPARAMS.init_lr, 
        epsilon = 1e-14,
        weight_decay = 1e-5,
        rectify = True,
        print_change_log = False,
    ),
    student_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
    distillation_loss_fn = tf.keras.losses.KLDivergence(),
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy()],
    steps_per_execution = HPARAMS.steps_per_execution,
)

distiller.fit(
    tr_ds,
    validation_data = vl_ds,
    steps_per_epoch = HPARAMS.steps_per_epoch,
    validation_steps = HPARAMS.validation_steps,
    epochs = HPARAMS.epochs,
    verbose = 2,
    callbacks = get_callbacks(distiller.name, is_distiller = True),
)

# of training data: 50000
# of validation data: 10000
# of test data: 10000

Global batch size: 128
Steps per epoch: 391.0 (total 3910.0 batches)
Validation steps: 79.0 (total 790.0 batches)

Steps per execution: 16

tr_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
vl_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]
ts_ds.element_spec: [TensorShape([None, 112, 112, 1]), TensorShape([None])]

Training model: distiller... 
Epoch 1/10
391/391 - 139s - sparse_categorical_accuracy: 0.7782 - student_loss: 0.2169 - distillation_loss: 0.0370 - val_sparse_categorical_accuracy: 0.9667 - val_student_loss: 0.1647
Epoch 2/10
391/391 - 64s - sparse_categorical_accuracy: 0.9730 - student_loss: 0.1113 - distillation_loss: 0.0228 - val_sparse_categorical_accuracy: 0.9792 - val_student_loss: 0.0144
Epoch 3/10
391/391 - 64s - sparse_categorical_accuracy: 0.9847 - student_loss: 0.0489 - distillation_loss: 0.0208 - val_sparse_categorical_accuracy: 0.9780 -

<tensorflow.python.keras.callbacks.History at 0x7f65923c5e10>

In [None]:
latest_distiller = load_latest_checkpoint(f"./ckpt/{distiller.name}", is_distiller = True)
latest_distiller.evaluate(ts_ds, verbose = 2)

Load latest checkpoints: ./ckpt/distiller/cp-008-0.0037.ckpt...
Restored model: distillation_model_wrapper_4

79/79 - 5s - sparse_categorical_accuracy: 0.9928 - student_loss: 1.2945


[0.9927999973297119, 1.2945350408554077]

## **Commit to Tensorboard Dev.**

In [None]:
!tensorboard dev upload --logdir ./logs \
    --name "Experiment of 'Distilling the Knowledge in a Neural Network'" \
    --description "Implemented training results from the paper 'https://arxiv.org/abs/1503.02531'" \
    --one_shot

In [6]:
from IPython import display

display.IFrame(
    src = "https://tensorboard.dev/experiment/up3gbYoJTNWgPkAmzYexNw/",
    width = "100%",
    height = "1000px"
)