In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

from datetime import datetime
import tensorflow as tf
from tensorflow import keras
from keras import layers

# !rm -rf ./logs/**

In [3]:
# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Preprocess the data by flattening & scaling it
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255

# Categorical (one hot) encoding of the labels
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

In [4]:
def _tf_ann(X, samples, p=2, soft=True):
    m_dis = None
    for i in range(samples.shape[0]):
        i_sample = samples[i : i + 1, :]
        l_dis = tf.math.reduce_sum((X - i_sample) ** p, axis=1, keepdims=True) ** (
            1 / p
        )
        if m_dis is None:
            m_dis = l_dis
        else:
            m_dis = tf.concat([m_dis, l_dis], 1)

    if soft:
        feature_map = tf.nn.softmax(-m_dis, axis=0)
    else:
        feature_map = tf.one_hot(tf.math.argmax(-m_dis, axis=1), samples.shape[0])
    # l_dis_min = tf.math.reduce_sum(m_dis * feature_map, axis=0)
    return feature_map

In [5]:
RandomFourierFeatures = keras.layers.experimental.RandomFourierFeatures

model1 = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        RandomFourierFeatures(
            output_dim=8192, scale=10.0, kernel_initializer="gaussian"
        ),
        layers.Dense(units=10),
    ]
)
model1.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.hinge,
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

model1.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2, callbacks=[
    keras.callbacks.TensorBoard(log_dir="./logs/fit/linear-" + datetime.now().strftime("%Y%m%d-%H%M%S"), histogram_freq=1)
])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x19fb43ceda0>

In [6]:
def gen_samples(X, psi, t=1000):
    n, _ = X.shape
    return [
        tf.gather_nd(
            indices=tf.expand_dims(tf.random.shuffle(tf.range(n))[:psi], 1), params=X
        ).numpy()
        for _ in range(t)
    ]


t_samples = gen_samples(x_train, psi=16, t=500)

In [7]:
class IsolationEncodingLayer(layers.Layer):
    def __init__(self, samples, p=2, soft=True, **kwargs):
      super(IsolationEncodingLayer, self).__init__(**kwargs)
      self.samples = samples
      self.p = p
      self.soft = soft

    def call(self, inputs):
      return _tf_ann(inputs, self.samples, self.p, self.soft)

    def get_config(self):
        config = super().get_config()
        config.update({
            "samples": self.samples,
            "p": self.p,
            "soft": self.soft,
        })
        return config

In [8]:
def build_model(t_samples, p=2, soft=True):
    t = len(t_samples)
    if t <= 0:
        raise ValueError("t <= 0")
    _, dims = t_samples[0].shape

    inputs = keras.Input(name="inputs_x", shape=(dims,))
    lambdas = [
        IsolationEncodingLayer(
            t_samples[i], 
            p = p,
            soft = soft,
            name="ann_{}".format(i)
        )(inputs)
        for i in range(t)
    ]
    concatenated = layers.Concatenate(axis=1, name="concatenated")(lambdas)
    outputs = layers.Dense(units=10, name="outputs_y")(concatenated)

    model = keras.Model(name="isolation_encoding", inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss=keras.losses.hinge,
        metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
    )
    return model


model2 = build_model(t_samples, soft=False)
# model2.summary()

In [9]:
model2.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2, callbacks=[
    keras.callbacks.TensorBoard(log_dir="./logs/fit/hard-" + datetime.now().strftime("%Y%m%d-%H%M%S"), histogram_freq=1)
])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x2427f915480>

In [10]:
model3 = build_model(t_samples, soft=True)
# model3.summary()

In [11]:
model3.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2, callbacks=[
    keras.callbacks.TensorBoard(log_dir="./logs/fit/soft-" + datetime.now().strftime("%Y%m%d-%H%M%S"), histogram_freq=1)
])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x248f50fc940>

In [6]:
%tensorboard --logdir ./logs/fit