In [None]:
import tensorflow as tf
print("version :",tf.__version__)

In [None]:
dset = tf.keras.datasets.cifar10
(train_img,train_label),(test_img,test_label) = dset.load_data()

In [None]:
cls= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10,10))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(test_img[i], cmap=plt.cm.binary)
    plt.xlabel(cls[test_label[i][0]])
plt.show()

In [None]:
# Resnet 50,101,152 ... 사용가능
# https://www.tensorflow.org/api_docs/python/tf/keras/applications
from tensorflow.keras.applications.resnet50 import ResNet50
pretrained_model = ResNet50(include_top=False, input_shape = (32, 32 ,3), weights = 'imagenet',classes = 10)

In [None]:
# cifar10 data는 32x32 size, label이 10개 이므로 classes=10으로 설정
pretrained_model.summary()

In [None]:
from tensorflow.keras.layers import Dense, Flatten, MaxPooling2D
from tensorflow.keras import Input

In [None]:
# tensor (32,32,3) size
inputs = Input(shape=(32,32,3))
x = tf.keras.layers.experimental.preprocessing.Resizing(32, 32)(inputs)
x = tf.keras.applications.resnet50.preprocess_input(inputs)
x = pretrained_model(x, training = False)
x = Flatten()(x)
# Dense(10) -> 10개의 class로 softmax해주는 과정
outputs = Dense(10, activation = 'softmax')(x)

model = tf.keras.Model(inputs,outputs)

In [None]:
model.summary()

In [None]:
# trainable을 False로 하면 전체 Freeze
pretrained_model.trainable = True
print(len(pretrained_model.layers))

In [None]:
# Freeze 부분과 아닌 부분들 확인
for i in pretrained_model.layers[:160]:
    i.trainable = False

for j in pretrained_model.layers[159:]:
    print(j.name,j.trainable)

In [None]:
# freeze해제 후 compile을 다시 해주어야한다, lr 조정하며 학습진행
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss = 'categorical_crossentropy',metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
# 위의 summary를 보면 non-trainable params와 trainable params 변경 확인
# 학습 진행

# 숫자로 변경 기존 class : 'airplane' ...
train_y = tf.keras.utils.to_categorical(train_label, 10)
test_y = tf.keras.utils.to_categorical(test_label, 10)

model.fit(train_img,train_y,epochs = 5, validation_data = (test_img,test_y),batch_size=64)