The following trains a single model on MNIST.

We'll need tensorflow, numpy and a model from custom_layers.

In [1]:
import tensorflow as tf
import numpy as np
from custom_layers import *

Create the model and inputs

In [2]:
fc = FullyConnectedModel(output_size=10, hidden_layer_size=128)
input_img = np.ones((1, 28, 28)).astype(np.float32)
input_mask = np.ones((1, 128)).astype(np.float32)
_ = fc([input_img, input_mask])
fc.summary()
tf.keras.utils.plot_model(fc, show_shapes=True)

Model: "fully_connected_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
fully_connected_model_flatte multiple                  0         
_________________________________________________________________
fully_connected_model_encode multiple                  100480    
_________________________________________________________________
fully_connected_model_Custom multiple                  0         
_________________________________________________________________
fully_connected_model_output multiple                  1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


Create loss, optimizer, and metrics.

In [3]:
mse_loss_fn = tf.keras.losses.MeanSquaredError()
#mse_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_metric = tf.keras.metrics.Mean()

Prepare the data

In [4]:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255
masks_train = np.ones((x_train.shape[0], 128)).astype(np.float32)
y_train = tf.keras.backend.one_hot(y_train, 10)
print(x_train.shape)
print(y_train.shape)
print(masks_train.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
print(train_dataset)

(60000, 28, 28)
(60000, 10)
(60000, 128)
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>


Train the model

In [5]:
epochs = 20
accuracy = tf.keras.metrics.CategoricalAccuracy()

# Iterate over epochs.
for epoch in range(epochs):
  print('Start of epoch %d' % (epoch,))

  # Iterate over the batches of the dataset.
  for step, (x_train, y_train) in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      masks = np.ones((x_train.shape[0], 128)).astype(np.float32)
      logits = fc([x_train, masks])
      # Compute reconstruction loss
      loss = mse_loss_fn(y_train, logits)

    accuracy.update_state(y_train, logits)

    grads = tape.gradient(loss, fc.trainable_weights)
    optimizer.apply_gradients(zip(grads, fc.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
      print('step %s: mean loss = %s\n\tmean accuracy = %s' % (step, loss_metric.result(), accuracy.result()))


Start of epoch 0
step 0: mean loss = tf.Tensor(0.3581706, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.015625, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.06756151, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.6511448, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.054714236, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.71711755, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.04872201, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.74579525, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.044828206, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.7662095, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.04220481, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.77965945, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.040036697, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.7903754, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.038391083, shape=(), dt

step 100: mean loss = tf.Tensor(0.024306074, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8656894, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.024222491, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8661228, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.024158973, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86640966, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.024078164, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8668115, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.02399849, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86724615, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.02392162, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86765933, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.0238469, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8680915, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.02378415, shape=(), dtype=float32)
	

step 200: mean loss = tf.Tensor(0.021402385, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8812631, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.021371862, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8814334, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.02133618, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88163304, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.021301921, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8818425, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.021269327, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8820208, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.021233644, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.882234, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.021203939, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8823944, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.02118089, shape=(), dtype=float32)
	

step 300: mean loss = tf.Tensor(0.019899435, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8898918, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.019879194, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89002264, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.019856501, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8901664, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.019834295, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8903058, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.019812876, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8904552, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.0197935, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89056224, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.019777559, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8906443, shape=(), dtype=float32)
Start of epoch 19
step 0: mean loss = tf.Tensor(0.019769669, shape=()