In [19]:
from keras.datasets import mnist
from keras.utils import to_categorical

from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import optimizers

from keras.callbacks import EarlyStopping

In [20]:
# MNIST 데이터셋 불러오기
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [21]:
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [22]:
# 60000, 가로, 세로, 필터
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

(60000, 28, 28, 1) (60000,)
(10000, 28, 28, 1) (10000,)


In [23]:
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
print(y_train.shape, y_train.shape)
print(y_train[0], y_train[0])

(60000, 10) (60000, 10)
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]


In [24]:
model = Sequential()

model.add(Conv2D(32, kernel_size=(5, 5), padding='same', input_shape = (28, 28, 1), activation='relu', kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (5, 5), padding='same', activation='relu', kernel_initializer='he_normal'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_4 (Conv2D)           (None, 28, 28, 32)        832       
                                                                 
 max_pooling2d_4 (MaxPoolin  (None, 14, 14, 32)        0         
 g2D)                                                            
                                                                 
 conv2d_5 (Conv2D)           (None, 14, 14, 64)        51264     
                                                                 
 max_pooling2d_5 (MaxPoolin  (None, 7, 7, 64)          0         
 g2D)                                                            
                                                                 
 flatten_2 (Flatten)         (None, 3136)              0         
                                                                 
 dense_4 (Dense)             (None, 1000)             

In [25]:
early_stopping = EarlyStopping(monitor = 'val_loss', min_delta = 0, patience = 3)
model.compile(loss = 'categorical_crossentropy', optimizer = optimizers.Adam(0.001), metrics=['accuracy'])

In [26]:
# fit() 메서드로 모델 훈련 시키기
hist = model.fit(
    x_train, y_train,
    epochs = 20, batch_size = 128,
    validation_data=(x_test, y_test),
    callbacks = [early_stopping]
    )

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20


In [27]:
# 테스트 데이터로 정확도 측정하기
test_loss, test_acc = model.evaluate(x_test, y_test)
print('test_acc: ', test_acc)

test_acc:  0.977400004863739


In [29]:
predit = model.predict(x_test)
print(predit[0:5])
print(y_test[0 :5])

[[1.99541106e-21 1.11698841e-18 2.08070417e-18 2.09591628e-14
  3.25046346e-17 5.05995327e-22 1.52657210e-27 1.00000000e+00
  6.20936879e-19 1.06020336e-14]
 [7.15499933e-23 1.95536553e-13 1.00000000e+00 8.91125293e-19
  9.95941547e-24 4.23268816e-29 6.17904952e-20 6.34448605e-18
  8.90940890e-21 1.35757045e-26]
 [6.51498780e-34 1.00000000e+00 1.68548481e-29 1.16487300e-35
  1.96321493e-30 5.99472915e-33 8.89466791e-29 4.17323440e-31
  1.29049620e-32 3.52937794e-34]
 [9.99974012e-01 2.56116950e-09 5.06117317e-07 7.63309629e-07
  1.35880063e-09 1.28466227e-09 2.41160833e-05 8.56218152e-11
  1.92729033e-08 6.09889980e-07]
 [1.05830449e-21 2.07284497e-26 4.98499547e-25 5.52011312e-27
  1.00000000e+00 4.83015922e-21 1.96580280e-20 1.41546286e-25
  6.83383260e-24 9.10228140e-18]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]
