In [1]:
import tensorflow as tf
from tensorflow import keras


In [2]:
(x_train, y_train), (x_val, y_val) = keras.datasets.fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:

def preprocess(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)

  return x, y

def create_dataset(xs, ys, n_classes=10):
  ys = tf.one_hot(ys, depth=n_classes)
  return tf.data.Dataset.from_tensor_slices((xs, ys)) \
    .map(preprocess) \
    .shuffle(len(ys)) \
    .batch(128)

In [4]:
train_dataset = create_dataset(x_train, y_train)
val_dataset = create_dataset(x_val, y_val)

In [5]:
model = keras.Sequential([
    keras.layers.Reshape(target_shape=(6 * 7,), input_shape=(6, 7)),
    keras.layers.Dense(units=256, activation='relu'),
    keras.layers.Dense(units=192, activation='relu'),
    keras.layers.Dense(units=128, activation='relu'),
    keras.layers.Dense(units=10, activation='softmax')
])

In [6]:
model.compile(optimizer='adam', 
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(
    train_dataset.repeat(), 
    epochs=10, 
    steps_per_epoch=500,
    validation_data=val_dataset.repeat(), 
    validation_steps=2
)

Epoch 1/10


  return dispatch_target(*args, **kwargs)


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [7]:

predictions = model.predict(val_dataset)

In [8]:
predictions

array([[9.8200053e-01, 4.0382235e-09, 3.5198528e-04, ..., 3.1753938e-12,
        5.0770630e-08, 8.4940666e-12],
       [1.5023162e-04, 3.0635935e-08, 9.8947185e-01, ..., 8.2212831e-10,
        1.5115847e-07, 5.7085519e-09],
       [1.7115613e-10, 5.4191795e-10, 2.0911375e-10, ..., 9.9807203e-01,
        1.6862913e-08, 1.9252666e-03],
       ...,
       [4.7859450e-09, 9.9999130e-01, 6.5481145e-09, ..., 4.8029471e-12,
        2.0374792e-10, 9.1174160e-12],
       [2.7590397e-12, 1.0000000e+00, 6.1664497e-13, ..., 5.3241045e-16,
        8.9430477e-14, 1.1509628e-14],
       [7.4676210e-11, 3.5629150e-10, 1.5878882e-10, ..., 9.9997711e-01,
        5.5583471e-09, 6.1524270e-06]], dtype=float32)