# MNIST with Decision Tree + Deep NN

In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
from sklearn.utils import shuffle

In [None]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

def digit_indices(digit, labels = train_labels):
    return np.where(labels==digit)

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

In [None]:
def decision_labels(labels, d0, d1, other=[]):
    num_samples = labels.shape[0]
    new_labels = np.zeros(shape=(num_samples, 3))
    for i in range(num_samples):
        if labels[i] in d0:
            new_labels[i, 0] += 1
        elif labels[i] in d1:
            new_labels[i, 1] += 1
        else:
            new_labels[i, 2] += 1
    return new_labels

## Building a NN with decision tree architecture

Build the full tree with the following structure: 

- [10] --> [6] + [4r]
- [6] --> [4l] + [2]
- [4r] --> [2] + [2]
- [4l] --> [2] + [2]

In [None]:
model_input = keras.Input(shape=(28*28))

def tree_node(node_input, node_name='classifier_node'):
    node_output = layers.Dense(512, activation='relu')(node_input)
    node_output = layers.Dense(3, activation='softmax')(node_output)
    node_model = keras.Model(inputs = model_input,
                             outputs = node_output,
                             name = node_name)
    return node_model, node_output


nodes = []
concat_layers = []

# node 0 : [10] --> [6] + [4r]
nodes.append(tree_node(model_input, node_name='node_0'))
concat_layers.append(layers.Concatenate()([model_input, nodes[0][1]]))

# node 1 : [6] --> [4l] + [2c]
nodes.append(tree_node(concat_layers[0], node_name='node_1'))
concat_layers.append(layers.Concatenate()([model_input, nodes[1][1]]))

# node 2 : [4r] --> [2r] + [2rr]
nodes.append(tree_node(concat_layers[0], node_name='node_2'))
concat_layers.append(layers.Concatenate()([model_input, nodes[2][1]]))

# node 3 : [4l] --> [2ll] + [2l]
nodes.append(tree_node(concat_layers[1], node_name='node_3'))
concat_layers.append(layers.Concatenate()([model_input, nodes[3][1]]))

# node 4 : [2c]
nodes.append(tree_node(concat_layers[1], node_name='node_4'))

# node 5 : [2r]
nodes.append(tree_node(concat_layers[2], node_name='node_5'))

# node 6 : [2rr]
nodes.append(tree_node(concat_layers[2], node_name='node_6'))

# node 7 : [2ll]
nodes.append(tree_node(concat_layers[3], node_name='node_7'))

# node 8 : [2l]
nodes.append(tree_node(concat_layers[3], node_name='node_8'))

leaf_nodes = [7, 8, 4, 5, 6]
concat_leaves = layers.Concatenate()([nodes[i][1] for i in leaf_nodes])

model_output = layers.Dense(10, activation='softmax')(concat_leaves)
model = keras.Model(inputs = model_input, outputs = model_output)

In [None]:
model.summary()

# Model training

In [None]:
import itertools

def digits_comp(digits):
    return sorted(list( set(range(10)) - set(digits) ))


def generate_cats(digits, size):
    combinations = list(itertools.combinations(digits, size))
    return [(list(x),
             sorted(list(set(digits) - set(x))),
             digits_comp(digits)) for x in combinations]


def train_node(node, categories, vs=0, num_epochs=10, report=0, fix_weights=False):
    node.compile(optimizer="adam",
                 loss="categorical_crossentropy",
                 metrics=["accuracy"])
    lb, rb, nib = categories
    labels = decision_labels(train_labels, lb, rb, nib)
    history = node.fit(train_images, labels, validation_split=vs,
                       epochs=num_epochs, batch_size=128,
                       verbose=report)
    if fix_weights:
        node.trainable = False
    return history.history

In [None]:
digits_comp([1,2,5,7])

[0, 3, 4, 6, 8, 9]

In [None]:
generate_cats([1,2,3,4], 2)

[([1, 2], [3, 4], [0, 5, 6, 7, 8, 9]),
 ([1, 3], [2, 4], [0, 5, 6, 7, 8, 9]),
 ([1, 4], [2, 3], [0, 5, 6, 7, 8, 9]),
 ([2, 3], [1, 4], [0, 5, 6, 7, 8, 9]),
 ([2, 4], [1, 3], [0, 5, 6, 7, 8, 9]),
 ([3, 4], [1, 2], [0, 5, 6, 7, 8, 9])]

## Decision Tree training

### Node 0 : [10] --> [6] + [4]

In [None]:
categories_node_0 = generate_cats(range(10), 6)
histories_node_0 = []

for cats in categories_node_0:
    histories_node_0.append(train_node(nodes[0][0], cats, vs=0.2, num_epochs=15))

In [None]:
max_va_0 = [max(history['val_accuracy']) for history in histories_node_0]
max_va_0.index(max(max_va_0))

81

In [None]:
categories_node_0[81]

([0, 2, 3, 5, 6, 8], [1, 4, 7, 9], [])

### Node 1 : [6] --> [4] + [2]

In [None]:
categories_node_1 = generate_cats([0, 2, 3, 5, 6, 8], 4)
histories_node_1 = []

for cats in categories_node_1:
    histories_node_1.append(train_node(nodes[1][0], cats, vs=0.2, num_epochs=15))

In [None]:
max_va_1 = [max(history['val_accuracy']) for history in histories_node_1]
max_va_1.index(max(max_va_1))

In [None]:
categories_node_0[81]

## NN training

Each node is a classifier which classifies the image as belonging to the left branch, right branch or irrelevant to this branch with the categories obtained in the Decision Tree trainin phase.

For example, the output categories of node 1 are [2, 3, 5, 8], [0, 6], [1, 4, 9, 7]

In [None]:
def digits_comp(digits):
    return sorted(list( set(range(10)) - set(digits) ))


node_categories = []

# node 0 : [10] --> [6] + [4r]
node_categories.append( ([0,2,3,5,6,8], [1,4,7,9], []) )

# node 1 : [6] --> [4l] + [2c]
node_categories.append( ([2,3,5,8], [0,6], [1,4,7,9]) )

# node 2 : [4r] --> [2r] + [2rr]
node_categories.append( ([1,4], [7,9], [0,2,3,5,6,8]) )

# node 3 : [4l] --> [2ll] + [2l]
node_categories.append( ([2,3], [5,8], [0,1,4,6,7,9]) )

# node 4 : [2c]
node_categories.append( ([0], [6], digits_comp([0,6])) )

# node 5 : [2r]
node_categories.append( ([1], [4], digits_comp([1,4])) )

# node 6 : [2rr]
node_categories.append( ([7], [9], digits_comp([7,9])) )

# node 7 : [2ll]
node_categories.append( ([2], [3], digits_comp([2,3])) )

# node 8 : [2l]
node_categories.append( ([5], [8], digits_comp([5,8])) )

In [None]:
for i in range(len(nodes)):
    histories = train_node(nodes[i][0], node_categories[i], fix_weights=True)

In [None]:
keras.utils.plot_model(model, "MNIST_DT.png")

In [None]:
model.summary()

### Combining the leaves

Since we have already fixed all the nodes to be non-trainable, instead of manually working out the final transformation that combines the leaves to the final output we can simply train the model to learn the transformations!

In [None]:
model.compile(optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"])

model.fit(train_images, train_labels, epochs=3, batch_size=128)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f4e26c117d0>

In [None]:
model.evaluate(test_images, test_labels)



[0.1784956455230713, 0.9761000275611877]