In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers

# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Preprocess the data by flattening & scaling it
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255

# Categorical (one hot) encoding of the labels
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

In [None]:
def _tf_ann(X, samples, p=2, soft=True):
    m_dis = None
    for i in range(samples.shape[0]):
        i_sample = samples[i : i + 1, :]
        l_dis = tf.math.reduce_sum((X - i_sample) ** p, axis=1, keepdims=True) ** (
            1 / p
        )
        if m_dis is None:
            m_dis = l_dis
        else:
            m_dis = tf.concat([m_dis, l_dis], 1)

    if soft:
        feature_map = tf.nn.softmax(-m_dis, axis=0)
    else:
        feature_map = tf.one_hot(tf.math.argmax(-m_dis, axis=1), samples.shape[0])
    # l_dis_min = tf.math.reduce_sum(m_dis * feature_map, axis=0)
    return feature_map

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Preprocess the data by flattening & scaling it
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255

# Categorical (one hot) encoding of the labels
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

In [None]:
RandomFourierFeatures = keras.layers.experimental.RandomFourierFeatures

model1 = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        RandomFourierFeatures(
            output_dim=4096, scale=10.0, kernel_initializer="gaussian"
        ),
        layers.Dense(units=10),
    ]
)
model1.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.hinge,
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

model1.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2)

In [None]:
def gen_samples(X, psi, t=1000):
    n, _ = X.shape
    return [
        tf.gather_nd(
            indices=tf.expand_dims(tf.random.shuffle(tf.range(n))[:psi], 1), params=X
        )
        for _ in range(t)
    ]


t_samples = gen_samples(x_train, psi=16, t=500)

In [None]:
class IsolationEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, samples, p=2, soft=True, **kwargs):
      super(IsolationEncodingLayer, self).__init__(**kwargs)
      self.samples = samples
      self.p = p
      self.soft = soft

    def call(self, inputs):
      return _tf_ann(inputs, self.samples, self.p, self.soft)

In [None]:
def build_model(t_samples, p=2, soft=True):
    t = len(t_samples)
    if t <= 0:
        raise ValueError("t <= 0")
    _, dims = t_samples[0].shape

    inputs = keras.Input(name="inputs_x", shape=(dims,))
    lambdas = [
        IsolationEncodingLayer(
            t_samples[i], 
            p = p,
            soft = soft,
            name="ann_{}".format(i)
        )(inputs)
        for i in range(t)
    ]
    concatenated = layers.Concatenate(axis=1, name="concatenated")(lambdas)
    outputs = layers.Dense(units=10, name="outputs_y")(concatenated)

    model = keras.Model(name="isolation_encoding", inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss=keras.losses.hinge,
        metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
    )
    return model


model2 = build_model(t_samples, soft=False)
model2.summary()

In [None]:
model2.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2)

In [None]:
model3 = build_model(t_samples, soft=True)
model3.summary()

In [None]:
model3.fit(x_train, y_train, epochs=50, batch_size=128, validation_split=0.2)