In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import Model, load_model
from keras.preprocessing.image import ImageDataGenerator, load_img

import imageio
import numpy as np
import sys
import pickle
import matplotlib.pyplot as plt

import datetime

from skimage.metrics import structural_similarity as ssim

Using Theano backend.


In [2]:
sys.path.append('/home/niaki/PycharmProjects/patch-desc-ae')
from ae_descriptor import init_descr_32, init_descr_128
import ae_descriptor
from other_descriptors.other_descriptors import compute_chen_rgb

In [3]:
base_dir = '/home/niaki/Code/sundry_segments/weights'
model_version = '0.0.0.9_lr0.0001_50moreepochs'
vae_patch_size = 56

patch_size = 65

In [4]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [5]:
latent_dim = 128

encoder_inputs = keras.Input(shape=(vae_patch_size, vae_patch_size, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 56, 56, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 64)   18496       conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 7, 7, 64)     36928       conv2d_1

In [6]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 128)]             0         
_________________________________________________________________
dense_1 (Dense)              (None, 3136)              404544    
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 64)        36928     
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 56, 56, 32)        18464     
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 56, 56, 1)         289 

In [7]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = encoder(data)
            reconstruction = decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            reconstruction_loss *= 28 * 28
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.reduce_mean(kl_loss)
            kl_loss *= -0.5
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [8]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

In [9]:
vae.load_weights(base_dir + '/vae_' + model_version + '.h5')

In [10]:
ae32_encoder = init_descr_32(vae_patch_size, vae_patch_size)

In [122]:
def calculate_psnr(img1, img2, max_value=255):
    """"Calculating peak signal-to-noise ratio (PSNR) between two images."""
    mse = np.mean((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32)) ** 2)
    if mse == 0:
        return 100
    return 20 * np.log10(max_value / (np.sqrt(mse)))

In [11]:
def calculate_ssd(img1, img2):
    """Computing the sum of squared differences (SSD) between two images."""
    if img1.shape != img2.shape:
        raise Exception("Images don't have the same shape: ", img1.shape, "and", img2.shape)
    return np.sum((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32))**2)


In [12]:
def compute_descriptor(descr, patch):
    assert patch.shape == (patch_size, patch_size), "Patch shape should be (65, 65), and not " + str(patch.shape)
    assert descr == ae32_encoder or descr == vae.encoder, "Type of descriptor not supported"
    
    patch = patch / 255.0
    patch_crop = patch[4: 60, 4: 60]
    
    if descr == ae32_encoder:
        patch_crop = np.repeat(patch_crop, 3, axis=1).reshape((patch_crop.shape[0], patch_crop.shape[1], 3))
    else:
        patch_crop = np.expand_dims(patch_crop, axis=-1)
    
    patch_crop_encoded = descr.predict(np.expand_dims(patch_crop, axis=0))
    
    if descr == vae.encoder:
        patch_crop_encoded = patch_crop_encoded[2]
    
    patch_crop_encoded_flat = patch_crop_encoded.flatten()
    return patch_crop_encoded_flat

In [28]:
image_lenna = imageio.imread('/home/niaki/Code/Lenna.png')
image_lenna = np.dot(image_lenna[ : ,  : , : 3], [0.299, 0.587, 0.114])
image_lenna = image_lenna.astype(np.uint8)
image_briefs = imageio.imread("/home/niaki/PycharmProjects/learned-brief/images/briefs_gray.bmp")
image_hpatches = imageio.imread("/home/niaki/Downloads/hpatches_patch.png")

In [43]:
def retrieve_patches_for_queries_and_descr(x_queries, y_queries, which_desc, image, random_seed=0, 
                                           noise_level=0, patch_size=65, compare_stride=8, eps=0.0001, nr_similar_patches=6):

    np.random.seed(random_seed)
#     image = imageio.imread(image_path)
#     image = np.dot(image[ : ,  : , : 3], [0.299, 0.587, 0.114])
#     image = image.astype(np.uint8)
    image_height = image.shape[0]
    image_width = image.shape[1]
    
    results_patches_diffs = {}
    results_patches_x_coords = {}
    results_patches_y_coords = {}
    results_patches_positions = {}

    counter_query_patches = 0

    total_nr_query_patches = len(x_queries)

    for query_it in range(total_nr_query_patches):

        x_query = x_queries[query_it]
        y_query = y_queries[query_it]

        sys.stdout.write("\r" + str(counter_query_patches + 1) + "/" + str(total_nr_query_patches))

        query_patch = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

        if which_desc == 0:
            query_patch_descr = compute_descriptor(ae32_encoder, query_patch)
        elif which_desc == 1:
            query_patch_descr = compute_descriptor(vae.encoder, query_patch)
        else:
            raise Exception("Wrong input for which_desc")

        counter_compare_patches = 0

        patches_diffs = [1000000000]
        patches_x_coords = [-1]
        patches_y_coords = [-1]
        patches_positions = [-1]

        for y_compare in range(0, image_width - patch_size + 1, compare_stride):
            for x_compare in range(0, image_height - patch_size + 1, compare_stride):

                compare_patch = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

                if which_desc == 0:
                    compare_patch_descr = compute_descriptor(ae32_encoder, compare_patch)
                elif which_desc == 1:
                    compare_patch_descr = compute_descriptor(vae.encoder, compare_patch)
                else:
                    raise Exception("Wrong input for which_desc")

                diff = calculate_ssd(query_patch_descr, compare_patch_descr)

                if diff < eps:
                    counter_compare_patches += 1
                    continue

                # sorting
                for i in range(len(patches_diffs)):
                    if diff < patches_diffs[i]:
                        patches_diffs.insert(i, diff)
                        patches_x_coords.insert(i, x_compare)
                        patches_y_coords.insert(i, y_compare)
                        patches_positions.insert(i, counter_compare_patches)
                        break

                counter_compare_patches += 1

        results_patches_diffs[counter_query_patches] = patches_diffs[:nr_similar_patches]
        results_patches_x_coords[counter_query_patches] = patches_x_coords[:nr_similar_patches]
        results_patches_y_coords[counter_query_patches] = patches_y_coords[:nr_similar_patches]
        results_patches_positions[counter_query_patches] = patches_positions[:nr_similar_patches]

        counter_query_patches += 1

    return results_patches_x_coords, results_patches_y_coords

In [49]:
def generate_visualisation_for_3_descrs(x_queries, y_queries, results_patches_x_coords_0, results_patches_y_coords_0,
                                        results_patches_x_coords_1, results_patches_y_coords_1,
                                        image, random_seed=0, noise_level=0,
                                        patch_size=65, nr_similar_patches=6):

    assert nr_similar_patches % 2 == 0, "If nr_similar_patches is odd, it will give an odd graph (I'll show myself out)"
    np.random.seed(random_seed)
#     image = imageio.imread(image_path)
#     image = np.dot(image[ : ,  : , : 3], [0.299, 0.587, 0.114])
#     image = image.astype(np.uint8)
    

    y_offset_under = -0.2
    font_size = 18
    x_offset_left = -2.5
    y_offset_left = 15

    fig = plt.figure(figsize=(18, 4))

    total_nr_query_patches = len(x_queries)

    columns = nr_similar_patches + 2
    rows = total_nr_query_patches * 2
    

    counter_query_patches = 0 # TODO test it with multiple query patches

    for query_it in range(total_nr_query_patches):

        x_query = x_queries[query_it]
        y_query = y_queries[query_it]
        patch_query = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

        ax = fig.add_subplot(rows // 2, columns // 2, (counter_query_patches * 2) * (nr_similar_patches + 2) + 1)
        ax.axis('off')
        # ax.set_title('query', y=y_offset_under, fontsize=font_size)  # + str(query_it + 1)
        ax.imshow(patch_query, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_0[counter_query_patches][i]
            y_compare = results_patches_y_coords_0[counter_query_patches][i]

            # psnr = calculate_psnr(image[x_query: x_query + patch_size, y_query: y_query + patch_size, :],
            #                       image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size, :],
            #                       max_value=psnr_max_value)

            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, (counter_query_patches * 2) * (nr_similar_patches + 2) + 3 + i)
            ax.axis('off')
            # if i == 0:
                # ax.text(x_offset_left, 1, 'proposed v128', rotation=90, fontsize=font_size)
                # ax.text(x_offset_left, y_offset_left, 'proposed v128', rotation=90, fontsize=font_size)  # y_offset_left
            # ax.set_title("{:.2f} [dB]".format(psnr), y=y_offset_under, fontsize=font_size)
            ax.imshow(patch_compare, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_1[counter_query_patches][i]
            y_compare = results_patches_y_coords_1[counter_query_patches][i]

            # psnr = calculate_psnr(image[x_query: x_query + patch_size, y_query: y_query + patch_size, :],
            #                       image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size, :],
            #                       max_value=psnr_max_value)

            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, ((counter_query_patches * 2) + 1) * (nr_similar_patches + 2) + 3 + i)
            ax.axis('off')
            # if i == 0:
                # ax.text(x_offset_left, y_offset_left - 2, 'Chen et al.', rotation=90, fontsize=font_size)
            # ax.set_title("{:.2f} [dB]".format(psnr), y=y_offset_under, fontsize=font_size)
            ax.imshow(patch_compare, cmap='gray')


        counter_query_patches += 1

    # fig.savefig("/home/niaki/PycharmProjects/patch-desc-ae/results/Visualisation_v128_chen_exhaustive_q_" + str(x_query) + "_" + str(y_query) + "_noise" + str(noise_level) + ".pdf", bbox_inches='tight')
    fig.savefig("/home/niaki/Downloads/Visualisation_AE32_vs_VAE0.0.0.9_" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".pdf", bbox_inches='tight')
    fig.savefig("/home/niaki/Downloads/Visualisation_AE32_vs_VAE0.0.0.9_" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".png", bbox_inches='tight')

    fig.show()

    plt.show(block=True)
    plt.interactive(False)

In [59]:
x_queries = [patch_size * 7]   # 1, 4, 7,  7,  7,  51, 51, 51, 15, 31
y_queries = [patch_size * 21]  # 1, 6, 19, 17, 21, 19, 21, 30, 29, 28

image_path = '/home/niaki/Downloads/montage.png'
image = imageio.imread(image_path)
# image = np.dot(image[ : ,  : , : 3], [0.299, 0.587, 0.114])
# image = image.astype(np.uint8)

compare_stride = 65

results_patches_x_coords_0, results_patches_y_coords_0 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 0, image, compare_stride=compare_stride)
results_patches_x_coords_1, results_patches_y_coords_1 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 1, image, compare_stride=compare_stride)

1/1

In [221]:
generate_visualisation_for_3_descrs(x_queries, y_queries, results_patches_x_coords_0, results_patches_y_coords_0,
                                        results_patches_x_coords_1, results_patches_y_coords_1,
                                        image)

NameError: name 'generate_visualisation_for_3_descrs' is not defined

In [158]:
def calculate_SSDs_for_descr(which_desc, image, patch_size=65, query_stride=65, compare_stride=65, nr_similar_patches=6, eps=0.0001):

    image_height = image.shape[0]
    image_width = image.shape[1]
    
    query_x_coords = []
    query_y_coords = []

    results_noisy_descr_patches_diffs = {}
    results_noisy_descr_patches_x_coords = {}
    results_noisy_descr_patches_y_coords = {}
    results_noisy_descr_patches_positions = {}

    counter_query_patches = 0

    # just for the sake of output
    total_nr_query_patches = len(range(0, image_width - patch_size + 1, query_stride)) * len(
        range(0, image_height - patch_size + 1, query_stride))

    for y_query in range(0, image_width - patch_size + 1, query_stride):
        for x_query in range(0, image_height - patch_size + 1, query_stride):
            sys.stdout.write("\r" + str(counter_query_patches + 1) + "/" + str(total_nr_query_patches))

            query_x_coords.append(x_query)
            query_y_coords.append(y_query)

            query_patch = image[x_query: x_query + patch_size, y_query: y_query + patch_size]
            
            if which_desc == 0:
                query_patch_descr = compute_descriptor(ae32_encoder, query_patch)
            elif which_desc == 1:
                query_patch_descr = compute_descriptor(vae.encoder, query_patch)
            else:
                raise Exception("Wrong input for which_desc")
            
            counter_compare_patches = 0

            patches_diffs = [1000000000]
            patches_x_coords = [-1]
            patches_y_coords = [-1]
            patches_positions = [-1]

            for y_compare in range(0, image_width - patch_size + 1, compare_stride):
                for x_compare in range(0, image_height - patch_size + 1, compare_stride):

                    compare_patch = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

                    if which_desc == 0:
                        compare_patch_descr = compute_descriptor(ae32_encoder, compare_patch)
                    elif which_desc == 1:
                        compare_patch_descr = compute_descriptor(vae.encoder, compare_patch)
                    else:
                        raise Exception("Wrong input for which_desc")

                    diff = calculate_ssd(query_patch_descr, compare_patch_descr)

                    if diff < eps:
                        counter_compare_patches += 1
                        continue

                    # sorting
                    for i in range(len(patches_diffs)):
                        if diff < patches_diffs[i]:
                            patches_diffs.insert(i, diff)
                            patches_x_coords.insert(i, x_compare)
                            patches_y_coords.insert(i, y_compare)
                            patches_positions.insert(i, counter_compare_patches)
                            break

                    counter_compare_patches += 1
            
                       
            results_noisy_descr_patches_diffs[counter_query_patches] = patches_diffs[:nr_similar_patches]
            results_noisy_descr_patches_x_coords[counter_query_patches] = patches_x_coords[:nr_similar_patches]
            results_noisy_descr_patches_y_coords[counter_query_patches] = patches_y_coords[:nr_similar_patches]
            results_noisy_descr_patches_positions[counter_query_patches] = patches_positions[:nr_similar_patches]

            counter_query_patches += 1

    ssds = []
    ssims = []
    psnrs=[]

    for q_it in range(total_nr_query_patches):
        for c_it in range(nr_similar_patches):

            # getting the query patch from the clean image
            x_query = query_x_coords[q_it]
            y_query = query_y_coords[q_it]
            query_patch = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

            # getting the compare patch from the clean image
            x_compare = results_noisy_descr_patches_x_coords[q_it][c_it]
            y_compare = results_noisy_descr_patches_y_coords[q_it][c_it]
            compare_patch = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            # calculating the difference in the clean image
            actual_diff = calculate_ssd(query_patch, compare_patch)
            ssds.append(actual_diff)
            
            dr_max = max(query_patch.max(), compare_patch.max())
            dr_min = min(query_patch.min(), compare_patch.min())
            diff_ssim = ssim(query_patch, compare_patch, data_range=dr_max - dr_min)
            ssims.append(diff_ssim)
            
            diff_psnr = calculate_psnr(query_patch, compare_patch, max_value=255)
            psnrs.append(diff_psnr)
            
    ssds = np.array(ssds)
    ssims = np.array(ssims)
    psnrs = np.array(psnrs)
    
    return ssds, ssims, psnrs

In [214]:
image_path = '/home/niaki/Downloads/montage.png'
image = imageio.imread(image_path)

ssds_by_model = {}
ssims_by_model = {}
psnrs_by_model = {}

which_descs = [0, 1]

query_stride = 65 * 2  # 5

for which_desc in which_descs:
    ssds_by_model[which_desc], ssims_by_model[which_desc], psnrs_by_model[which_desc] = calculate_SSDs_for_descr(which_desc, image, patch_size=65, query_stride=query_stride, compare_stride=65, nr_similar_patches=5)
    print('\n==============\n')
    
which_descs_string = "_".join(str(which_desc) for which_desc in which_descs)
ssds_by_model_file_path = '/home/niaki/Downloads/ssds__descrs_' + which_descs_string + '__querystride_' + str(query_stride) + '__' + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '.pkl'

with open(ssds_by_model_file_path, 'wb') as f:
    pickle.dump((ssds_by_model, ssims_by_model, psnrs_by_model), f)

580/580

580/580



In [None]:
# nr_similar_patches=5

In [185]:
np.mean(ssds_by_model[0])

3705725.0

In [186]:
np.mean(ssds_by_model[1])

10119623.0

In [187]:
np.mean(ssims_by_model[0])

0.32032504156651004

In [188]:
np.mean(ssims_by_model[1])

0.423992354215702

In [189]:
np.mean(psnrs_by_model[0])

20.707794206803428

In [190]:
np.mean(psnrs_by_model[1])

29.302910285213493

In [None]:
# nr_similar_patches=5, 580 query patches

In [215]:
np.mean(ssds_by_model[0])

3557544.8

In [216]:
np.mean(ssds_by_model[1])

10301963.0

In [217]:
np.mean(ssims_by_model[0])

0.31872559185122595

In [218]:
np.mean(ssims_by_model[1])

0.41613687748792816

In [219]:
np.mean(psnrs_by_model[0])

20.736298700335812

In [220]:
np.mean(psnrs_by_model[1])

28.198053697735865

In [176]:
print("test set is a matrix of patches of dimensions", np.array(image.shape) // patch_size, "so in total", np.prod(np.array(image.shape) // patch_size), "patches")

test set is a matrix of patches of dimensions [58 40] so in total 2320 patches
