In [1]:
# 모듈 임포트
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

## 데이터셋 준비

In [2]:
train_datasets = tfds.load('cats_vs_dogs',split='train[:80%]')
valid_datasets = tfds.load('cats_vs_dogs',split='train[80%:]')

def preprocessing(data):
    x = data['image']
    y = data['label']
    
    x = x / 255
    x = tf.image.resize(x, size=(224, 224))
    return x, y

batch_size = 32

train_data = train_datasets.map(preprocessing).batch(batch_size)
valid_data = train_datasets.map(preprocessing).batch(batch_size)

### Transfer Layer 구현

In [3]:
transfer_model = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
transfer_model.trainable = False

### 모델 정의

In [4]:
model = tf.keras.models.Sequential([
                                    transfer_model,
                                    tf.keras.layers.Flatten(),
                                    tf.keras.layers.Dropout(0.5),
                                    tf.keras.layers.Dense(512, activation='relu'),
                                    tf.keras.layers.Dense(128, activation='relu'),
                                    tf.keras.layers.Dense(2, activation='softmax')                                 
])

In [5]:

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
vgg16 (Functional)           (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dropout (Dropout)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 512)               12845568  
_________________________________________________________________
dense_1 (Dense)              (None, 128)               65664     
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 258       
Total params: 27,626,178
Trainable params: 12,911,490
Non-trainable params: 14,714,688
___________________________________

## 모델 컴파일,체크포인트,학습

In [6]:
# 컴파일
optimizer = tf.keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['acc'])

  "The `lr` argument is deprecated, use `learning_rate` instead.")


In [7]:
# 체크 포인트
checkpoint_path = 'my_checkpoint.ckpt'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only = True,
                                                save_best_only = True,
                                                monitor = 'val_loss',
                                                verbose = 1)

In [8]:
#학습
model.fit(train_data,
          validation_data=(valid_data),
          epochs=10,
          callbacks = [checkpoint])

Epoch 1/10

Epoch 00001: val_loss improved from inf to 0.19849, saving model to my_checkpoint.ckpt
Epoch 2/10

Epoch 00002: val_loss improved from 0.19849 to 0.11818, saving model to my_checkpoint.ckpt
Epoch 3/10

Epoch 00003: val_loss improved from 0.11818 to 0.09453, saving model to my_checkpoint.ckpt
Epoch 4/10

Epoch 00004: val_loss improved from 0.09453 to 0.08257, saving model to my_checkpoint.ckpt
Epoch 5/10

Epoch 00005: val_loss did not improve from 0.08257
Epoch 6/10

Epoch 00006: val_loss did not improve from 0.08257
Epoch 7/10

Epoch 00007: val_loss improved from 0.08257 to 0.06813, saving model to my_checkpoint.ckpt
Epoch 8/10

Epoch 00008: val_loss did not improve from 0.06813
Epoch 9/10

Epoch 00009: val_loss improved from 0.06813 to 0.06459, saving model to my_checkpoint.ckpt
Epoch 10/10

Epoch 00010: val_loss did not improve from 0.06459


<tensorflow.python.keras.callbacks.History at 0x7fbbd4fdce10>

In [9]:
# 가중치 적용
model.load_weights(checkpoint_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbbe406ffd0>