In [1]:
import tensorflow as tf
from tensorflow import keras

## Callback 함수 및 모델 저장

In [2]:
# mnist 데이터셋을 활용

mnist = keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train/255., x_test/255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
class MyModel(keras.Model):
  def __init__(self):
    super(MyModel, self).__init__(name='subclassing_exercise')
    self.flatten = keras.layers.Flatten()
    self.fc1 = keras.layers.Dense(256, activation='relu')
    self.fc2 = keras.layers.Dense(128, activation='relu')
    self.dropout = keras.layers.Dropout(rate=0.2)
    self.fc3 = keras.layers.Dense(10, activation='softmax')

  def call(self, x):
    x = self.flatten(x)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.dropout(x)
    x = self.fc3(x)

    return x

In [4]:
model = MyModel()

In [5]:
callbacks = [
  keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
  keras.callbacks.TensorBoard('./logs')
]

In [6]:
# model.compile : 모델 학습에 필요한 정보 설정

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [7]:
# 모델 학습

model.fit(x_train, y_train,
          batch_size=32, epochs=10,
          callbacks=callbacks,
          validation_data=(x_test, y_test))

Epoch 1/10
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10


<tensorflow.python.keras.callbacks.History at 0x7f1d5f295e48>

## 모델 저장
* 모델 전체를 저장
* 모델의 가중치만 저장하는 방법

In [8]:
# 모델 전체를 저장하는 방법

model.save('./my_model')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: ./my_model/assets


In [9]:
del model

In [10]:
model = keras.models.load_model('./my_model')

In [11]:
model.summary()

Model: "subclassing_exercise"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  200960    
_________________________________________________________________
dense_1 (Dense)              multiple                  32896     
_________________________________________________________________
dropout (Dropout)            multiple                  0         
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
_________________________________________________________________


In [12]:
# 가중치만 저장하는 방법

model.save_weights('./weights/subclassing')
model.load_weights('./weights/subclassing')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1d5b97af60>

In [13]:
# HDF5 파일포멧으로 저장
model.save_weights('test.h5', save_format='h5')