In [1]:
import numpy                      as np
import tensorflow                 as tf
import tensorflow_hub             as hub
import tensorflow_datasets        as tfds

# Data

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

## Pre-porcess image

In [3]:
@tf.function

def normalize(img, label):
  return tf.cast(img, float)/255., label

In [4]:
ds_train = ds_train.map(normalize, num_parallel_calls= tf.data.experimental.AUTOTUNE)
ds_train = ds_train.shuffle(buffer_size=1000).batch(128)
ds_train = ds_train.prefetch(1000)

In [5]:
ds_test = ds_test.map(normalize, num_parallel_calls= tf.data.experimental.AUTOTUNE)
ds_test = ds_test.shuffle(buffer_size=1000).batch(128)
ds_test = ds_test.prefetch(1000)

## Model

In [6]:
model = tf.keras.Sequential([
                          tf.keras.layers.Conv2D(64, 3, activation='relu', input_shape = (28,28,1)),
                          tf.keras.layers.MaxPool2D(2),
                          tf.keras.layers.Flatten(),
                          tf.keras.layers.Dense(64, activation='relu'),
                          tf.keras.layers.Dense(10, activation='softmax')
])

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 26, 26, 64)        640       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 10816)             0         
_________________________________________________________________
dense (Dense)                (None, 64)                692288    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
Total params: 693,578
Trainable params: 693,578
Non-trainable params: 0
_________________________________________________________________


## Optimizer, loss and accuracy

In [7]:
def set_optimizer():
  return tf.optimizers.Adam()


def set_losses():
    train_loss = tf.losses.SparseCategoricalCrossentropy()
    val_loss = tf.losses.SparseCategoricalCrossentropy() 
    return train_loss, val_loss


def set_accuracy():
  train_acc = tf.metrics.SparseCategoricalAccuracy()    
  val_acc = tf.metrics.SparseCategoricalAccuracy() 
  return train_acc, val_acc

In [8]:
optimizer = set_optimizer()
train_loss, val_loss = set_losses()
train_acc , val_acc = set_accuracy()

## One training loop

In [9]:
def one_epoch(model, optimizer, X, y, losses, accuracy):
  with tf.GradientTape() as tape:
    pred = model(X)
    loss = train_loss(y, pred)

  grad = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grad, model.trainable_variables))

  train_acc(y, pred)

  return loss

## full model

In [14]:
@tf.function

def train(model, device, num_epochs, optimizer, train_data, train_loss, train_acc, val_data, val_loss, val_acc):

    step = 0
    loss = 0.

    for i in range(num_epochs):
        for x, y in train_data:
            step +=1
            with tf.device(device_name=device):
              loss = one_epoch(model, optimizer, x, y, train_loss, train_acc)

# print loss after each batch (128)
            tf.print('step ', step, 
                  ': loss' , loss,
                ": acc", train_acc.result())
        
        
        with tf.device(device_name=device):
            for x, y in val_data:
                pred = model(x)
                loss = val_loss(y, pred)
                val_acc(y, pred)

        tf.print('val_loss' , loss,
                        ": val_ acc", val_acc.result())
          




In [None]:
# this code uses the GPU if available, otherwise uses a CPU
device = '/gpu:0' if tf.test.is_gpu_available() else '/cpu:0'
EPOCHS = 2


In [12]:
train(model, device, 2, optimizer, ds_train, train_loss, train_acc, ds_test, val_loss, val_acc)

step  1 : loss 2.30282164 : acc 0.078125
step  2 : loss 2.16940808 : acc 0.15234375
step  3 : loss 2.04881334 : acc 0.239583328
step  4 : loss 1.90542412 : acc 0.3046875
step  5 : loss 1.80432487 : acc 0.364062488
step  6 : loss 1.56639624 : acc 0.41015625
step  7 : loss 1.46044123 : acc 0.446428567
step  8 : loss 1.42975104 : acc 0.469726562
step  9 : loss 1.19611955 : acc 0.502604187
step  10 : loss 1.06518841 : acc 0.53125
step  11 : loss 1.0008018 : acc 0.552556813
step  12 : loss 0.984683394 : acc 0.571614563
step  13 : loss 0.853455 : acc 0.588942289
step  14 : loss 0.772534311 : acc 0.604910731
step  15 : loss 0.832049251 : acc 0.614583313
step  16 : loss 0.75705874 : acc 0.624023438
step  17 : loss 0.768032432 : acc 0.632352948
step  18 : loss 0.63091749 : acc 0.642361104
step  19 : loss 0.601264417 : acc 0.650082231
step  20 : loss 0.568347573 : acc 0.660937488
step  21 : loss 0.509435 : acc 0.669642866
step  22 : loss 0.529275298 : acc 0.677201688
step  23 : loss 0.449446619 

In [13]:
ds_info

tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.1,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)