In [3]:
from __future__ import absolute_import, division, print_function, unicode_literals

!pip install tensorflow-datasets
import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting tensorflow-datasets
[?25l  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d9/b8/457ad44e8748fbe5021b4ca7e7d589b5852881bbb11bca4d947952a13558/tensorflow_datasets-1.0.1-py3-none-any.whl (400kB)
[K    100% |████████████████████████████████| 409kB 5.3MB/s ta 0:00:01
Collecting requests (from tensorflow-datasets)
  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/7d/e3/20f3d364d6c8e5d2353c72a67778eb189176f08e873c9900e10c0287b84b/requests-2.21.0-py2.py3-none-any.whl
Collecting tensorflow-metadata (from tensorflow-datasets)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/b7/3fc74574aa9aff44491cce996711dd6094653c20d9e2800be4efb054e0da/tensorflow_metadata-0.13.0-py3-none-any.whl
Collecting wrapt (from tensorflow-datasets)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/b2/0f71ca90b0ade7fad27e3d20327c996c6252a2ffe88f50a95bba7434eda9/wrapt-1.11.1.tar.gz
Collecting future (from tensorfl

In [5]:
dataset, info = tfds.load('mnist', data_dir='gs://tfds-data/datasets', with_info=True, as_supervised=True)

mnist_train, mnist_test = dataset['train'], dataset['test']

def convert_types(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

In [6]:
mnist_train = mnist_train.map(convert_types).shuffle(10000).batch(32)
mnist_test = mnist_test.map(convert_types).batch(32)

In [8]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')
    
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()

In [9]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()

In [12]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


In [13]:
@tf.function
def train_step(image, label):
    with tf.GradientTape() as tape:
        predictions = model(image)
        loss = loss_object(label, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(label, predictions)
    

In [14]:
@tf.function
def test_step(image, label):
    predictions = model(image)
    t_loss = loss_object(label, predictions)
    
    test_loss(t_loss)
    test_accuracy(label, predictions)

In [15]:
EPOCHS = 5
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'

for epoch in range(EPOCHS):
    for image, label in mnist_train:
        train_step(image, label)
    
    for test_image, test_label in mnist_test:
        test_step(image, label)
    
    print(template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))

Epoch 1, Loss: 0.13656781613826752, Accuracy: 95.92666625976562, Test Loss: 0.005389093887060881, Test Accuracy: 100.0
Epoch 2, Loss: 0.08821629732847214, Accuracy: 97.3308334350586, Test Loss: 0.04423443228006363, Test Accuracy: 98.4375
Epoch 3, Loss: 0.06631175428628922, Accuracy: 97.98055267333984, Test Loss: 0.04265563562512398, Test Accuracy: 97.91667175292969
Epoch 4, Loss: 0.05273852497339249, Accuracy: 98.38583374023438, Test Loss: 0.03585124760866165, Test Accuracy: 98.4375
Epoch 5, Loss: 0.0439818874001503, Accuracy: 98.65233612060547, Test Loss: 0.029077725484967232, Test Accuracy: 98.75


<BatchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>

In [22]:
for image, label in mnist_train.take(1):
    print(image)

tf.Tensor(
[[[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0