In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

## 1. 讀入 Fashion MNIST dataset

In [2]:
from tensorflow.keras.datasets import fashion_mnist

In [3]:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

In [4]:
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

## 2. 整理資料

In [5]:
x_train.shape

(60000, 28, 28)

### Channel
CNN要注意一張圖有多少個channel, 因為只有灰階, 所以只有一個channel <br>
( 紅色, 綠色, ...etc. 各是一個channel, 可以有多個 ) <br>
需轉換資料格式: (28,28) --> (28, 28, 1)

In [6]:
# 1 代表一個 channel
# / 255 是為了要標準化

x_train = x_train.reshape(60000, 28, 28, 1) / 255
x_test = x_test.reshape(10000, 28, 28, 1) / 255

In [7]:
x_train[3456].shape

(28, 28, 1)

In [8]:
print(y_train[3456], class_names[y_train[3456]])

6 Shirt


In [9]:
from tensorflow.keras.utils import to_categorical

In [10]:
# one-hot encoding

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

## 3. 打造函數學習機 CNN

In [11]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Flatten

In [12]:
model = Sequential()

In [13]:
model.add(Conv2D(16, (3,3), padding='same',
                input_shape=(28,28,1),
                activation='relu'))

In [14]:
model.add(MaxPooling2D(pool_size=(2,2)))

In [15]:
model.add(Conv2D(32, (3,3), padding='same',
                activation='relu'))

In [16]:
model.add(MaxPooling2D(pool_size=(2,2)))

In [17]:
model.add(Conv2D(64, (3,3), padding='same',
                activation='relu'))

In [18]:
model.add(MaxPooling2D(pool_size=(2,2)))

In [19]:
model.add(Flatten())

In [20]:
model.add(Dense(54, activation='relu'))

In [21]:
# # softmax把值轉為0到1之間的機率

model.add(Dense(10, activation='softmax'))

In [22]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 28, 28, 16)        160       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 32)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 576)               0

In [23]:
# 3*3 (權重) + 1 (bias)
# filter個數=16

(3*3+1) * 16 

160

In [24]:
# filter個數=32

(3*3*16+1) * 32 

4640

In [25]:
model.compile(loss='categorical_crossentropy', 
              optimizer='rmsprop',
             metrics=['accuracy'])

## 4. 訓練

In [26]:
model.fit(x_train, y_train, batch_size=128, epochs=12)


Train on 60000 samples
Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<tensorflow.python.keras.callbacks.History at 0x19303fe4948>

## 5. 結果

In [27]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)

print('\nTest accuracy:', test_acc)

10000/10000 - 2s - loss: 0.2566 - accuracy: 0.9083

Test accuracy: 0.9083


In [28]:
predicted_label = model.predict_classes(x_test)

In [29]:
def predict(n):
    print('prediction:', class_names[predicted_label[n]])
    print('true label:', class_names[np.argmax(y_test[n])])
    X = x_test[n].reshape(28,28)
    plt.imshow(X, cmap='Greys')

In [30]:
from ipywidgets import interact_manual

In [31]:
interact_manual(predict, n=(0, 9999))

interactive(children=(IntSlider(value=4999, description='n', max=9999), Button(description='Run Interact', sty…

<function __main__.predict(n)>