The following trains a single model on MNIST.

We'll need tensorflow, numpy and a model from custom_layers.

In [1]:
import tensorflow as tf
import numpy as np
from custom_layers import *

Create the model and inputs

In [2]:
fc = FullyConnectedModel(output_size=10, hidden_layer_size=128)
input_img = np.ones((1, 28, 28)).astype(np.float32)
input_mask = np.ones((1, 128)).astype(np.float32)
_ = fc([input_img, input_mask])
fc.summary()
tf.keras.utils.plot_model(fc, show_shapes=True)

Model: "fully_connected_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
fully_connected_model_flatte multiple                  0         
_________________________________________________________________
fully_connected_model_encode multiple                  100480    
_________________________________________________________________
fully_connected_model_Custom multiple                  0         
_________________________________________________________________
fully_connected_model_output multiple                  1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________
Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.


Create loss, optimizer, and metrics.

In [3]:
mse_loss_fn = tf.keras.losses.MeanSquaredError()
#mse_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_metric = tf.keras.metrics.Mean()

Prepare the data

In [4]:
(x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype('float32') / 255
masks_train = np.ones((x_train.shape[0], 128)).astype(np.float32)
y_train = tf.keras.backend.one_hot(y_train, 10)
print(x_train.shape)
print(y_train.shape)
print(masks_train.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
print(train_dataset)

(60000, 28, 28)
(60000, 10)
(60000, 128)
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>


Train the model

In [5]:
epochs = 20
accuracy = tf.keras.metrics.CategoricalAccuracy()

# Iterate over epochs.
for epoch in range(epochs):
  print('Start of epoch %d' % (epoch,))

  # Iterate over the batches of the dataset.
  for step, (x_train, y_train) in enumerate(train_dataset):
    with tf.GradientTape() as tape:
      masks = np.ones((x_train.shape[0], 128)).astype(np.float32)
      logits = fc([x_train, masks])
      # Compute reconstruction loss
      loss = mse_loss_fn(y_train, logits)

    accuracy.update_state(y_train, logits)

    grads = tape.gradient(loss, fc.trainable_weights)
    optimizer.apply_gradients(zip(grads, fc.trainable_weights))

    loss_metric(loss)

    if step % 100 == 0:
      print('step %s: mean loss = %s\n\tmean accuracy = %s' % (step, loss_metric.result(), accuracy.result()))


Start of epoch 0
step 0: mean loss = tf.Tensor(0.43917274, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.046875, shape=(), dtype=float32)
step 100: mean loss = tf.Tensor(0.066778004, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.6488243, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.054127038, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.71851677, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.04845472, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.74662584, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.044754155, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.76624846, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.042075355, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.78037673, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.040025935, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.7905574, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.038449544, shape=()

step 100: mean loss = tf.Tensor(0.024348035, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8665026, shape=(), dtype=float32)
step 200: mean loss = tf.Tensor(0.024264691, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8669434, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.024198236, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8672455, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.02412651, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86761016, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.024049101, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86806744, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.023969062, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.86849755, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.023894634, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8689041, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.023829112, shape=(), dtype=float3

step 200: mean loss = tf.Tensor(0.021489177, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8817584, shape=(), dtype=float32)
step 300: mean loss = tf.Tensor(0.02146177, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8819204, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.021432068, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88210773, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.021396216, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88232785, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.021360626, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88252836, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.021325257, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8827504, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.021289958, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.88297004, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.021261303, shape=(), dtype=float

step 300: mean loss = tf.Tensor(0.019992908, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89054495, shape=(), dtype=float32)
step 400: mean loss = tf.Tensor(0.01996923, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89071184, shape=(), dtype=float32)
step 500: mean loss = tf.Tensor(0.019951232, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89080304, shape=(), dtype=float32)
step 600: mean loss = tf.Tensor(0.019926285, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.8909576, shape=(), dtype=float32)
step 700: mean loss = tf.Tensor(0.01990258, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89109886, shape=(), dtype=float32)
step 800: mean loss = tf.Tensor(0.019881612, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89122784, shape=(), dtype=float32)
step 900: mean loss = tf.Tensor(0.019865332, shape=(), dtype=float32)
	mean accuracy = tf.Tensor(0.89132994, shape=(), dtype=float32)
Start of epoch 19
step 0: mean loss = tf.Tensor(0.019857768, shap