The following trains a single model on MNIST.

We'll need tensorflow, numpy and a model from custom_layers.

In [1]:
import tensorflow as tf
import numpy as np
from custom_layers import *

Create the model and inputs

In [2]:
fc = FullyConnectedModel(output_size=10, hidden_layer_size=128)
input_img = np.ones((1, 28, 28)).astype(np.float32)
input_mask = np.ones((1, 128)).astype(np.float32)
_ = fc([input_img, input_mask])
fc.summary()
tf.keras.utils.plot_model(fc, show_shapes=True)

Model: "fully_connected_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
fully_connected_model_flatte multiple                  0         
_________________________________________________________________
fully_connected_model_encode multiple                  100480    
_________________________________________________________________
fully_connected_model_Custom multiple                  0         
_________________________________________________________________
fully_connected_model_output multiple                  1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


Create loss, optimizer, and metrics.

In [3]:
mse_loss_fn = tf.keras.losses.MeanSquaredError()
#mse_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_metric = tf.keras.metrics.Mean()

Prepare the data

In [4]:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255
masks_train = np.ones((x_train.shape[0], 128)).astype(np.float32)
y_train = tf.keras.backend.one_hot(y_train, 10)
print(x_train.shape)
print(y_train.shape)
print(masks_train.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
print(train_dataset)

(60000, 28, 28)
(60000, 10)
(60000, 128)
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>


Train the model

In [5]:
epochs = 20
accuracy = tf.keras.metrics.CategoricalAccuracy()

# Iterate over epochs.
for epoch in range(epochs):
  print('Start of epoch %d' % (epoch,))

  # Iterate over the batches of the dataset.
  for step, (x_train, y_train) in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      masks = np.ones((x_train.shape[0], 128)).astype(np.float32)
      logits = fc([x_train, masks])
      # Compute reconstruction loss
      loss = mse_loss_fn(y_train, logits)

    accuracy.update_state(y_train, logits)

    grads = tape.gradient(loss, fc.trainable_weights)
    optimizer.apply_gradients(zip(grads, fc.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
      print('step %s: mean loss = %s\n\tmean accuracy = %s' % (step, loss_metric.result(), accuracy.result()))


Start of epoch 0
step 0: mean loss = tf.Tensor(0.41519332, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.09375, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.06341139, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.7295792, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.047829702, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.81071204, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.040677916, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8454111, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.03597984, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8668563, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.0329194, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8799588, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.030511735, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88963705, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.028672831, shape=(), dtyp

step 100: mean loss = tf.Tensor(0.012302667, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9625748, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.012223821, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9629007, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.0121477805, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9631866, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.0120666865, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9634993, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.011998772, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96375847, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.011922746, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96408457, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.011850891, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.96436334, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.011793408, shape=(), dtype=flo

step 200: mean loss = tf.Tensor(0.009510466, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.97370183, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.009480682, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.973828, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.009449981, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9739454, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.009421431, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9740647, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.009393487, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9741649, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.009365023, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9742869, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.009341755, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9743875, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.009311741, shape=(), dtype=float32)

step 300: mean loss = tf.Tensor(0.008228724, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9788477, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.008211643, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9789131, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.008196073, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.97897243, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.008179574, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9790427, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.008163282, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.97911036, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.008149876, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.9791623, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.008132835, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.97923464, shape=(), dtype=float32)
Start of epoch 19
step 0: mean loss = tf.Tensor(0.008125408, shape