In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
def draw_images(images):
    images_count = images.shape[0]
    fig, axs = plt.subplots(1, images_count, figsize=(images_count * 2, 2))
    for i, ax in enumerate(axs.flat):
        ax.axis("off")
        ax.imshow(images[i])
    plt.tight_layout()
    plt.show()

In [None]:
def generate_random_circles(count, scale=1.0, offset=(0.0, 0.0)):
    offset_x, offset_y = offset
    x = (tf.random.uniform([count], minval=-1.0, maxval=1.0) + offset_x) * scale
    y = (tf.random.uniform([count], minval=-1.0, maxval=1.0) + offset_y) * scale
    radius = tf.random.uniform([count], minval=0.0, maxval=1.0) * scale
    r = tf.random.uniform([count], minval=0.0, maxval=1.0)
    g = tf.random.uniform([count], minval=0.0, maxval=1.0)
    b = tf.random.uniform([count], minval=0.0, maxval=1.0)
    return tf.stack([x, y, radius, r, g, b], axis=-1)

In [None]:
def generate_circle_images(circles, image_size, scale=1.0, offset=(0.0, 0.0)):
    offset_x, offset_y = offset
    x, y = tf.meshgrid(
        tf.linspace(
            (-1.0 + offset_x) * scale,
            (1.0 + offset_x) * scale,
            image_size,
        ),
        tf.linspace(
            (-1.0 + offset_y) * scale,
            (1.0 + offset_y) * scale,
            image_size,
        ),
    )
    x = tf.tile(x[tf.newaxis, :, :], [circles.shape[0], 1, 1])
    y = tf.tile(y[tf.newaxis, :, :], [circles.shape[0], 1, 1])
    circles = tf.tile(
        circles[:, tf.newaxis, tf.newaxis, :], [1, image_size, image_size, 1]
    )
    distance = tf.sqrt(tf.square(x - circles[..., 0]) + tf.square(y - circles[..., 1]))
    mask = tf.cast(distance <= circles[..., 2], dtype=tf.float32)
    color = mask[..., tf.newaxis] * circles[..., -3:]
    return color


draw_images(generate_circle_images(generate_random_circles(10), 100))

In [None]:
def generate_data(count, image_size, scale=1.0, offset=(0.0, 0.0)):
    offset_x, offset_y = offset
    circles = generate_random_circles(count, scale=scale, offset=offset)
    images = generate_circle_images(
        circles,
        image_size,
        scale=scale,
        offset=offset,
    )
    latent = tf.tile(
        circles[:, tf.newaxis, tf.newaxis, :], [1, image_size, image_size, 1]
    )
    x, y = tf.meshgrid(
        tf.linspace((-1.0 + offset_x) * scale, (1.0 + offset_x) * scale, image_size),
        tf.linspace(
            (-1.0 + offset_y) * scale, (1.0 + offset_y) * scale, image_size, image_size
        ),
    )
    x = tf.expand_dims(tf.tile(x[tf.newaxis, ...], [circles.shape[0], 1, 1]), axis=-1)
    y = tf.expand_dims(tf.tile(y[tf.newaxis, ...], [circles.shape[0], 1, 1]), axis=-1)

    inputs = tf.concat([x, y, latent], axis=-1)
    outputs = images

    return inputs, outputs


inputs, outputs = generate_data(count=10, image_size=200)
print(inputs.shape)
print(outputs.shape)
draw_images(outputs)

In [None]:
def create_positional_decoder_model():
    class CustomActivationLayer(tf.keras.layers.Layer):
        def __init__(self, **kwargs):
            super(CustomActivationLayer, self).__init__(**kwargs)

        def build(self, input_shape):
            self.lower_treshold = self.add_weight(
                name="lower_treshold",
                shape=(input_shape[-1],),
                initializer=tf.keras.initializers.RandomNormal(mean=-1.0, stddev=0.5),
                trainable=True,
            )
            # self.lower_slope = self.add_weight(
            #     name="lower_slope",
            #     shape=(input_shape[-1],),
            #     initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.1),
            #     trainable=True,
            # )
            self.upper_treshold = self.add_weight(
                name="upper_treshold",
                shape=(input_shape[-1],),
                initializer=tf.keras.initializers.RandomNormal(mean=1.0, stddev=0.5),
                trainable=True,
            )
            # self.upper_slope = self.add_weight(
            #     name="upper_slope",
            #     shape=(input_shape[-1],),
            #     initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.1),
            #     trainable=True,
            # )
            super(CustomActivationLayer, self).build(input_shape)

        def call(self, inputs):
            return tf.where(
                inputs > 0,
                tf.minimum(inputs, self.upper_treshold),
                tf.maximum(inputs, self.lower_treshold),
            )
            # return tf.where(
            #     inputs > 0,
            #     tf.where(
            #         inputs > self.upper_treshold,
            #         self.upper_treshold
            #         + (inputs - self.upper_treshold) * self.upper_slope,
            #         inputs,
            #     ),
            #     tf.where(
            #         inputs < self.lower_treshold,
            #         self.lower_treshold
            #         + (inputs - self.lower_treshold) * self.lower_slope,
            #         inputs,
            #     ),
            # )

    class CustomDenseLayer(tf.keras.layers.Layer):
        def __init__(self, units, **kwargs):
            super().__init__(**kwargs)
            self.dense = tf.keras.layers.Dense(
                units,
                # kernel_regularizer=tf.keras.regularizers.L1L2(l1=0.001, l2=0.001),
            )
            self.activation = CustomActivationLayer()
            # self.dropout = tf.keras.layers.Dropout(0.2)

        def call(self, inputs):
            output = self.dense(inputs)
            output = self.activation(output)
            # output = self.dropout(output)
            return output

    model = tf.keras.Sequential(
        [
            tf.keras.layers.Input(
                shape=(
                    None,
                    None,
                    8,
                )
            ),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            CustomDenseLayer(10),
            tf.keras.layers.Dense(3, activation="relu"),
        ]
    )
    model.compile(
        optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.mean_squared_error
    )
    return model


create_positional_decoder_model().summary()

In [None]:
def train_until_improvement_treshold(model, fit, threshold=0.8, max_epochs=10):
    """given a callback that trains a model, train until the loss stops improving by a threshold"""
    last_loss = float("inf")
    while max_epochs > 0:
        max_epochs -= 1
        history = fit(model)
        loss = history.history["loss"][-1]
        if loss < last_loss * threshold:
            last_loss = loss
        else:
            break

In [None]:
def select_starting_model(model_factory, fit, starting_model_count=4):
    """given a model factory and a callback that trains a model, train multiple models and select the best one"""
    models = [model_factory() for _ in range(starting_model_count)]
    loss_by_model = [fit(model).history["loss"][-1] for model in models]
    best_starting_model = models[loss_by_model.index(min(loss_by_model))]
    return best_starting_model

In [None]:
train_inputs, train_outputs = generate_data(count=1000, image_size=100)


def fit(model):
    return model.fit(train_inputs, train_outputs, epochs=1, batch_size=1)


print("find starting model")
model = select_starting_model(create_positional_decoder_model, fit)
print("train until improvement")
train_until_improvement_treshold(model, fit)

In [None]:
# visual test: train data
page_size = 10
draw_images(train_outputs[0:page_size])
draw_images(model.predict(train_inputs[0:page_size]))

In [None]:
# visual test: test data
test_inputs, test_outputs = generate_data(count=page_size, image_size=100)
draw_images(test_outputs)
draw_images(model.predict(test_inputs))

In [None]:
# visual test: test data, higher resolution
test_inputs, test_outputs = generate_data(count=page_size, image_size=400, scale=1.0)
draw_images(test_outputs)
draw_images(model.predict(test_inputs))

In [None]:
# visual test: test data, higher scale
test_inputs, test_outputs = generate_data(count=page_size, image_size=100, scale=4.0)
draw_images(test_outputs)
draw_images(model.predict(test_inputs))

In [None]:
# visual test: test data, lower scale
test_inputs, test_outputs = generate_data(count=page_size, image_size=100, scale=0.2)
draw_images(test_outputs)
draw_images(model.predict(test_inputs))

In [None]:
# visual test: test data, different offsets
for offset in [
    (0.3, 0.3),
    (-0.3, -0.3),
    (0.3, -0.3),
    (-0.3, 0.3),
    (2, 2),
    (-2, -2),
    (2, -2),
    (-2, 2),
    (10, 10),
    (-10, -10),
    (10, -10),
    (-10, 10),
]:
    offset_x, offset_y = offset
    print(f"offset: {offset_x}, {offset_y}")
    test_inputs, test_outputs = generate_data(
        count=page_size, image_size=100, scale=1, offset=offset
    )
    draw_images(test_outputs)
    draw_images(model.predict(test_inputs))