In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D

In [3]:
init_lr = 1e-3
batch_size = 64

In [4]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = tf.one_hot(y_train, depth=10)
y_test = tf.one_hot(y_test, depth=10)
x_train, x_test = x_train / 255.0, x_test / 255.0
print('raw x_train:', x_train.shape)

# # Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(batch_size)
print('train_ds:', train_ds._input_dataset)
print('x_train:', x_train.mean(), x_train.max(), x_train.min())

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

raw x_train: (60000, 28, 28)
train_ds: <ShuffleDataset shapes: ((28, 28, 1), (10,)), types: (tf.float32, tf.float32)>
x_train: 0.13066062 1.0 0.0


In [5]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.model = Sequential()
        self.model.add(Conv2D(filters=32, kernel_size=(5,5), padding='same', activation='relu', input_shape=(28, 28, 1)))
        self.model.add(MaxPool2D(strides=2))
        self.model.add(Conv2D(filters=48, kernel_size=(5,5), padding='valid', activation='relu'))
        self.model.add(MaxPool2D(strides=2))
        self.model.add(Flatten())
        self.model.add(Dense(256, activation='relu'))
        self.model.add(Dense(84, activation='relu'))
        self.model.add(Dense(10, activation=None))

    def call(self, x):
        return self.model(x)

def get_cross_entropy_loss(labels, logits):
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
                                                   logits=logits)
    return tf.reduce_mean(loss)

@tf.function
def test_step(images, labels):
    logits = model(images, training=False)
    test_acc(tf.nn.softmax(logits), labels)

@tf.function
def train_step(src_images, src_labels):
    with tf.GradientTape() as tape:
        src_logits = model(src_images, training=True)
        batch_cross_entropy_loss = get_cross_entropy_loss(labels=src_labels,
                                                          logits=src_logits)

    gradients = tape.gradient(batch_cross_entropy_loss,
                              model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    source_train_acc(src_labels, tf.nn.softmax(src_logits))


learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(
    init_lr,
    decay_steps=(x_train.shape[0] // batch_size) * 2,
    end_learning_rate=init_lr * 1e-2,
    cycle=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
test_acc = tf.keras.metrics.CategoricalAccuracy()
source_train_acc = tf.keras.metrics.CategoricalAccuracy()
model = MyModel()

for epoch in range(20):
    for images, labels in train_ds:
        train_step(images, labels)
    for images, labels in test_ds:
        test_step(images, labels)
    print('epoch:',epoch, 'train:', source_train_acc.result().numpy(), 'test:', test_acc.result().numpy())

train: 0.95451665 test: 0.9855
train: 0.971875 test: 0.98795
train: 0.97796667 test: 0.9881667
train: 0.9822208 test: 0.989175
train: 0.98463 test: 0.98948
train: 0.9867028 test: 0.9898667
train: 0.9881167 test: 0.9901429
train: 0.9894229 test: 0.99035
train: 0.9903537 test: 0.99035555
train: 0.991245 test: 0.9905
train: 0.9919121 test: 0.9905818


KeyboardInterrupt: 