In [1]:
import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [2]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.30474028, -0.33526403, -0.07552153, -0.28535366, -0.79481727,
        -0.5111413 , -0.12335837,  0.20205286,  0.6245407 ,  0.15871125]],
      dtype=float32)

In [3]:
tf.nn.softmax(predictions).numpy()

array([[0.0790246 , 0.07664891, 0.09938269, 0.08057156, 0.04840882,
        0.06428704, 0.09474046, 0.13117762, 0.2001446 , 0.12561361]],
      dtype=float32)

この損失は、クラスが正しい確率の対数をとって符号を反転させたものです。この値はモデルがこのクラスが正しいと確信しているときに 0 になります。

この訓練されていないモデルはランダムに近い確率 (それぞれのクラスについて 1/10) を出力します、最初の損失は -tf.log(1/10) ~= 2.3 に近い値になるはずです。

In [4]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()

2.7443972

In [5]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [6]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7efd0be55050>

In [7]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 1s - loss: 0.0751 - accuracy: 0.9770 - 822ms/epoch - 3ms/step


[0.07507123053073883, 0.9769999980926514]

In [8]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1.86656038e-07, 6.87014090e-09, 1.19874594e-05, 2.98295025e-04,
        9.51166554e-13, 7.70561940e-07, 3.16877470e-12, 9.99672294e-01,
        1.40722818e-07, 1.63789773e-05],
       [5.39137872e-08, 2.61673034e-04, 9.99735177e-01, 2.72810144e-06,
        2.57797364e-13, 1.51416994e-08, 1.00789173e-07, 1.90325208e-12,
        2.70519251e-07, 4.11650342e-13],
       [7.53684446e-07, 9.99536991e-01, 5.62434143e-05, 4.24991367e-06,
        9.32002695e-06, 6.21848756e-07, 4.62474782e-06, 2.68783333e-04,
        1.17929208e-04, 5.45374746e-07],
       [9.99915481e-01, 7.56222285e-09, 2.69252196e-05, 1.75659716e-06,
        5.07413347e-07, 3.64806306e-06, 3.31542324e-05, 1.01843425e-05,
        2.62100883e-08, 8.30546742e-06],
       [9.74946488e-07, 3.02113499e-08, 2.09354639e-05, 2.93184343e-07,
        9.99450862e-01, 1.28746467e-06, 1.05559195e-06, 5.30922298e-05,
        2.25972045e-07, 4.71185747e-04]], dtype=float32)>

In [11]:
#どのくらいデータ量があるのか確認
display(len(y_train))

60000