In [1]:
import argparse

import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import tensorflow_datasets as tfds

## Loading dataset

In [4]:
batch_size = 32

train_dataset, test_dataset = tfds.load("mnist", 
                                        split=["train", "test"], 
                                        as_supervised=True)

AUTOTUNE = tf.data.experimental.AUTOTUNE
size = (32, 32)

# Resize image, transform to one-hot encoding, convert from grayscale to rgb
train_dataset = train_dataset.map(lambda x, y: (tf.image.grayscale_to_rgb(tf.image.resize(x, size)), tf.one_hot(y, depth=10)))
test_dataset = test_dataset.map(lambda x, y: (tf.image.grayscale_to_rgb(tf.image.resize(x, size)), tf.one_hot(y, depth=10)))

train_dataset = train_dataset.cache().batch(batch_size).prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.cache().batch(batch_size).prefetch(buffer_size=AUTOTUNE)

## Loading Model

In [5]:
lr_rate = 0.0001

base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3),
                                                include_top=False,
                                                weights='imagenet')

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
base_model.trainable = False

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(10) # 10 classes
inputs = tf.keras.Input(shape=(32, 32, 3))

x = preprocess_input(inputs)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

model.compile(optimizer=tf.keras.optimizers.Adam(lr=lr_rate),
            loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
            metrics=['accuracy'])



## Train model

In [8]:
model.fit(train_dataset, epochs=2)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7ff4d9d54790>