In [1]:
import tensorflow as tf 
import pandas as pd 
import numpy as np 
import os 
import matplotlib.pyplot as plt
%matplotlib inline 

In [2]:
tf.__version__

'2.3.0'

In [3]:
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()

In [4]:
train_image = train_image/255
test_image = test_image/255

# train model

In [5]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

In [6]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [7]:
optimizer = tf.keras.optimizers.Adam()

In [8]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [9]:
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')

In [10]:
def train_step(model, images, labels):
    with tf.GradientTape() as tape:
        pred = model(images)
        loss_step = loss_func(labels, pred)
    grads = tape.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels, pred)

In [11]:
dataset = tf.data.Dataset.from_tensor_slices((train_image, train_label))

In [12]:
dataset = dataset.shuffle(10000).batch(32)

In [13]:
cp_dir = './customize_train_cp'
cp_prefix = os.path.join(cp_dir, 'ckpt')

In [14]:
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                model=model)

In [15]:
def train(epochs):
    for epoch in range(epochs):
        for (batch_index, (images, labels)) in enumerate(dataset):
            train_step(model, images, labels)
        print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
        print('Epoch{} Accuracy is {}'.format(epoch, train_accuracy.result()))
        train_loss.reset_states()
        train_accuracy.reset_states()
        if (epoch + 1) % 2 == 0:
            checkpoint.save(file_prefix = cp_prefix)

In [16]:
train(5)

Epoch0 loss is 1.7492755651474
Epoch0 Accuracy is 0.7183666825294495
Epoch1 loss is 1.7082139253616333
Epoch1 Accuracy is 0.7535833120346069
Epoch2 loss is 1.6982461214065552
Epoch2 Accuracy is 0.7616000175476074
Epoch3 loss is 1.6113578081130981
Epoch3 Accuracy is 0.8507333397865295
Epoch4 loss is 1.5954569578170776
Epoch4 Accuracy is 0.8668666481971741


# load model

In [17]:
load_model = tf.keras.Sequential()
load_model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
load_model.add(tf.keras.layers.Dense(128, activation='relu'))
load_model.add(tf.keras.layers.Dense(10, activation='softmax'))

In [18]:
load_checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                model=load_model)

In [19]:
tf.train.latest_checkpoint(cp_dir)

'./customize_train_cp\\ckpt-2'

In [20]:
load_checkpoint.restore(tf.train.latest_checkpoint(cp_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1981306ef10>

In [21]:
tf.argmax(load_model(train_image, training=False), axis=-1).numpy()

array([9, 0, 0, ..., 3, 0, 5], dtype=int64)

In [22]:
train_label

array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)

In [23]:
(tf.argmax(load_model(train_image, training=False), axis=-1).numpy() == train_label).sum()/len(train_label)

0.8601333333333333