In [1]:
# Version Check
import sys
import tensorflow as tf
import tensorflow_datasets as tfds
import PIL
import pandas as pd
import numpy as np
import scipy
print("python", sys.version)
print("tensorflow", tf.__version__)
print("tensorflow-datasets", tfds.__version__)
print("Pillow", PIL.__version__)
print("pandas", pd.__version__)
print("numpy", np.__version__)
print("scipy", scipy.__version__)
print()
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))
print("Built with CUDA:", tf.test.is_built_with_cuda())
print("Built with GPU support:", tf.test.is_built_with_gpu_support())

python 3.8.0 (default, Nov  6 2019, 16:00:02) [MSC v.1916 64 bit (AMD64)]
tensorflow 2.6.2
tensorflow-datasets 4.4.0
Pillow 8.3.2
pandas 1.3.3
numpy 1.19.5
scipy 1.7.1

Num GPUs Available: 1
Built with CUDA: True
Built with GPU support: True


In [2]:
from tqdm import tqdm

In [17]:
# @title parameters (work in colab)
t_epoch = 10 # @param {type:"slider", min:1, max:100, step:1}
s_epoch = 5 # @param {type:"slider", min:1, max:100, step:1}
learning_rate = 0.01 
batch_size = 64 # @param [32, 64, 128, 256] {type:"raw"}
temperature = 3 # @param {type:"slider", min:1, max:10, step:1}
alpha = 0.5 # @param {type:"slider", min:0.1, max:0.9, step:0.1}

In [18]:
# mnist dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype('float32') / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype('float32') / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

In [19]:
# teacher model

from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, LeakyReLU, MaxPooling2D, Flatten, Dense

i = Input(shape=(28, 28, 1))
out = Conv2D(256, (3, 3), strides=(2, 2), padding='same')(i)
out = LeakyReLU(alpha=0.2)(out)
out = MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same')(out)
out = Conv2D(512, (3, 3), strides=(2, 2), padding='same')(out)
out = Flatten()(out)
out = Dense(10)(out)

t_model = Model(inputs=[i], outputs=[out])

t_model.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 256)       2560      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 256)       0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 256)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 512)         1180160   
_________________________________________________________________
flatten_2 (Flatten)          (None, 25088)             0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                2508

In [20]:
# student model
i = Input(shape=(28, 28, 1))
out = Flatten()(i)
out = Dense(28)(out)
out = Dense(10)(out)

s_model_1 = Model(inputs=[i], outputs=[out])
s_model_2 = tf.keras.models.clone_model(s_model_1)

s_model_1.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 28)                21980     
_________________________________________________________________
dense_5 (Dense)              (None, 10)                290       
Total params: 22,270
Trainable params: 22,270
Non-trainable params: 0
_________________________________________________________________


In [21]:
# Compile

# teacher model
t_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), 
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# student model 1 (disitillation)
s_model_1.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), 
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# student model 2 (no disitillation)
s_model_2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate), 
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [22]:
# Fit

# teacher model
t_history = t_model.fit(x_train, y_train, batch_size=batch_size, epochs=t_epoch)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [24]:
# student loss function
s_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# distillation loss function
d_loss = tf.keras.losses.KLDivergence()

In [25]:
x_train.shape

(60000, 28, 28, 1)

In [31]:
batch_count = x_train.shape[0] // batch_size
opt = tf.keras.optimizers.Adam(learning_rate)

for e in range(s_epoch):
    for _ in range(batch_count):
        batch_num=np.random.randint(0, x_train.shape[0], size=batch_size)
        t_pred = t_model.predict(x_train[batch_num])
        
        with tf.GradientTape()as tape:
            s_pred_1 = s_model_1(x_train[batch_num])
            student_loss = s_loss(y_train[batch_num], s_pred_1)
            distillation_loss = d_loss(
                tf.nn.softmax(t_pred / temperature, axis=1), 
                tf.nn.softmax(s_pred_1 / temperature, axis=1)
            )
            loss = alpha * student_loss + (1 - alpha) * distillation_loss
            
        vars = s_model_1.trainable_variables
        grad = tape.gradient(loss, vars)
        opt.apply_gradients(zip(grad, vars))
        
        with tf.GradientTape() as tape:
            s_pred_2 = s_model_2(x_train[batch_num])
            student_loss = s_loss(y_train[batch_num], s_pred_2)
        
        vars = s_model_2.trainable_variables
        grad = tape.gradient(student_loss, vars)
        opt.apply_gradients(zip(grad, vars))
    
    print("epoch {}".format(e))
    print("case1. when the teacher teachs")
    s_model_1.evaluate(x_test, y_test)
    print("case2. when the student studies alone")
    s_model_2.evaluate(x_test, y_test)
    print("\n")

epoch 0
case1. when the teacher teachs
case2. when the student studies alone


epoch 1
case1. when the teacher teachs
case2. when the student studies alone


epoch 2
case1. when the teacher teachs
case2. when the student studies alone


epoch 3
case1. when the teacher teachs
case2. when the student studies alone


epoch 4
case1. when the teacher teachs
case2. when the student studies alone


