In [1]:
from datetime import datetime
import tensorflow as tf
from tensorflow import keras
from keras import layers
import tensorflow_datasets as tfds

epochs = 30
batch_size = 64
image_w, image_h = 32, 32


In [2]:
def soft_greater(x, y=0, resolution=2**32):
    return (tf.math.tanh((x - y) * resolution) + 1.0) / 2.0

def hard_encode(
    x,  # n, dims
    l_lower,  # psi, dims
    l_upper,  # psi, dims
):
    x_encoded = tf.math.reduce_prod(
        tf.cast(
            tf.greater_equal(
                tf.expand_dims(x, axis=1), tf.expand_dims(l_lower, axis=0)
            ),
            dtype=tf.float32,
        ),
        axis=2,
    )
    x_encoded *= tf.math.reduce_prod(
        tf.cast(
            tf.greater(tf.expand_dims(l_upper, axis=0), tf.expand_dims(x, axis=1)),
            dtype=tf.float32,
        ),
        axis=2,
    )

    return x_encoded

def hard_decode(
    x_encoded,  # n, psi
    l_lower,  # psi, dims
    l_upper,  # psi, dims
):
    eps = tf.keras.backend.epsilon()
    g_lower = tf.math.reduce_min(l_lower, axis=0, keepdims=True)
    g_upper = tf.math.reduce_max(l_upper, axis=0, keepdims=True)

    x_lower = tf.expand_dims(x_encoded, axis=2) * tf.expand_dims(
        tf.math.add(l_lower, eps) - g_lower, axis=0
    )
    x_lower = (
        tf.math.reduce_sum(
            tf.one_hot(
                tf.math.argmax(x_lower, axis=1),
                x_encoded.shape[1],
                axis=1,
                dtype=tf.float32,
            )
            * x_lower,
            axis=1,
        )
        - eps
        + g_lower
    )

    x_upper = tf.expand_dims(x_encoded, axis=2) * tf.expand_dims(
        tf.math.subtract(eps, l_upper) + g_upper, axis=0
    )
    x_upper = (
        eps
        + g_upper
        - tf.math.reduce_sum(
            tf.one_hot(
                tf.math.argmax(x_upper, axis=1),
                x_encoded.shape[1],
                axis=1,
                dtype=tf.float32,
            )
            * x_upper,
            axis=1,
        )
    )
    return (x_lower + x_upper) / 2

def soft_encode(
    x,  # n, dims
    l_lower,  # psi, dims
    l_upper,  # psi, dims
):
    x_encoded = tf.math.reduce_prod(
        soft_greater(tf.expand_dims(x, axis=1), tf.expand_dims(l_lower, axis=0)), axis=2
    )
    x_encoded *= tf.math.reduce_prod(
        soft_greater(tf.expand_dims(l_upper, axis=0), tf.expand_dims(x, axis=1)), axis=2
    )

    return x_encoded

def soft_decode(
    x_encoded,  # n, psi
    l_lower,  # psi, dims
    l_upper,  # psi, dims
):
    eps = tf.keras.backend.epsilon()
    g_lower = tf.math.reduce_min(l_lower, axis=0, keepdims=True)
    g_upper = tf.math.reduce_max(l_upper, axis=0, keepdims=True)

    x_lower = tf.expand_dims(x_encoded, axis=2) * tf.expand_dims(
        tf.math.add(l_lower, eps) - g_lower, axis=0
    )
    x_lower = (
        tf.math.reduce_sum(
            tf.math.softmax(x_lower, axis=1) * x_lower,
            axis=1,
        )
        - eps
        + g_lower
    )

    x_upper = tf.expand_dims(x_encoded, axis=2) * tf.expand_dims(
        tf.math.subtract(eps, l_upper) + g_upper, axis=0
    )
    x_upper = (
        eps
        + g_upper
        - tf.math.reduce_sum(
            tf.math.softmax(x_upper, axis=1) * x_upper,
            axis=1,
        )
    )
    return (x_lower + x_upper) / 2

x, l_lower, l_upper = (
    [
        [-0.1],
        [-0.2],
        [-0.5],
        [-0.8],
    ],
    [
        [-0.3],
        [-0.6],
        [-0.9],
    ],
    [
        [0.0],
        [-0.4],
        [-0.7],
    ],
)

x_encoded = hard_encode(x, l_lower, l_upper)

print(x)

print(x_encoded)
print(hard_decode(x_encoded, l_lower, l_upper))

x_encoded = soft_encode(x, l_lower, l_upper)

print(x_encoded)
print(soft_decode(x_encoded, l_lower, l_upper))


[[-0.1], [-0.2], [-0.5], [-0.8]]
tf.Tensor(
[[1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[-0.15000004]
 [-0.15000004]
 [-0.5       ]
 [-0.79999995]], shape=(4, 1), dtype=float32)
tf.Tensor(
[[1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]], shape=(4, 3), dtype=float32)
tf.Tensor(
[[-0.306981 ]
 [-0.306981 ]
 [-0.4750027]
 [-0.6255996]], shape=(4, 1), dtype=float32)


In [3]:
class FloatingBoxesEncoder(layers.Layer):
    def __init__(self, psi, lower_boundary, upper_boundary, soft=True, **kwargs):
        super(FloatingBoxesEncoder, self).__init__(**kwargs)

        dims = lower_boundary.shape[1]
        if upper_boundary.shape[1] != dims:
            raise ValueError()

        self.psi = psi
        self.dims = dims
        self.lower_boundary = lower_boundary
        self.upper_boundary = upper_boundary
        self.soft = soft

        return

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.lower_ratio = self.add_weight(
            name="lower_ratio",
            shape=(self.psi, self.dims),
            initializer=tf.random_uniform_initializer(minval=0, maxval=1),
            constraint=lambda x: tf.clip_by_value(x, 0, 1),
            trainable=True,
            dtype=tf.float32,
        )
        self.size_ratio = self.add_weight(
            name="size_ratio",
            shape=(self.psi, self.dims),
            initializer=tf.random_uniform_initializer(minval=0, maxval=1),
            constraint=lambda x: tf.clip_by_value(x, 0, 1),
            trainable=True,
            dtype=tf.float32,
        )
        super(FloatingBoxesEncoder, self).build(input_shape)

    def get_box_boundaries(self):
        lower_bounds = (
            self.upper_boundary - self.lower_boundary
        ) * self.lower_ratio + self.lower_boundary
        upper_bounds = (
            self.upper_boundary - lower_bounds
        ) * self.size_ratio + lower_bounds
        return lower_bounds, upper_bounds

    def call(self, inputs):
        lower_bounds, upper_bounds = self.get_box_boundaries()
        if self.soft:
            return soft_encode(inputs, lower_bounds, upper_bounds)
        else:
            return hard_encode(inputs, lower_bounds, upper_bounds)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "psi": self.psi,
                "dims": self.dims,
                "soft": self.soft,
                "lower_boundary": self.lower_boundary.numpy(),
                "upper_boundary": self.upper_boundary.numpy(),
            }
        )
        return config


class FloatingBoxesDecoder(layers.Layer):
    def __init__(self, box_lower_boundaries, box_upper_boundaries, soft=True, **kwargs):
        super(FloatingBoxesDecoder, self).__init__(**kwargs)

        psi, dims = box_lower_boundaries.shape
        if (
            psi != box_upper_boundaries.shape[0]
            or dims != box_upper_boundaries.shape[1]
        ):
            raise ValueError()

        if len(box_lower_boundaries.shape) != 2 or len(box_upper_boundaries.shape) != 2:
            raise ValueError()

        if tf.reduce_any(box_lower_boundaries > box_upper_boundaries):
            raise ValueError()

        self.psi = psi
        self.dims = dims
        self.box_lower_boundaries = box_lower_boundaries
        self.box_upper_boundaries = box_upper_boundaries
        
        self.soft = soft

        return

    def call(self, inputs):
        if self.soft:
            return soft_decode(inputs, self.box_lower_boundaries, self.box_upper_boundaries)
        else:
            return hard_decode(inputs, self.box_lower_boundaries, self.box_upper_boundaries)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "psi": self.psi,
                "dims": self.dims,
                "soft": self.soft,
                "box_lower_boundaries": self.box_lower_boundaries.numpy(),
                "box_upper_boundaries": self.box_upper_boundaries.numpy(),
            }
        )
        return config


In [14]:
def build_autoencoder(psi, lower, upper, soft=True):
    inputs = keras.Input(name="inputs_x", shape=(lower.shape[1],))
    encoder = FloatingBoxesEncoder(psi, lower, upper, soft)
    encoded = encoder(inputs)
    box_lower_bounds, box_upper_bounds = encoder.get_box_boundaries()
    decoder = FloatingBoxesDecoder(box_lower_bounds, box_upper_bounds, soft)
    outputs = decoder(encoded)

    model = keras.Model(
        name="floating_boxes_autoencoder", inputs=inputs, outputs=outputs
    )
    model.compile(loss="mean_squared_error", optimizer=keras.optimizers.Adam(learning_rate=1e-10))
    return model, encoder, decoder

x = tf.constant(
    [
        [0.1],
        [0.2],
        [0.5],
        [0.8],
    ]
)

lower = tf.constant([[0.0]])
upper = tf.constant([[1.0]])
print(x)
autoencoder, encoder, decoder = build_autoencoder(100, lower, upper, soft=True)
print(autoencoder(x))
autoencoder.fit(
    x,
    x,
    epochs=epochs,
    batch_size=2,
)
print(autoencoder(x))

tf.Tensor(
[[0.1]
 [0.2]
 [0.5]
 [0.8]], shape=(4, 1), dtype=float32)
tf.Tensor(
[[0.48147973]
 [0.45885518]
 [0.5158565 ]
 [0.6249186 ]], shape=(4, 1), dtype=float32)
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
tf.Tensor(
[[0.48147973]
 [0.45885518]
 [0.5158565 ]
 [0.6249186 ]], shape=(4, 1), dtype=float32)


In [10]:
autoencoder.summary()
# from keras.utils.vis_utils import plot_model
# plot_model(autoencoder, to_file='model_plot.png', show_shapes=True, show_layer_names=True)

# from matplotlib import pyplot as plt
# import matplotlib.image as mpimg
# plt.axis("off")
# plt.imshow(mpimg.imread('model_plot.png'))
# plt.show()


Model: "floating_boxes_autoencoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 inputs_x (InputLayer)       [(None, 1)]               0         
                                                                 
 floating_boxes_encoder_1 (F  (None, 100)              200       
 loatingBoxesEncoder)                                            
                                                                 
 floating_boxes_decoder_1 (F  (None, 1)                0         
 loatingBoxesDecoder)                                            
                                                                 
Total params: 200
Trainable params: 200
Non-trainable params: 0
_________________________________________________________________


In [7]:
# (ds_train_raw, ds_test_raw), ds_info = tfds.load(
#     "mnist",
#     split=["train", "test"],
#     shuffle_files=False,
#     as_supervised=True,
#     with_info=True,
# )

# n_classes = ds_info.features["label"].num_classes
# n = ds_info.splits["train"].num_examples


# def normalize_img(image, label):
#     image = tf.cast(image, tf.float32) / 255.0
#     image = layers.Resizing(image_h, image_w)(image)
#     image = tf.reshape(image, [-1])
#     label = tf.one_hot(tf.cast(label, tf.int32), n_classes)
#     label = tf.cast(label, tf.float32)
#     return image, label


# ds_train_normalized = ds_train_raw.map(
#     normalize_img, num_parallel_calls=tf.data.AUTOTUNE
# ).cache()

# ds_test_normalized = ds_test_raw.map(
#     normalize_img, num_parallel_calls=tf.data.AUTOTUNE
# ).cache()


# def prepare(ds, batch_size=batch_size):
#     return ds.shuffle(n).batch(batch_size).prefetch(tf.data.AUTOTUNE)


# dims = list(ds_train_normalized.take(1))[0][0].shape[0]

# print("n: ", n, "n_classes: ", n_classes, "dims: ", dims)


# def minmax_reducer(current, input):
#     X, _ = input
#     return (
#         tf.reduce_min([current[0], X], axis=0),
#         tf.reduce_max([current[0], X], axis=0),
#     )


# x0, _ = list(ds_train_normalized.take(1))[0]
# min_train, max_train = ds_train_normalized.reduce((x0, x0), minmax_reducer)


# lower = tf.expand_dims(min_train, axis=0)
# upper = tf.expand_dims(max_train, axis=0)

# autoencoder, lower_bounds, upper_bounds = build_autoencoder(10, lower, upper)
# ds = (prepare(ds_train_normalized),)
# autoencoder.fit(
#     ds,
#     ds,
#     epochs=epochs,
#     batch_size=batch_size,
# )
