<a href="https://colab.research.google.com/github/forMwish/MyDeepLearn/blob/master/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. 准备

## 1.1 挂载 google drive


In [None]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive
google_drive = 1

## 1.2 数据集处理

In [None]:
from keras.datasets import mnist
import numpy as np

#pdb.set_trace() #需要在函数中使用，否则异常
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape(60000, -1).astype(np.float32) / 255
test_images = test_images.reshape(10000, -1).astype(np.float32) / 255

from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

validation_image = train_images[:10000]
validation_labels = train_labels[:10000]
partial_train_images = train_images[10000:]
partial_train_labels = train_labels[10000:]

# 2. 模型

## 2.1 构建

In [None]:
from keras import models
from keras import layers

from keras import optimizers
from keras import losses
from keras import metrics


network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28, )))
network.add(layers.Dense(10, activation="softmax"))
network.compile(optimizer=optimizers.RMSprop(lr=0.001),
        loss=losses.categorical_crossentropy,
        metrics=metrics.categorical_accuracy)
network.summary()

## 2.2 训练

In [None]:
network.fit(partial_train_images, 
      partial_train_labels, 
      epochs=50, 
      batch_size=128, 
      validation_data=(validation_image, validation_labels))


# 3 model 数据存储

## 3.1 history

In [None]:
import pickle
import os

save_path = './history/mnist_base.pickle'
save_dir = os.path.dirname(save_path)

if not os.path.isdir(save_dir):
  os.mkdir(save_dir)

fp = open(save_path, 'wb')
pickle.dump(network.history.history, fp)
fp.close()

## 3.2 model

In [None]:
save_path = './model/mnist_base'
save_dir = os.path.dirname(save_path)

if not os.path.isdir(save_dir):
  os.mkdir(save_dir)
network.save(save_path)

# 4. model 数据恢复

# 4.1 history

In [None]:
import pickle

history_path = './history/mnist_base.pickle'
fp = open(history_path, 'rb')

history = pickle.load(fp)
fp.close()

## 4.2 model

In [None]:
model_path = './model/mnist_base'
network = models.load_model(model_path)

# 5. history 分析

In [None]:
import matplotlib.pyplot as plt

loss = history['loss']
accuracy = history['categorical_accuracy']
val_loss = history['val_loss']
val_accuracy = history['val_categorical_accuracy']

epoch = range(1, len(loss)+1)

plt.figure(figsize=(20,10))
plt.subplot(1, 2, 1)
plt.title('loss')
plt.plot(epoch, loss, 'bo', label='loss')
plt.plot(epoch, val_loss, 'r.', label='val_loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.title('accuracy')
plt.plot(epoch, accuracy, 'bo', label='accuracy')
plt.plot(epoch, val_accuracy, 'r.', label='val_accuracy')
plt.show()

## 5.1 测试集验证

In [None]:
test_loss, test_acc = network.evaluate(test_images, test_labels)
print(test_acc)

## 其他

In [None]:
import matplotlib.pyplot as plt
digit = train_images[4].reshape(28,28)
print(digit.shape)
plt.imshow(digit, cmap=plt.cm.binary)
plt.show()



In [None]:
for j in range(0, 6, 2):
  print(j)