In [1]:
import tensorflow as tf

In [13]:
(X_train,y_train),(X_test,y_test) = tf.keras.datasets.mnist.load_data()

In [16]:
X_train = X_train.reshape(-1,28,28,1) / 255
X_test = X_test.reshape(-1,28,28,1) / 255

In [None]:
# Teacher
# Student

In [2]:
# inner
class Distiller(tf.keras.Model):
    self.teacher = tf.keras.Sequential()
    self.student = tf.keras.Sequential()

In [None]:
# outer

In [9]:
teacher = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10),
    ],
    name="teacher",
)

In [11]:
teacher.summary()

Model: "teacher"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_6 (Conv2D)           (None, 14, 14, 256)       2560      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 14, 14, 256)       0         
                                                                 
 max_pooling2d_3 (MaxPooling  (None, 14, 14, 256)      0         
 2D)                                                             
                                                                 
 conv2d_7 (Conv2D)           (None, 7, 7, 512)         1180160   
                                                                 
 flatten_3 (Flatten)         (None, 25088)             0         
                                                                 
 dense_3 (Dense)             (None, 10)                250890    
                                                           

# Offline

In [18]:
teacher.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In [20]:
# Train and evaluate teacher on data.
teacher.fit(X_train, y_train, epochs=5)
teacher.evaluate(X_test, y_test)

Epoch 1/5


2022-10-26 20:11:22.010167: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.13452531397342682, 0.9688000082969666]

In [10]:
student = tf.keras.Sequential(
    [
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(10),
    ],
    name="student",
)

In [12]:
student.summary()

Model: "student"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_8 (Conv2D)           (None, 14, 14, 16)        160       
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 14, 14, 16)        0         
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 14, 14, 16)       0         
 2D)                                                             
                                                                 
 conv2d_9 (Conv2D)           (None, 7, 7, 32)          4640      
                                                                 
 flatten_4 (Flatten)         (None, 1568)              0         
                                                                 
 dense_4 (Dense)             (None, 10)                15690     
                                                           

# KD 비교용

In [15]:
student_scratch = tf.keras.models.clone_model(student)

In [None]:
student_scratch.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
) # Multi-task-loss 지원함

In [None]:
student_scratch.fit(X_train, y_train, epochs=5)
student_scratch.evaluate(X_test, y_test)

In [None]:
# 2개 이상의 모델이 상호작용해서 학습
# fit function - 
# overriding

In [21]:
# function overloading이 지원한다면 다른 함수
def a():
    pass

def a(x):
    pass

In [49]:
teacher(X_train)

<tf.Tensor: shape=(60000, 10), dtype=float32, numpy=
array([[-14.082058  , -14.64071   , -16.501438  , ...,  -6.3327193 ,
          2.1890817 ,  -0.6540017 ],
       [ 20.622925  , -17.089869  ,   7.894786  , ...,  -5.4389305 ,
          2.7200747 ,  -3.3014212 ],
       [-19.907085  ,  -1.3406717 ,   3.2002792 , ...,  -0.32502627,
          3.074411  ,   2.383148  ],
       ...,
       [-15.1645155 , -11.464357  , -24.730722  , ..., -18.912764  ,
         10.860882  ,   1.5428874 ],
       [ -2.6363785 ,  -5.8864026 ,  -1.8576436 , ...,  -6.6641684 ,
         -3.5393956 , -10.297437  ],
       [-13.003723  , -13.508585  ,  -4.78139   , ...,  -6.4610333 ,
         20.233877  ,   1.1128802 ]], dtype=float32)>

In [100]:
# outer
class Distiller(tf.keras.Model):
    
    def __init__(self, teacher, student):
        super().__init__(self)
        self.teacher = teacher
        self.student = student
        
    def compile(self, optimizer, metrics, student_loss, distill_loss, temperature, alpha): # parameter가 달라도 overriding (function overloading 지원 안함)
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss = student_loss
        self.distill_loss = distill_loss
        self.temperature = temperature
        self.alpha = alpha
        
    
    def train_step(self, data):
        X, y = data
        t_prediction = self.teacher(X, training=False)
        with tf.GradientTape() as tape:
            s_prediction = self.student(X, training=True)
            s_loss = self.student_loss(y, s_prediction)
            d_loss = self.distill_loss(
                tf.nn.softmax(t_prediction / self.temperature, axis=1),
                tf.nn.softmax(s_prediction / self.temperature, axis=1)
            )
            loss = s_loss + self.alpha*d_loss
        gradients = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.student.trainable_variables))
        self.compiled_metrics.update_state(y, s_prediction)
        results = { m.name: m.result() for m in self.metrics }
        results.update({'s_loss': s_loss, 'd_loss': d_loss, 'total_loss': loss})
        return results
    
    def call(self, x):
        return self.student(x)
    
    def test_step(self, data):
        X, y = data
        y_pred = self.student(X, training=False)
        s_loss = self.student_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        results = { m.name: m.result() for m in self.metrics }
        results.update({'s_loss': s_loss})
        return results
    

In [101]:
kd = Distiller(teacher, student)

In [102]:
kd.compile(optimizer='rmsprop', 
           metrics=['acc'], 
           student_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
           distill_loss = tf.keras.losses.KLDivergence(), temperature = 2, alpha=0.2)

In [90]:
kd.fit(X_train,y_train, epochs=2)

Epoch 1/2
 196/1875 [==>...........................] - ETA: 27s - acc: 0.9861 - s_loss: 0.0474 - d_loss: 0.0608 - total_loss: 0.0596

KeyboardInterrupt: 

In [103]:
kd.evaluate(X_test,y_test)



[0.9833999872207642, 0.0014986938331276178]