# Saving keras models

In [1]:
from tensorflow.keras.backend import clear_session
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import load_model
from tensorflow.keras.models import model_from_json
from numpy.testing import assert_allclose
from tensorflow.keras.callbacks import ModelCheckpoint

clear_session()

In [2]:
inputs = Input(shape=(784,), name='digits')
x = Dense(64, activation='relu', name='dense_1')(inputs)
x = Dense(64, activation='relu', name='dense_2')(x)
outputs = Dense(10, name='predictions')(x)

model = Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')
model.summary()

Model: "3_layer_mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
digits (InputLayer)          [(None, 784)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
predictions (Dense)          (None, 10)                650       
Total params: 55,050
Trainable params: 55,050
Non-trainable params: 0
_________________________________________________________________


In [3]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss=SparseCategoricalCrossentropy(from_logits=True),
              optimizer=RMSprop())

history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=1)

# Reset metrics before saving so that loaded model has same state,
# since metric states are not preserved by Model.save_weights
model.reset_metrics()

Train on 60000 samples


In [4]:
predictions = model.predict(x_test)

## h5 format

In [5]:
# Save the model
model.save('model_save.h5')

# Recreate the exact same model purely from the file
new_model = load_model('model_save.h5')

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

## Tensorflow SavedModel format

In [6]:
# Export the model to a SavedModel
model.save('model_save', save_format='tf')

# Recreate the exact same model
new_model = load_model('model_save')

# Check that the state is preserved
new_predictions = new_model.predict(x_test)
assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: model_save/assets


## Model architecture

In [7]:
config = model.get_config()
reinitialized_model = Model.from_config(config)

In [8]:
json_config = model.to_json()
reinitialized_model = model_from_json(json_config)

## Model weights

In [9]:
weights = model.get_weights()  
model.set_weights(weights)

## Checkpoints during the training

In [10]:
checkpoint_callback = ModelCheckpoint(filepath="weights.{epoch:02d}-{val_loss:.2f}.hdf5")
history = model.fit(
    x_train, y_train,
    batch_size=64,
    epochs=5, 
    validation_split = 0.2,
    callbacks=[checkpoint_callback]
)

Train on 48000 samples, validate on 12000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
