In [8]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import numpy as np

# MNIST 데이터 로드
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

# 모델 정의 함수
def create_large_model():
    model = models.Sequential([
        layers.Flatten(input_shape=(28, 28)),
        layers.Dense(512, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(10)
    ])
    return model

def create_model():
    model = models.Sequential([
        layers.Flatten(input_shape=(28, 28)),
        layers.Dense(128, activation='relu'),
        layers.Dense(10)
    ])
    return model

# 큰 모델 생성 및 훈련
large_model = create_large_model()
large_model.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
large_model.fit(train_images, train_labels, epochs=3)

# 큰 모델에서 소프트 타겟 생성
temperature = 50.0
large_model_logits = large_model.predict(train_images)
soft_targets = tf.nn.softmax(large_model_logits / temperature).numpy()

# 작은 모델 생성
small_model = create_model()

# 작은 모델의 손실 함수 정의
def distillation_loss(y_true, y_pred, soft_targets, temperature):
    soft_targets_batch = tf.gather(soft_targets, tf.range(tf.shape(y_pred)[0]))
    soft_loss = tf.keras.losses.categorical_crossentropy(soft_targets_batch, y_pred / temperature, from_logits=True)
    hard_loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
    return soft_loss * (temperature ** 2) + hard_loss

# 작은 모델 컴파일
small_model.compile(optimizer='adam',
                    loss=lambda y_true, y_pred: distillation_loss(y_true, y_pred, soft_targets, temperature),
                    metrics=['accuracy'])

# 작은 모델 훈련
batch_size = 32
small_model.fit(train_images, train_labels, epochs=3, batch_size=batch_size)

# 성능 평가
test_loss, test_acc = small_model.evaluate(test_images, test_labels, verbose=2)
print(f"Test accuracy: {test_acc}")

Epoch 1/3
Epoch 2/3
Epoch 3/3
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
313/313 - 1s - loss: 5755.9971 - accuracy: 0.8873 - 1s/epoch - 3ms/step
Test accuracy: 0.8873000144958496
