Permalink
Please
sign in to comment.
Showing
with
653 additions
and 0 deletions.
| @@ -0,0 +1,123 @@ | |||
| { | |||
| "cells": [ | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "# tensorflow2教程-keras模型保持和序列化\n", | |||
| "\n", | |||
| "## 1.保持序列模型和函数模型" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 4, | |||
| "metadata": {}, | |||
| "outputs": [ | |||
| { | |||
| "name": "stdout", | |||
| "output_type": "stream", | |||
| "text": [ | |||
| "Model: \"3_layer_mlp\"\n", | |||
| "_________________________________________________________________\n", | |||
| "Layer (type) Output Shape Param # \n", | |||
| "=================================================================\n", | |||
| "digits (InputLayer) [(None, 784)] 0 \n", | |||
| "_________________________________________________________________\n", | |||
| "dense_1 (Dense) (None, 64) 50240 \n", | |||
| "_________________________________________________________________\n", | |||
| "dense_2 (Dense) (None, 64) 4160 \n", | |||
| "_________________________________________________________________\n", | |||
| "predictions (Dense) (None, 10) 650 \n", | |||
| "=================================================================\n", | |||
| "Total params: 55,050\n", | |||
| "Trainable params: 55,050\n", | |||
| "Non-trainable params: 0\n", | |||
| "_________________________________________________________________\n", | |||
| "60000/60000 [==============================] - 1s 23us/sample - loss: 0.3136\n" | |||
| ] | |||
| } | |||
| ], | |||
| "source": [ | |||
| "# 构建一个简单的模型并训练\n", | |||
| "from __future__ import absolute_import, division, print_function\n", | |||
| "import tensorflow as tf\n", | |||
| "tf.keras.backend.clear_session()\n", | |||
| "from tensorflow import keras\n", | |||
| "from tensorflow.keras import layers\n", | |||
| "\n", | |||
| "inputs = keras.Input(shape=(784,), name='digits')\n", | |||
| "x = layers.Dense(64, activation='relu', name='dense_1')(inputs)\n", | |||
| "x = layers.Dense(64, activation='relu', name='dense_2')(x)\n", | |||
| "outputs = layers.Dense(10, activation='softmax', name='predictions')(x)\n", | |||
| "\n", | |||
| "model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')\n", | |||
| "model.summary()\n", | |||
| "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", | |||
| "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n", | |||
| "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n", | |||
| "\n", | |||
| "model.compile(loss='sparse_categorical_crossentropy',\n", | |||
| " optimizer=keras.optimizers.RMSprop())\n", | |||
| "history = model.fit(x_train, y_train,\n", | |||
| " batch_size=64,\n", | |||
| " epochs=1)\n", | |||
| "\n", | |||
| "predictions = model.predict(x_test)" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "markdown", | |||
| "metadata": {}, | |||
| "source": [ | |||
| "### 1.1保持全模型\n", | |||
| "可以对整个模型进行保存,其保持的内容包括:\n", | |||
| "- 该模型的架构\n", | |||
| "- 模型的权重(在训练期间学到的)\n", | |||
| "- 模型的训练配置(你传递给编译的),如果有的话\n", | |||
| "- 优化器及其状态(如果有的话)(这使您可以从中断的地方重新启动训练)\n" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": 5, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [ | |||
| "import numpy as np\n", | |||
| "model.save('the_save_model.h5')\n", | |||
| "new_model = keras.models.load_model('the_save_model.h5')\n", | |||
| "new_prediction = new_model.predict(x_test)\n", | |||
| "np.testing.assert_allclose(predictions, new_prediction, atol=1e-6) # 预测结果一样" | |||
| ] | |||
| }, | |||
| { | |||
| "cell_type": "code", | |||
| "execution_count": null, | |||
| "metadata": {}, | |||
| "outputs": [], | |||
| "source": [] | |||
| } | |||
| ], | |||
| "metadata": { | |||
| "kernelspec": { | |||
| "display_name": "Python 3", | |||
| "language": "python", | |||
| "name": "python3" | |||
| }, | |||
| "language_info": { | |||
| "codemirror_mode": { | |||
| "name": "ipython", | |||
| "version": 3 | |||
| }, | |||
| "file_extension": ".py", | |||
| "mimetype": "text/x-python", | |||
| "name": "python", | |||
| "nbconvert_exporter": "python", | |||
| "pygments_lexer": "ipython3", | |||
| "version": "3.6.8" | |||
| } | |||
| }, | |||
| "nbformat": 4, | |||
| "nbformat_minor": 2 | |||
| } | |||
Oops, something went wrong.
0 comments on commit
6a32897