### Boundary extension and contraction

This notebook contains code for exploring boundary extension and contraction in a VAE trained on the shapes3d dataset.

Tested with tensorflow 2.11.0 and Python 3.10.9.

#### Installation:

In [None]:
!pip install -r requirements.txt --upgrade

#### Imports:

In [None]:
import json
import zipfile
import os
import psutil
import numpy as np
import pandas as pd
import numpy as np
import cv2
from tensorflow import keras
from PIL import Image, ImageOps
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Sequential, metrics, optimizers, layers
from tensorflow.python.framework.ops import disable_eager_execution
from utils import load_tfds_dataset
from generative_model import encoder_network_large, decoder_network_large, VAE
tf.keras.utils.set_random_seed(123)

#### Load data

In [None]:
train_ds, test_ds, train_labels, test_labels = load_tfds_dataset('shapes3d', labels=True, 
                                                                 key_dict= {'shapes3d': 'label_scale'})
train_ds = train_ds / 255
test_ds = test_ds / 255

#### Filter to just objects of mean size

We can't use the object_size attribute directly because this is not equivalent to a 'close-up' or 'far away' view - the background is still the same scale.

In [None]:
test_ds = [img for img, label in zip(test_ds, test_labels) if label in [4,5]]
test_ds = np.array(test_ds)
test_ds.shape

#### Load the trained VAE:

In [None]:
K.set_image_data_format('channels_last')

latent_dim = 20
input_shape = (64, 64, 3)

encoder, z_mean, z_log_var = encoder_network_large(input_shape, latent_dim)
decoder = decoder_network_large(latent_dim)

encoder.load_weights("model_weights/shapes3d_encoder.h5")
decoder.load_weights("model_weights/shapes3d_decoder.h5")

#### Test boundary extension / contraction:

In [None]:
def remove_border(im_as_array, border_width=5):
    img = Image.fromarray((im_as_array*255).astype(np.uint8))
    im_crop = ImageOps.crop(img, border=border_width)
    new_im = im_crop.resize((64,64))
    return np.array(new_im) / 255

def add_border(img, border_width=5):
    img = np.pad(img*255, pad_width=((border_width,border_width),
                                     (border_width,border_width),
                                     (0,0)), mode='edge')
    img = Image.fromarray(img.astype(np.uint8))
    img = img.resize((64,64))
    return np.array(img)/255

def add_noise(im_as_array):
    img = Image.fromarray((im_as_array*255).astype(np.uint8))
    gaussian = np.random.normal(0, 30, (img.size[0],img.size[1], 3))
    noisy_img = img + gaussian
    return np.clip(np.array(noisy_img), 0, 255) / 255

def display_recalled(x_test_new, decoded_imgs, n=10):
    plt.figure(figsize=(n*2, 4))
    for i in range(n):
        # display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(x_test_new[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(2, n, i + n + 1)
        plt.imshow(decoded_imgs[i])
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

code = Model(encoder.input, encoder.get_layer('mean').output)

x_test_new = np.array([add_noise(image) for image in train_ds[0:20]])
encoded_imgs = code.predict(x_test_new)
decoded_imgs = decoder.predict(encoded_imgs)

display_recalled(x_test_new, decoded_imgs)

#### Plots

In [None]:
border_const = 5

def plot_zoom_rows(ind):
    x_test_new_remove = np.array([add_noise(remove_border(train_ds[ind], border_width=5.33*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_remove)
    decoded_imgs_remove_border = decoder.predict(encoded_imgs)

    x_test_new_add = np.array([add_noise(add_border(train_ds[ind], border_width=8*i)) for i in range(2)])
    encoded_imgs = code.predict(x_test_new_add)
    decoded_imgs_add_border = decoder.predict(encoded_imgs)

    display_recalled(x_test_new_add[::-1].tolist() + x_test_new_remove.tolist()[1:], 
                     decoded_imgs_add_border[::-1].tolist() + decoded_imgs_remove_border.tolist()[1:], n=3)

for i in range (0,50):
    plot_zoom_rows(i)

#### Measure change in object size

In [None]:
def segment_image(image):
    # Reshape the image to be a list of RGB pixels and convert to float32
    pixels = image.reshape(-1, 3).astype(np.float32)
    
    # Define criteria and apply kmeans()
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
    k = 4  # Change this to the number of color blocks you want
    _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    
    # Convert back to 8 bit and reshape to original image shape
    segmented_image = centers[labels.flatten()].reshape(image.shape).astype(np.uint8)

    return segmented_image

In [None]:
def measure_object_height(image):
    # get the index of the middle column
    mid_col_idx = image.shape[1] // 2
    mid_col = image[:, mid_col_idx]

    # list to hold colors and their counts
    colors = []
    counts = []

    # iterate over the mid_col
    for color in mid_col:
        # convert color array to a tuple so it can be used in a list
        color_t = tuple(color)
        if color_t in colors:
            # if color is already in the list, increment its count
            idx = colors.index(color_t)
            counts[idx] += 1
        else:
            # if color is new, add it to the list and start its count
            colors.append(color_t)
            counts.append(1)

    # calculate the proportions
    proportions = [count / len(mid_col) for count in counts]

    # print and return proportions
    print(proportions)
    return sum(proportions[1:-1])

In [None]:
zoom_levels = range(80, 121, 5)
size_changes_dict = {zoom: [] for zoom in zoom_levels}

def get_size_change_and_zoom(ind):
    changes = []

    for zoom in zoom_levels:
        # We have: ratio = shape_width_after / shape_width_before = width / (width + 2 * margin)
        # This means margin = (32 / ratio) - 32
        # E.g. the margin to add for a zoom percentage of 80% (i.e. ratio of 0.8) is 8 pixels
        border_width = abs(int((32 / (zoom / 100)) - 32))
        if zoom < 100:
            image = add_border(train_ds[ind], border_width=border_width)
        elif zoom > 100:
            image = remove_border(train_ds[ind], border_width=border_width)
        else:
            image = train_ds[ind]
        
        encoded_img = code.predict(np.array([image]))
        decoded_img = decoder.predict(encoded_img)[0]

        input_height = measure_object_height(segment_image(image*255))
        output_height = measure_object_height(segment_image(decoded_img*255))

        if input_height != 0:
            change = (output_height - input_height) / input_height * 100  # Percentage change
            # change = output_height / input_height * 100  # Percentage change
        else:
            change = 0
        changes.append(change)
        
    return changes, zoom_levels

size_changes = []
zoom_changes = []

for i in range(100):
    changes, zoom_levels = get_size_change_and_zoom(i)
    size_changes.extend(changes)
    zoom_changes.extend(zoom_levels)

# Separate size changes by zoom level
for size_change, zoom_level in zip(size_changes, zoom_changes):
    size_changes_dict[zoom_level].append(size_change)

# Calculate means and standard deviations
means = [np.mean(size_changes_dict[zoom]) for zoom in zoom_levels]
std_devs = [np.std(size_changes_dict[zoom]) for zoom in zoom_levels]
sems = [np.std(size_changes_dict[zoom]) / np.sqrt(len(size_changes_dict[zoom])) for zoom in zoom_levels]


In [None]:
# # Plot line chart with error bars
# plt.figure(figsize=(8,6))
# plt.errorbar(zoom_levels, means, yerr=sems, fmt='-o', capsize=5)
# plt.xticks([80, 90, 100, 110, 120], fontsize=18)
# plt.yticks(fontsize=18)
# plt.xlabel('Zoom Level (%)', fontsize=18)
# plt.ylabel('Object size (output/input as %)', fontsize=18)
# plt.savefig('BE.png')
# plt.show()

In [None]:
# Create bar chart with error bars
plt.figure(figsize=(8,6))
plt.bar(zoom_levels, means, yerr=sems, capsize=5, width=3.5)
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlabel('Zoom Level (%)', fontsize=18)
plt.ylabel('Change in object size (%)', fontsize=18)
plt.axhline(y=0, color='black', linewidth=0.8) 
ax=plt.gca()
ax.invert_xaxis()
plt.savefig('BE.png')
plt.show()
