In [1]:
from datetime import datetime
import tensorflow as tf
from tensorflow import keras
from keras import layers
import tensorflow_datasets as tfds

!rm -rf ./logs/mnist2/*

In [2]:
epochs = 20
batch_size = 64
image_w, image_h = 28, 28

In [3]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

n_classes = ds_info.features['label'].num_classes
n = ds_info.splits['train'].num_examples
print("n: ", n, "n_classes:", n_classes)


n:  60000 n_classes: 10


In [4]:
def normalize_ds(ds):
    def normalize_img(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        image = layers.Resizing(image_w, image_h)(image)
        image = tf.reshape(image, [image_w * image_h])
        # label = keras.utils.to_categorical(label)
        label = tf.one_hot(tf.cast(label, tf.int32), n_classes)
        label = tf.cast(label, tf.float32)
        return image, label

    ds = ds.cache().map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
    return ds


ds_train = (
    normalize_ds(ds_train).shuffle(n).batch(batch_size).prefetch(tf.data.AUTOTUNE)
)
ds_test = normalize_ds(ds_test).batch(batch_size).prefetch(tf.data.AUTOTUNE)


In [13]:
def gen_samples(psi, t=1000):
    return [
        list(ds_train.unbatch().shuffle(n).take(psi).batch(psi).as_numpy_iterator())[0][0]
        for _ in range(t)
    ]

In [6]:
RandomFourierFeatures = keras.layers.experimental.RandomFourierFeatures

model_svm = keras.Sequential(
    [
        layers.Input(shape=(image_w * image_h,)),
        RandomFourierFeatures(
            output_dim=8192, scale=10.0, kernel_initializer="gaussian"
        ),
        layers.Dense(units=n_classes),
    ]
)
model_svm.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.hinge,
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

modeldir = "./logs/mnist2/linear-8192-" + datetime.now().strftime("%Y%m%d-%H%M%S")
model_svm.fit(
    ds_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=ds_test,
    callbacks=[
        keras.callbacks.TensorBoard(
            log_dir=modeldir+"/log",
            histogram_freq=1,
        )
    ],
)
model_svm.save(modeldir + "/model")

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
INFO:tensorflow:Assets written to: ./logs/mnist2/linear-8192-20230307-225444/model\assets


INFO:tensorflow:Assets written to: ./logs/mnist2/linear-8192-20230307-225444/model\assets


In [7]:
def _tf_ann(X, samples, p=2, soft=True):
    m_dis = None
    for i in range(samples.shape[0]):
        i_sample = samples[i : i + 1, :]
        l_dis = tf.math.reduce_sum((X - i_sample) ** p, axis=1, keepdims=True) ** (
            1 / p
        )
        if m_dis is None:
            m_dis = l_dis
        else:
            m_dis = tf.concat([m_dis, l_dis], 1)

    if soft:
        feature_map = tf.nn.softmax(-m_dis, axis=0)
    else:
        feature_map = tf.one_hot(tf.math.argmax(-m_dis, axis=1), samples.shape[0])
    # l_dis_min = tf.math.reduce_sum(m_dis * feature_map, axis=0)
    return feature_map


class IsolationEncodingLayer(layers.Layer):
    def __init__(self, samples, p=2, soft=True, **kwargs):
        super(IsolationEncodingLayer, self).__init__(**kwargs)
        self.samples = samples
        self.p = p
        self.soft = soft

    def call(self, inputs):
        return _tf_ann(inputs, self.samples, self.p, self.soft)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "samples": self.samples,
                "p": self.p,
                "soft": self.soft,
            }
        )
        return config

In [17]:
def build_model(t_samples, p=2, soft=True):
    t = len(t_samples)
    if t <= 0:
        raise ValueError("t <= 0")
    _, dims = t_samples[0].shape

    inputs = keras.Input(name="inputs_x", shape=(dims,))
    lambdas = [
        IsolationEncodingLayer(t_samples[i], p=p, soft=soft, name="ann_{}".format(i))(
            inputs
        )
        for i in range(t)
    ]
    concatenated = layers.Concatenate(axis=1, name="concatenated")(lambdas)
    outputs = layers.Dense(units=10, name="outputs_y")(concatenated)

    model = keras.Model(name="isolation_encoding", inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=1e-3),
        loss=keras.losses.hinge,
        metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
    )
    return model

In [14]:
t_samples = gen_samples(psi=16, t=500)


In [18]:
model_hard_16_500 = build_model(t_samples, soft=False)
modeldir = "./logs/mnist2/hard-16x500-" + datetime.now().strftime("%Y%m%d-%H%M%S")
model_hard_16_500.fit(
    ds_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=ds_test,
    callbacks=[
        keras.callbacks.TensorBoard(log_dir=modeldir + "/log", histogram_freq=1)
    ],
)
model_hard_16_500.save(modeldir + "/model")

Epoch 1/20

In [None]:
model_soft_16_500 = build_model(t_samples, soft=True)
modeldir = "./logs/mnist2/soft-16x500-" + datetime.now().strftime("%Y%m%d-%H%M%S")
model_soft_16_500.fit(
    ds_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=ds_test,
    callbacks=[
        keras.callbacks.TensorBoard(log_dir=modeldir + "/log", histogram_freq=1)
    ],
)
model_soft_16_500.save(modeldir + "/model")

In [None]:
t_samples = gen_samples(psi=160, t=50)

In [None]:
model_hard_160_50 = build_model(t_samples, soft=False)
modeldir = "./logs/mnist2/hard-160x50-" + datetime.now().strftime("%Y%m%d-%H%M%S")
model_hard_160_50.fit(
    ds_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=ds_test,
    callbacks=[
        keras.callbacks.TensorBoard(log_dir=modeldir + "/log", histogram_freq=1)
    ],
)
model_hard_160_50.save(modeldir + "/model")

In [None]:
model_soft_160_50 = build_model(t_samples, soft=True)
modeldir = "./logs/mnist2/soft-160x50-" + datetime.now().strftime("%Y%m%d-%H%M%S")
model_soft_160_50.fit(
    ds_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=ds_test,
    callbacks=[
        keras.callbacks.TensorBoard(log_dir=modeldir + "/log", histogram_freq=1)
    ],
)
model_soft_160_50.save(modeldir + "/model")