In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
from sklearn.utils import shuffle

TF_ENABLE_ONEDNN_OPTS=1

In [2]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

def digit_indices(digit, labels = train_labels):
    return np.where(labels==digit)

def average_digit_image(digit, images = train_images, labels = train_labels):
    return np.average(images[digit_indices(labels, digit)],axis=0)

average_digits = np.array([average_digit_image(i) for i in range(10)]).reshape((10,28*28))
average_digits = average_digits.astype("float32") / 255
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

In [33]:
def split_features(thresholds = [0.5]*10):
    masks = np.zeros((10, 3, 28*28, 1), dtype="float32")
    for digit in range(10):
        for i in range(28*28):
            if average_digits[digit, i] == 0:
                masks[digit, 0, i] += 1
            elif average_digits[digit, i] <= thresholds[digit]:
                masks[digit, 1, i] += 1
            else:
                masks[digit, 2, i] += 1
    return masks

In [12]:
model_input = layers.Input(shape=28*28)
mask_layers = []
concat_layers = []
hidden_layers = []
for digit in range(10):
    mask_layers.append([
        layers.Dense(1, trainable=False, name="{}_{}".format(digit,i))(model_input)
        for i in range(3)])
    concat_layers.append(
        layers.Concatenate(axis=1, name="concat{}".format(digit))(mask_layers[digit])
    )
    hidden_layers.append(
        layers.Dense(1, activation="sigmoid", name="sigmoid{}".format(digit))(concat_layers[digit])
    )
pre_output_layer = layers.Concatenate(axis=1, name="pre_output_concat")(hidden_layers)
model_output = layers.Softmax()(pre_output_layer)

In [34]:
masks = split_features(thresholds = [0.4, 0.8, 0.3, 0.3, 0.5, 0.4, 0.2, 0.1, 0.5, 0.5])
new_bias = np.zeros(shape=(1,), dtype="float32")

for digit in range(10):
    for i in range(3):
        model.get_layer('{}_{}'.format(digit, i)).set_weights([masks[digit, i], new_bias])

In [20]:
sub_models = [
    keras.Model(inputs = model_input,
               outputs = hidden_layers[digit],
               name = "is{}_model".format(digit))
    for digit in range(10)
]

for digit in range(10):
    sub_models[digit].compile(optimizer = "rmsprop",
                              loss = "binary_crossentropy",
                              metrics = ["binary_accuracy"])
 

In [32]:
train_labels_b = np.logical_not((train_labels - 1).astype("bool")).astype("float32")
sub_models[0].fit(train_images, train_labels_b, epochs=20, batch_size=128)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7fed4ccf5e20>

In [None]:
for digit in range(10):
    train_labels_b = np.logical_not((train_labels - digit).astype("bool")).astype("float32")
    sub_models[digit].fit(train_images, train_labels, epochs=20, batch_size=128)

In [13]:
model = keras.Model(inputs = model_input,
                    outputs = model_output,
                    name = "single_digit_model")

model.compile(optimizer = "rmsprop",
              loss = "sparse_categorical_crossentropy",
              metrics = ["accuracy"])

model.summary()

Model: "single_digit_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 784)]        0                                            
__________________________________________________________________________________________________
0_0 (Dense)                     (None, 1)            785         input_3[0][0]                    
__________________________________________________________________________________________________
0_1 (Dense)                     (None, 1)            785         input_3[0][0]                    
__________________________________________________________________________________________________
0_2 (Dense)                     (None, 1)            785         input_3[0][0]                    
_________________________________________________________________________________

In [9]:
model.fit(train_images, train_labels, validation_split=0.2, epochs=10, batch_size=128)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fedec7e5a90>

In [10]:
model.evaluate(test_images, test_labels)



[2.144726037979126, 0.18490000069141388]