### Boundary extension and contraction

This notebook contains code for exploring boundary extension and contraction in a VAE trained on the shapes3d dataset.

Tested with tensorflow 2.11.0 and Python 3.10.9.

#### Installation:

In [None]:
!pip install -r requirements.txt --upgrade

#### Imports:

In [None]:
import json
import zipfile
import os
import psutil
import numpy as np
import pandas as pd
import numpy as np
import cv2
from tensorflow import keras
from PIL import Image, ImageOps
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Sequential, metrics, optimizers, layers
from tensorflow.python.framework.ops import disable_eager_execution
from utils import load_tfds_dataset
from generative_model import encoder_network_large, decoder_network_large, VAE
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split
import numpy as np

tf.keras.utils.set_random_seed(123)
np.random.seed(0)

#### Load data

In [None]:
def load_tfds_dataset(dataset_name):
    # Load the dataset
    ds, ds_info = tfds.load(dataset_name, split='train', with_info=True, as_supervised=False)

    # Convert to numpy arrays
    images = []
    label_scales = []
    label_shapes = []
    label_object_hues = []
    label_floor_hues = []
    label_wall_hues = []

    counter=0

    for item in tfds.as_numpy(ds):
        images.append(item['image'])
        label_scales.append(item['label_scale']),
        label_shapes.append(item['label_shape'])
        label_object_hues.append(item['label_object_hue'])
        label_floor_hues.append(item['label_floor_hue'])
        label_wall_hues.append(item['label_wall_hue'])
        
        counter += 1
        if counter >= 40000:
            break

    
    return np.array(images), np.array(label_scales), np.array(label_shapes), np.array(label_object_hues), np.array(label_floor_hues), np.array(label_wall_hues)

# Load the dataset
images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues = load_tfds_dataset('shapes3d')

# Normalize the images
images = images / 255.0

# Split the dataset into training and testing sets
train_images, test_images, train_label_scales, test_label_scales, train_label_shapes, test_label_shapes, train_label_object_hues, test_label_object_hues, train_label_floor_hues, test_label_floor_hues, train_label_wall_hues, test_label_wall_hues = train_test_split(
    images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues, test_size=0.1, random_state=42
)

# Function to filter images based on shape and distinct colors
def filter_images(images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues):
    filtered_images = []
    for img, scale, shape, obj_hue, floor_hue, wall_hue in zip(images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues):
        if shape in [0] and obj_hue != floor_hue and obj_hue != wall_hue:
            filtered_images.append(img)
    return np.array(filtered_images)

# Apply filters to training and testing datasets
filtered_train_images = filter_images(train_images, train_label_scales, train_label_shapes, train_label_object_hues, train_label_floor_hues, train_label_wall_hues)
filtered_test_images = filter_images(test_images, test_label_scales, test_label_shapes, test_label_object_hues, test_label_floor_hues, test_label_wall_hues)

print("Filtered Training Images Shape:", filtered_train_images.shape)
print("Filtered Testing Images Shape:", filtered_test_images.shape)

test_ds = filtered_test_images

In [None]:
np.random.shuffle(test_ds)

#### Load the trained VAE:

In [None]:
K.set_image_data_format('channels_last')

latent_dim = 20
input_shape = (64, 64, 3)

encoder, z_mean, z_log_var = encoder_network_large(input_shape, latent_dim)
decoder = decoder_network_large(latent_dim)

encoder.load_weights("model_weights/shapes3d_encoder.h5")
decoder.load_weights("model_weights/shapes3d_decoder.h5")

#### Test boundary extension / contraction:

In [None]:
def remove_border(im_as_array, border_width=5):
    img = Image.fromarray((im_as_array*255).astype(np.uint8))
    im_crop = ImageOps.crop(img, border=border_width)
    new_im = im_crop.resize((64,64))
    return np.array(new_im) / 255

def add_border(img, border_width=5):
    img = np.pad(img*255, pad_width=((border_width,border_width),
                                     (border_width,border_width),
                                     (0,0)), mode='edge')
    img = Image.fromarray(img.astype(np.uint8))
    img = img.resize((64,64))
    return np.array(img)/255

def add_noise(array, noise_factor=0.1, seed=None, gaussian=False, replacement_val=0):
    # Replace a fraction noise_factor of pixels with replacement_val or gaussian noise
    if seed is not None:
        np.random.seed(seed)
    shape = array.shape
    array = array.flatten()
    indices = np.random.choice(np.arange(array.size), replace=False,
                               size=int(array.size * noise_factor))
    if gaussian is True:
        array[indices] = np.random.normal(loc=0.5, scale=1.0, size=array[indices].shape)
    else:
        array[indices] = replacement_val
    array = array.reshape(shape)
    return np.clip(array, 0.0, 1.0)

def display_recalled(x_test_new, decoded_imgs, n=10):
    plt.figure(figsize=(n*2, 4))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(x_test_new[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + n + 1)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

code = Model(encoder.input, encoder.get_layer('mean').output)

x_test_new = np.array([add_noise(image) for image in test_ds[0:20]])
encoded_imgs = code.predict(x_test_new)
decoded_imgs = decoder.predict(encoded_imgs)

#### Plots

In [None]:
border_const = 5

def plot_zoom_rows(ind):
    x_test_new_remove = np.array([add_noise(remove_border(test_ds[ind], border_width=5.33*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_remove)
    decoded_imgs_remove_border = decoder.predict(encoded_imgs)

    x_test_new_add = np.array([add_noise(add_border(test_ds[ind], border_width=8*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_add)
    decoded_imgs_add_border = decoder.predict(encoded_imgs)

    display_recalled(x_test_new_add[::-1].tolist() + x_test_new_remove.tolist()[1:], 
                     decoded_imgs_add_border[::-1].tolist() + decoded_imgs_remove_border.tolist()[1:], n=3)

for i in range (0,10):
    plot_zoom_rows(i)

#### Measure change in object size

In [None]:
def segment_image(image, k=5):
    # reshape the image to be a list of RGB pixels and convert to float32
    pixels = image.reshape(-1, 3).astype(np.float32)
    
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
    _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    
    # reshape to original image shape
    segmented_image = centers[labels.flatten()].reshape(image.shape).astype(np.uint8)

    return segmented_image

In [None]:
import numpy as np

def measure_object_height(image):
    # Define the area of interest for detecting the object's color
    mid_col_idx = image.shape[1] // 2
    start_row_idx = int(image.shape[0] * 0.5)
    end_row_idx = int(image.shape[0] * 0.8)

    # Extract the middle column in the area of interest
    detection_col = image[start_row_idx:end_row_idx, mid_col_idx]

    # Determine the color of the central object
    flat_detection_col = detection_col.reshape(-1, 3)
    colors, counts = np.unique(flat_detection_col, axis=0, return_counts=True)
    object_color = colors[counts.argmax()]

    # Define a threshold for color similarity
    color_threshold = 0.5

    # Extract the entire middle column
    full_mid_col = image[:, mid_col_idx]

    # Calculate the color difference for each pixel in the full middle column
    color_diff = np.sqrt(np.sum((full_mid_col - object_color) ** 2, axis=1))
    is_object_color = color_diff < color_threshold

    # Count the pixels of the object color in the full middle column
    object_height = np.sum(is_object_color)

    return object_height


In [None]:
# # visualise some examples to check the functions above

# zoom_levels = [80, 100, 120]
# for zoom in zoom_levels:
#     image = test_ds[4]
#     print(zoom)
#     border_width = abs(int((32 / (zoom / 100)) - 32))
#     if zoom < 100:
#         image = add_border(image, border_width=border_width)
#         print(f"Adding border of {border_width}")
#     elif zoom > 100:
#         image = remove_border(image, border_width=border_width)
#         print(f"Removing border of {border_width}")
    
#     # Add noise and use autoencoder
#     encoded_img = code.predict(np.array([add_noise(image, noise_factor=0.0)]))
#     decoded_img = decoder.predict(encoded_img)[0]
    
#     # Segment images
#     segmented_input = segment_image(image * 255)
#     segmented_output = segment_image(decoded_img * 255)

#     # Measure object height
#     input_height = measure_object_height(segmented_input)
#     output_height = measure_object_height(segmented_output)
#     print(input_height, output_height)

#     # Plotting the images
#     fig, axs = plt.subplots(2, 2, figsize=(3, 3))
#     axs[0, 0].imshow(image)
#     axs[0, 0].set_title('Input Image')
#     axs[0, 1].imshow(segmented_input)
#     axs[0, 1].set_title('Segmented In')
#     axs[1, 0].imshow(decoded_img)
#     axs[1, 0].set_title('Output Image')
#     axs[1, 1].imshow(segmented_output)
#     axs[1, 1].set_title('Segmented Out')
    
#     for ax in axs.flat:
#         ax.axis('off')

#     plt.show()


In [None]:
zoom_levels = range(80, 121, 5)
size_changes_dict = {zoom: [] for zoom in zoom_levels}

def get_size_change_and_zoom(ind):
    changes = []

    for zoom in zoom_levels:
        # margin = (32 / ratio) - 32
        # E.g. the margin to add for a zoom percentage of 80% (i.e. ratio of 0.8) is 8 pixels
        border_width = abs(int((32 / (zoom / 100)) - 32))
        image = test_ds[ind]
        if zoom < 100:
            image = add_border(image, border_width=border_width)
        elif zoom > 100:
            image = remove_border(image, border_width=border_width)
        
        encoded_img = code.predict(np.array([add_noise(image, noise_factor=0.1)]))
        decoded_img = decoder.predict(encoded_img)[0]

        input_height = measure_object_height(segment_image(image*255))
        output_height = measure_object_height(segment_image(decoded_img*255))
        print(input_height, output_height)

        if input_height != 0:
            change = (output_height - input_height) / input_height
        else:
            change = 0
        changes.append(change)
        
    return changes, zoom_levels

size_changes = []
zoom_changes = []

for i in range(500):
    changes, zoom_levels = get_size_change_and_zoom(i)
    size_changes.extend(changes)
    zoom_changes.extend(zoom_levels)

# Separate size changes by zoom level
for size_change, zoom_level in zip(size_changes, zoom_changes):
    size_changes_dict[zoom_level].append(size_change)

In [None]:
# Calculate means and standard deviations
means = [np.mean(size_changes_dict[zoom]) for zoom in zoom_levels]
std_devs = [np.std(size_changes_dict[zoom]) for zoom in zoom_levels]
sems = [np.std(size_changes_dict[zoom]) / np.sqrt(len(size_changes_dict[zoom])) for zoom in zoom_levels]

# Create bar chart with error bars
plt.figure(figsize=(8,6))
plt.bar(zoom_levels, means, yerr=sems, capsize=5, width=3.5)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Zoom Level (%)', fontsize=18)
plt.ylabel('Change in object size', fontsize=18)
plt.axhline(y=0, color='black', linewidth=0.8) 
ax=plt.gca()
ax.invert_xaxis()
plt.savefig('BE.png')
plt.show()


In [None]:
data = [size_changes_dict[zoom] for zoom in zoom_levels][0:200]
# Create figure and axis
plt.figure(figsize=(8, 6))

positions = np.arange(len(data))
width = 0.9 

plt.boxplot(data, positions=positions, widths=width, patch_artist=True, 
            boxprops=dict(facecolor='blue', alpha=0.5), medianprops=dict(color='black'),
           showfliers=False, whis=[10,90])

plt.xticks(positions, zoom_levels, fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Zoom Level (%)', fontsize=18)
plt.ylabel('Change in object size', fontsize=18)
plt.axhline(y=0, color='black', linewidth=0.8)

ax = plt.gca()
ax.invert_xaxis()

plt.savefig('BE_boxplot.pdf')
plt.show()
