In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # Colab only
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import numpy as np

In [None]:
layer = keras.layers
print("tensorflow version check : ", tf.__version__)
print("gpu check", tf.test.is_gpu_available())


# 1. get dataset on memory

In [None]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
print("mnist dataset on memory")
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
print("The shape of train dataset : ", x_train.shape)
print("The shape of test dataset : ", x_test.shape)
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(5000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(5000).batch(32)

# 2. model definition

In [None]:
inputs = keras.Input(shape=(28, 28, 1))
feature = layer.Conv2D(32, 3,activation='relu')(inputs)
feature = layer.MaxPool2D(pool_size=(2, 2))(feature)
feature = layer.Conv2D(64, 3, activation='relu')(feature)
feature = layer.MaxPool2D(pool_size=(2, 2))(feature)
flatten = layer.Flatten()(feature)
embedding = layer.Dense(128, activation='relu')(flatten)
prob = layer.Dense(10, activation='softmax')(embedding)
model = keras.Model(inputs, prob)

In [None]:
model.summary()

model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

In [None]:
print('========================Training===============================')

model.fit(x_train, y_train, epochs=5)

print('========================Evaluation===============================')

test_loss, test_acc = model.evaluate(test_ds)

# del model
# keras.backend.clear_session()

# 3. Noise image problem?

In [None]:
[fashine_train_x, fashine_train_y], [fashine_test_x, fashine_test_y] = tf.keras.datasets.fashion_mnist.load_data()

In [None]:
tmp_data = fashine_train_x[1]/255.0

In [None]:
from matplotlib import pyplot as plt

plt.imshow(tmp_data, cmap='gray')

In [None]:
tmp_data = tf.reshape(tmp_data, [1, 28, 28, 1])
model.predict(tmp_data)

# I will show you a modified model for solving this!

In [None]:
# 아웃풋 레이어의 activation은 None으로 해주고, softmax를 따로 만들어줍니다.
# 그 이유는 선형분류기가 출력하는 score의 부호를 검사해서, 학습되지 않은 클래스의 이미지는 걸러주기 위해서에요.

inputs = keras.Input(shape=(28, 28, 1))
feature = layer.Conv2D(32, 3,activation='relu')(inputs)
feature = layer.MaxPool2D(pool_size=(2, 2))(feature)
feature = layer.Conv2D(64, 3, activation='relu')(feature)
feature = layer.MaxPool2D(pool_size=(2, 2))(feature)
flatten = layer.Flatten()(feature)
embedding = layer.Dense(128, activation='relu')(flatten)
scores = layer.Dense(10, activation=None)(embedding)
prob = tf.keras.activations.softmax(scores)
model = keras.Model(inputs, prob)

In [None]:
model.summary()

model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

In [None]:
print('========================Training===============================')

model.fit(x_train, y_train, epochs=5)

print('========================Evaluation===============================')

test_loss, test_acc = model.evaluate(test_ds)

# del model
# keras.backend.clear_session()

### 아래 코드를 실행해서 기존 mnist와 상관 없는 fashine mnist를 넣어 테스트를 해봅니다.
### score 값이 모두 큰 음수가 나오는 것을 확인할 수 있어요!

In [None]:
test_model = tf.keras.Model(inputs=model.get_layer('input_2').input, 
                            outputs=[model.get_layer('dense_3').output, 
                                     model.get_layer('tf_op_layer_Softmax').output])

In [None]:
test_model.predict(tmp_data)

### 학습시켰던 클래스를 가지고 있는 mnist 영상에 대해서는 어떨까요?
### 해당 클래스의 스코어 값은 양수로 표현됩니다.

In [None]:
tmp_data = x_test[0]

In [None]:
plt.imshow(  tmp_data[:,:,0], cmap='gray')

In [None]:
test_model.predict(np.reshape(tmp_data, [1, 28, 28, 1]))