### Boundary extension and contraction

This notebook contains code for exploring boundary extension and contraction in a VAE trained on the shapes3d dataset.

Tested with tensorflow 2.11.0 and Python 3.10.9.

#### Installation:

In [None]:
!pip install -r requirements.txt --upgrade

#### Imports:

In [None]:
import json
import zipfile
import os
import psutil
import numpy as np
import pandas as pd
import numpy as np
from tensorflow import keras
from PIL import Image, ImageOps
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Sequential, metrics, optimizers, layers
from tensorflow.python.framework.ops import disable_eager_execution
from utils import load_tfds_dataset
from generative_model import encoder_network_large, decoder_network_large, VAE
tf.keras.utils.set_random_seed(123)

In [None]:
train_ds, test_ds, train_labels, test_labels = load_tfds_dataset('shapes3d', labels=True, 
                                                                 key_dict= {'shapes3d': 'label_scale'})
train_ds = train_ds / 255
test_ds = test_ds / 255

#### Filter to just objects of mean size

We can't use the object_size attribute directly because this is not equivalent to a 'close-up' or 'far away' view - the background is still the same scale.

In [None]:
test_ds = [img for img, label in zip(test_ds, test_labels) if label in [4,5]]
test_ds = np.array(test_ds)
test_ds.shape

#### Load the trained VAE:

In [None]:
K.set_image_data_format('channels_last')

latent_dim = 20
input_shape = (64, 64, 3)

encoder, z_mean, z_log_var = encoder_network_large(input_shape, latent_dim)
decoder = decoder_network_large(latent_dim)

encoder.load_weights("model_weights/shapes3d_encoder.h5")
decoder.load_weights("model_weights/shapes3d_decoder.h5")

#### Test boundary extension / contraction:

In [None]:
def add_noise(im_as_array):
    img = Image.fromarray((im_as_array*255).astype(np.uint8))
    gaussian = np.random.normal(0, 30, (img.size[0],img.size[1], 3))
    noisy_img = img + gaussian
    return np.clip(np.array(noisy_img), 0, 255) / 255

def display_recalled(x_test_new, decoded_imgs, n=10):
    plt.figure(figsize=(n*2, 4))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(x_test_new[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + n + 1)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

code = Model(encoder.input, encoder.get_layer('mean').output)

x_test_new = np.array([add_noise(image) for image in train_ds[0:20]])
encoded_imgs = code.predict(x_test_new)
decoded_imgs = decoder.predict(encoded_imgs)

display_recalled(x_test_new, decoded_imgs)

In [None]:
def remove_border(im_as_array, border_width=5):
    img = Image.fromarray((im_as_array*255).astype(np.uint8))
    im_crop = ImageOps.crop(img, border=border_width)
    new_im = im_crop.resize((64,64))
    return np.array(new_im) / 255

x_test_new = np.array([remove_border(image) for image in train_ds[0:20]])
encoded_imgs = code.predict(x_test_new)
decoded_imgs = decoder.predict(encoded_imgs)

display_recalled(x_test_new, decoded_imgs)

In [None]:
def add_border(img, border_width=5):
    img = np.pad(img*255, pad_width=((border_width,border_width),
                                     (border_width,border_width),
                                     (0,0)), mode='edge')
    img = Image.fromarray(img.astype(np.uint8))
    img = img.resize((64,64))
    return np.array(img)/255

x_test_new = np.array([add_border(image) for image in train_ds[0:20]])
encoded_imgs = code.predict(x_test_new)
decoded_imgs = decoder.predict(encoded_imgs)

display_recalled(x_test_new, decoded_imgs)

#### Plot boundary extension / contraction effects and prediction error

In [None]:
border_const = 5

def plot_zoom_rows(ind):
    x_test_new_remove = np.array([add_noise(remove_border(train_ds[ind], border_width=border_const*i)) for i in range(5)])
    x_test_new_no_noise_remove = np.array([remove_border(train_ds[ind], border_width=border_const*i) for i in range(5)])
    encoded_imgs = code.predict(x_test_new_remove)
    decoded_imgs_remove_border = decoder.predict(encoded_imgs)

    recons_remove = tf.reduce_sum(keras.losses.mean_absolute_error(x_test_new_no_noise_remove, decoded_imgs_remove_border), axis=(1,2)).numpy().tolist()

    x_test_new_add = np.array([add_noise(add_border(train_ds[ind], border_width=border_const*i)) for i in range(5)])
    x_test_new_no_noise_add = np.array([add_border(train_ds[ind], border_width=border_const*i) for i in range(5)])
    encoded_imgs = code.predict(x_test_new_add)
    decoded_imgs_add_border = decoder.predict(encoded_imgs)

    display_recalled(x_test_new_add[::-1].tolist() + x_test_new_remove.tolist()[1:], 
                     decoded_imgs_add_border[::-1].tolist() + decoded_imgs_remove_border.tolist()[1:], n=9)

    recons_add = tf.reduce_sum(keras.losses.mean_absolute_error(x_test_new_no_noise_add, decoded_imgs_add_border), axis=(1,2)).numpy().tolist()
    plt.figure(figsize=(18, 2))
    plt.bar([str(border_const*i) for i in range(5)][::-1] + [str(-border_const*i) for i in range(5)][1:], recons_add[::-1] + recons_remove[1:])
    plt.xlabel('Margin change (in pixels)', size=12)
    plt.ylabel('Prediction error', size=12)
    plt.show()

In [None]:
for i in range (0,20):
    plot_zoom_rows(i)

#### Plots for paper

We want to plot a 'zoomed out' and a 'zoomed in' view, where the zoomed out view halves the central object size, and the zoomed in view doubles the central object size

We have: ratio = shape_width_after / shape_width_before = width / (width + 2 * margin)
This means margin = (32 / ratio) - 32
So the margin for a ratio of 0.8 is 8, and for a ratio of 1.2 is - 5.33.

In [None]:
border_const = 5

def plot_zoom_rows(ind):
    x_test_new_remove = np.array([add_noise(remove_border(train_ds[ind], border_width=5.33*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_remove)
    decoded_imgs_remove_border = decoder.predict(encoded_imgs)

    x_test_new_add = np.array([add_noise(add_border(train_ds[ind], border_width=8*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_add)
    decoded_imgs_add_border = decoder.predict(encoded_imgs)

    display_recalled(x_test_new_add[::-1].tolist() + x_test_new_remove.tolist()[1:], 
                     decoded_imgs_add_border[::-1].tolist() + decoded_imgs_remove_border.tolist()[1:], n=3)


In [None]:
for i in range (0,50):
    plot_zoom_rows(i)