# Import TensorFlow 2.x.

In [None]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

import numpy as np
np.random.seed(7)

print(tf.__version__)

# Import TensorFlow datasets.

In [None]:
import tensorflow_datasets as tfds

### Load MNIST dataset.
* train split
* test split

In [None]:
(train_dataset, test_dataset), dataset_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Normalize dataset images.

In [None]:
def _normalize_image(image, label):
    image = tf.cast(image, tf.float32) / 255.
    return (image, label)

### Create dataset batches.

In [None]:
buffer_size = 1024
batch_size = 32

In [None]:
train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size)
train_dataset = train_dataset.map(_normalize_image)

test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.map(_normalize_image)

# Create the model.

In [None]:
model = models.Sequential([
  layers.Flatten(input_shape=(28, 28, 1)),
  layers.Dense(128,activation='relu'),
  layers.Dense(10, activation='softmax')
])

### Show model summary.

In [None]:
model.summary()

# Train the model.

### Compile the model.

In [None]:
model.compile(loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy'])

### Train the model.

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")

In [None]:
epochs = 10

In [None]:
model.fit(
    train_dataset, 
    epochs=epochs, 
    validation_data=test_dataset,
    callbacks=[tensorboard_callback])

# Evaluate the model.

### Evaluate the model accuracy.

In [None]:
history = model.evaluate(test_dataset)
print('model accuray on test dataset -' , history[1])

### Visualize the training graphs.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir logs