The following trains a single model on MNIST.

We'll need tensorflow, numpy and a model from custom_layers.

In [1]:
import tensorflow as tf
import numpy as np
from custom_layers import *

In [None]:
Create the model and inputs

In [22]:
fc = FullyConnectedModel(output_size=10, hidden_layer_size=128, dropout_rate=1-0.2, batch_size=128, epochs=25)
input_img = np.ones((1, 28, 28)).astype(np.float32)
input_mask = np.ones((1, 128)).astype(np.float32)
_ = fc([input_img, input_mask])
fc.summary()
tf.keras.utils.plot_model(fc, show_shapes=True)

Model: "fully_connected_model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            multiple                  0         
_________________________________________________________________
encoder (Dense)              multiple                  100480    
_________________________________________________________________
custom_dropout_2 (CustomDrop multiple                  0         
_________________________________________________________________
output (Dense)               multiple                  1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


Create loss, optimizer, and metrics.

In [5]:
mse_loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_metric = tf.keras.metrics.Mean()

Prepare the data

In [25]:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255
masks_train = np.ones((x_train.shape[0], 128))
y_train = tf.keras.backend.one_hot(y_train, 10)
print(x_train.shape)
print(y_train.shape)
print(masks_train.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
print(train_dataset)

(60000, 28, 28)
(60000, 10)
(60000, 128)
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>


Train the model

In [27]:
epochs = 3
accuracy = tf.keras.metrics.CategoricalAccuracy()

# Iterate over epochs.
for epoch in range(epochs):
  print('Start of epoch %d' % (epoch,))

  # Iterate over the batches of the dataset.
  for step, (x_train, y_train) in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      masks = np.ones((x_train.shape[0], 128))
      logits = fc([x_train, masks])
      # Compute reconstruction loss
      loss = mse_loss_fn(y_train, logits)

    accuracy.update_state(y_train, logits)

    grads = tape.gradient(loss, fc.trainable_weights)
    optimizer.apply_gradients(zip(grads, fc.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
      print('step %s: mean loss = %s\n\tmean accuracy = %s' % (step, loss_metric.result(), accuracy.result()))


Start of epoch 0
step 0: mean loss = tf.Tensor(0.01055785, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.953125, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.010558846, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9631807, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.0105639845, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96144277, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.010569032, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9609116, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.0105700195, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9619311, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.010578792, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96107787, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.010582772, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96126246, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.010586391, shape=