# 目录
### 1. 导入模块
### 2. 导入fashion_mnist数据
### 3. 定义模型（图）
### 4. tf.Seesion 训练

## 1. 导入模块

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn

from tensorflow import keras
import tensorflow as tf
import sys
import os
import time
import datetime

for module in [np, pd, mpl, sklearn, keras, tf]:
    print(module.__name__, module.__version__)

numpy 1.17.2
pandas 0.25.1
matplotlib 3.1.1
sklearn 0.21.3
tensorflow.python.keras.api._v1.keras 2.2.4-tf
tensorflow 1.15.0


## 2. 导入fashion_mnist数据

In [2]:
fashion_mnist = keras.datasets.fashion_mnist

(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()

x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

(55000, 28, 28) (55000,)
(5000, 28, 28) (5000,)
(10000, 28, 28) (10000,)


In [3]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1, 28*28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1, 28*28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1, 28*28)

## 3. 定义模型（图）

In [4]:
# 网络结构
hidden_units = [100, 100]
class_num = 10

x = tf.placeholder(tf.float32, shape=(None, 28*28))
y = tf.placeholder(tf.int64, shape=(None))

input_for_next_layer = x
for hidden_unit in hidden_units:
    input_for_next_layer = tf.layers.dense(input_for_next_layer, units=hidden_unit, activation=tf.nn.relu)

logits = tf.layers.dense(input_for_next_layer, class_num)

# 计算损失
loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)

# 预测
prediction = tf.argmax(logits, axis=1)

# 计算准确率
correct_prediction = tf.equal(prediction, y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 训练操作
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)

Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


## 4. tf.Seesion 训练

In [5]:
init = tf.global_variables_initializer()
epochs = 10
batch_size = 32

train_step_per_epoch = x_train_scaled.shape[0] // batch_size

def valid_with_sess(sess, x, y, images, labels, accuracy, batch_size):
    valid_step_per_epoch = images.shape[0] // batch_size
    accuracy_list = []
    for valid_step in range(valid_step_per_epoch):
        # 取数据
        batch_x = images[valid_step*batch_size:(valid_step+1)*batch_size]
        batch_y = labels[valid_step*batch_size:(valid_step+1)*batch_size]
        # 推理
        accuracy_value = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
        accuracy_list.append(accuracy_value)
    return np.mean(accuracy_list)


with tf.Session() as sess:
    sess.run(init)
    for epoch in range(epochs):
        for step in range(train_step_per_epoch):
            # 取数据
            batch_x = x_train_scaled[step*batch_size:(step+1)*batch_size]
            batch_y = y_train[step*batch_size:(step+1)*batch_size]
            # 推理和反向传播
            loss_value, prediction_value, accuracy_value, _ = sess.run(
                [loss, prediction, accuracy, train_op], 
                feed_dict={x: batch_x, y: batch_y}
            )
            print("\r[train] epoch: {}, loss: {:.5f}, acc: {:.2f}".format(epoch, loss_value, accuracy_value), end="")
            
        valid_accuracy = valid_with_sess(sess, x, y, x_valid_scaled, y_valid, accuracy, batch_size)
        print("\t[valid] accuracy: {:.2f}".format(valid_accuracy))

[train] epoch: 0, loss: 0.40328, acc: 0.81	[valid] accuracy: 0.87
[train] epoch: 1, loss: 0.41917, acc: 0.81	[valid] accuracy: 0.88
[train] epoch: 2, loss: 0.39985, acc: 0.84	[valid] accuracy: 0.88
[train] epoch: 3, loss: 0.39487, acc: 0.81	[valid] accuracy: 0.88
[train] epoch: 4, loss: 0.37143, acc: 0.84	[valid] accuracy: 0.87
[train] epoch: 5, loss: 0.27811, acc: 0.91	[valid] accuracy: 0.88
[train] epoch: 6, loss: 0.32157, acc: 0.88	[valid] accuracy: 0.88
[train] epoch: 7, loss: 0.21142, acc: 0.88	[valid] accuracy: 0.87
[train] epoch: 8, loss: 0.25593, acc: 0.84	[valid] accuracy: 0.88
[train] epoch: 9, loss: 0.24053, acc: 0.88	[valid] accuracy: 0.88
