In [None]:
# use with /venvs/torch

In [None]:
import numpy as np
import sys
import pickle
import matplotlib.pyplot as plt
import matplotlib
import datetime
import imageio
import os
import torch

In [None]:
import sys
sys.path.append('/home/niaki/PycharmProjects/local-img-descr-ae')

import models.ae as ae
import models.vae as vae
import models.ae_ir as ae_ir
import models.vae_ir as vae_ir

In [None]:
weights_base_dir = '/home/niaki/PycharmProjects/local-img-descr-ae/weights/the_final_four'
tar_filename = 'best.pth.tar'
weights_path_ae = os.path.join(weights_base_dir, 'weights_20210318_124153_ae', tar_filename)
weights_path_vae = os.path.join(weights_base_dir, 'weights_20210311_083736_vae', tar_filename)
weights_path_ae_ir = os.path.join(weights_base_dir, 'weights_20210427_121555_ae_ir', tar_filename)
weights_path_vae_ir = os.path.join(weights_base_dir, 'weights_20210428_101801_vae_ir', tar_filename)

visualisations_dir = '/home/niaki/Downloads/Visualisations_AE_VAE_noIR_IR/sweep/'

In [None]:
patch_size = 65
nr_similar_patches = 5
eps = 0.0001

In [None]:
models = []

model_ae = ae.AE(32)
model_ae.load_state_dict(torch.load(weights_path_ae, map_location=torch.device('cpu'))['state_dict'])
models.append(model_ae)

model_vae = vae.BetaVAE(32)
model_vae.load_state_dict(torch.load(weights_path_vae, map_location=torch.device('cpu'))['state_dict'])
models.append(model_vae)

model_ae_ir = ae_ir.AE_IR(32)
model_ae_ir.load_state_dict(torch.load(weights_path_ae_ir, map_location=torch.device('cpu'))['state_dict'])
models.append(model_ae_ir)

model_vae_ir = vae_ir.BetaVAE(32)
model_vae_ir.load_state_dict(torch.load(weights_path_vae_ir, map_location=torch.device('cpu'))['state_dict'])
models.append(model_vae_ir)

for model in models:
    model = model.cpu()
    model.eval()

In [None]:
def calculate_ssd(img1, img2):
    """Computing the sum of squared differences (SSD) between two images."""
    if img1.shape != img2.shape:
        raise Exception("Images don't have the same shape: ", img1.shape, "and", img2.shape)
    return np.sum((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32))**2)

In [None]:
def compute_descriptor(descr, patch):
    variational = isinstance(descr, vae.BetaVAE) or isinstance(descr, vae_ir.BetaVAE)
    patch = np.array(patch)
    patch = patch / 255.0
    patch = np.expand_dims(np.expand_dims(patch, axis=0), axis=0)
    patch = torch.from_numpy(patch).float()
    if variational:
        patch_encoding, _, _ = descr.encode(patch)
    else:
        patch_encoding = descr.encode(patch)
    patch_encoding = patch_encoding.detach().numpy()
    patch_encoding = patch_encoding.reshape(patch_encoding.shape[0], np.product(patch_encoding.shape[1:]))
    return patch_encoding[0]

In [None]:
def retrieve_patches_for_queries_and_descr(x_queries, y_queries, which_desc, image, random_seed=0, 
                                           noise_level=0, patch_size=65, compare_stride=8, eps=0.0001, 
                                           nr_similar_patches=6):

    np.random.seed(random_seed)
    image_height = image.shape[0]
    image_width = image.shape[1]
    
    results_patches_diffs = {}
    results_patches_x_coords = {}
    results_patches_y_coords = {}
    results_patches_positions = {}

    counter_query_patches = 0

    total_nr_query_patches = len(x_queries)

    for query_it in range(total_nr_query_patches):

        x_query = x_queries[query_it]
        y_query = y_queries[query_it]

        sys.stdout.write("\r" + str(counter_query_patches + 1) + "/" + str(total_nr_query_patches))

        query_patch = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

        if which_desc > 3:
            raise Exception("Wrong input for which_desc")
        query_patch_descr = compute_descriptor(models[which_desc], query_patch)
        

        counter_compare_patches = 0

        patches_diffs = [1000000000]
        patches_x_coords = [-1]
        patches_y_coords = [-1]
        patches_positions = [-1]

        for y_compare in range(0, image_width - patch_size + 1, compare_stride):
            for x_compare in range(0, image_height - patch_size + 1, compare_stride):

                compare_patch = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

                compare_patch_descr = compute_descriptor(models[which_desc], compare_patch)

                diff = calculate_ssd(query_patch_descr, compare_patch_descr)

                if diff < eps or (x_compare == x_query and y_compare == y_query):
                    counter_compare_patches += 1
                    continue

                # sorting
                for i in range(len(patches_diffs)):
                    if diff < patches_diffs[i]:
                        patches_diffs.insert(i, diff)
                        patches_x_coords.insert(i, x_compare)
                        patches_y_coords.insert(i, y_compare)
                        patches_positions.insert(i, counter_compare_patches)
                        break

                counter_compare_patches += 1

        results_patches_diffs[counter_query_patches] = patches_diffs[:nr_similar_patches]
        results_patches_x_coords[counter_query_patches] = patches_x_coords[:nr_similar_patches]
        results_patches_y_coords[counter_query_patches] = patches_y_coords[:nr_similar_patches]
        results_patches_positions[counter_query_patches] = patches_positions[:nr_similar_patches]

        counter_query_patches += 1

    return results_patches_x_coords, results_patches_y_coords

In [None]:
def generate_visualisation_for_3_descrs(x_queries, y_queries, results_patches_x_coords_0, results_patches_y_coords_0,
                                        results_patches_x_coords_1, results_patches_y_coords_1,
                                        image, random_seed=0, noise_level=0,
                                        patch_size=65, nr_similar_patches=6):

    assert nr_similar_patches % 2 == 0, "If nr_similar_patches is odd, it will give an odd graph (I'll show myself out)"
    np.random.seed(random_seed)
    
    y_offset_under = -0.2
    font_size = 18
    x_offset_left = -2.5
    y_offset_left = 15

    fig = plt.figure(figsize=(18, 4))

    total_nr_query_patches = len(x_queries)

    columns = nr_similar_patches + 2
    rows = total_nr_query_patches * 2
    

    counter_query_patches = 0 # TODO test it with multiple query patches

    for query_it in range(total_nr_query_patches):

        x_query = x_queries[query_it]
        y_query = y_queries[query_it]
        patch_query = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

        ax = fig.add_subplot(rows // 2, columns // 2, (counter_query_patches * 2) * (nr_similar_patches + 2) + 1)
        ax.axis('off')
        # ax.set_title('query', y=y_offset_under, fontsize=font_size)  # + str(query_it + 1)
        ax.imshow(patch_query, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_0[counter_query_patches][i]
            y_compare = results_patches_y_coords_0[counter_query_patches][i]

            # psnr = calculate_psnr(image[x_query: x_query + patch_size, y_query: y_query + patch_size, :],
            #                       image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size, :],
            #                       max_value=psnr_max_value)

            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, (counter_query_patches * 2) * (nr_similar_patches + 2) + 3 + i)
            ax.axis('off')
            # if i == 0:
                # ax.text(x_offset_left, 1, 'proposed v128', rotation=90, fontsize=font_size)
                # ax.text(x_offset_left, y_offset_left, 'proposed v128', rotation=90, fontsize=font_size)  # y_offset_left
            # ax.set_title("{:.2f} [dB]".format(psnr), y=y_offset_under, fontsize=font_size)
            ax.imshow(patch_compare, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_1[counter_query_patches][i]
            y_compare = results_patches_y_coords_1[counter_query_patches][i]

            # psnr = calculate_psnr(image[x_query: x_query + patch_size, y_query: y_query + patch_size, :],
            #                       image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size, :],
            #                       max_value=psnr_max_value)

            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, ((counter_query_patches * 2) + 1) * (nr_similar_patches + 2) + 3 + i)
            ax.axis('off')
            # if i == 0:
                # ax.text(x_offset_left, y_offset_left - 2, 'Chen et al.', rotation=90, fontsize=font_size)
            # ax.set_title("{:.2f} [dB]".format(psnr), y=y_offset_under, fontsize=font_size)
            ax.imshow(patch_compare, cmap='gray')


        counter_query_patches += 1

    # fig.savefig("/home/niaki/PycharmProjects/patch-desc-ae/results/Visualisation_v128_chen_exhaustive_q_" + str(x_query) + "_" + str(y_query) + "_noise" + str(noise_level) + ".pdf", bbox_inches='tight')
    fig.savefig("/home/niaki/Downloads/Visualisation_AE_VAE_noIR_IR__" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".pdf", bbox_inches='tight')
    fig.savefig("/home/niaki/Downloads/Visualisation_AE_VAE_noIR_IR__" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".png", bbox_inches='tight')

    fig.show()

    plt.show(block=True)
    plt.interactive(False)

In [None]:
def generate_visualisation_for_4_descrs(x_queries, y_queries, 
                                        results_patches_x_coords_0, results_patches_y_coords_0,
                                        results_patches_x_coords_1, results_patches_y_coords_1,
                                        results_patches_x_coords_2, results_patches_y_coords_2,
                                        results_patches_x_coords_3, results_patches_y_coords_3,
                                        image, random_seed=0, noise_level=0,
                                        patch_size=65, nr_similar_patches=6):

    assert nr_similar_patches % 2 == 0, \
        "If nr_similar_patches is odd, it will give an odd graph (I'll show myself out)"
    np.random.seed(random_seed)
    
    y_offset_under = -0.2
    font_size = 18
    x_offset_left = -2.5
    y_offset_left = 15

    fig = plt.figure(figsize=(12, 4))

    total_nr_query_patches = len(x_queries)
    
    num_descrs = 4

    columns = nr_similar_patches + num_descrs
    rows = total_nr_query_patches * num_descrs
    
    counter_query_patches = 0 # TODO test it with multiple query patches

    for query_it in range(total_nr_query_patches):

        x_query = x_queries[query_it]
        y_query = y_queries[query_it]
        patch_query = image[x_query: x_query + patch_size, y_query: y_query + patch_size]

        ax = fig.add_subplot(rows // 4, columns // 4, (counter_query_patches * num_descrs) * columns + 1)
        ax.axis('off')
        ax.imshow(patch_query, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_0[counter_query_patches][i]
            y_compare = results_patches_y_coords_0[counter_query_patches][i]
            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, 
                                 ((counter_query_patches * num_descrs) + 0) * columns + 1 + num_descrs + i)
            ax.axis('off')
            ax.imshow(patch_compare, cmap='gray')

        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_1[counter_query_patches][i]
            y_compare = results_patches_y_coords_1[counter_query_patches][i]
            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, 
                                 ((counter_query_patches * num_descrs) + 1) * columns + 1 + num_descrs + i)
            ax.axis('off')
            ax.imshow(patch_compare, cmap='gray')
            
        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_2[counter_query_patches][i]
            y_compare = results_patches_y_coords_2[counter_query_patches][i]
            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, 
                                 ((counter_query_patches * num_descrs) + 2) * columns + 1 + num_descrs + i)
            ax.axis('off')
            ax.imshow(patch_compare, cmap='gray')
        
        for i in range(nr_similar_patches):
            x_compare = results_patches_x_coords_3[counter_query_patches][i]
            y_compare = results_patches_y_coords_3[counter_query_patches][i]
            patch_compare = image[x_compare: x_compare + patch_size, y_compare: y_compare + patch_size]

            ax = fig.add_subplot(rows, columns, 
                                 ((counter_query_patches * num_descrs) + 3) * columns + 1 + num_descrs + i)
            ax.axis('off')
            ax.imshow(patch_compare, cmap='gray')


        counter_query_patches += 1

    
    fig.savefig(visualisations_dir + "/Visualisation_AE_VAE_noIR_IR__" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".pdf", bbox_inches='tight')
    fig.savefig(visualisations_dir + "/Visualisation_AE_VAE_noIR_IR__" + str(x_query) + "_" + str(
        y_query) + "_noise" + str(noise_level) + "_" + datetime.datetime.now().strftime(
        "%Y%m%d_%H%M%S") + ".png", bbox_inches='tight', facecolor='w')

    fig.show()

    plt.show(block=True)
    plt.interactive(False)

In [None]:
image_path = '/home/niaki/Downloads/montage.png'
image = imageio.imread(image_path)
compare_stride = 65

In [None]:
x_queries = [patch_size * 47] # 4, 7,  7,  7,  51, 51, 51, 15, 31, 47, 47
y_queries = [patch_size * 32] # 6, 19, 17, 21, 19, 21, 30, 29, 28, 23, 22

results_patches_x_coords_0, results_patches_y_coords_0 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 0, image, compare_stride=compare_stride)
results_patches_x_coords_1, results_patches_y_coords_1 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 1, image, compare_stride=compare_stride)
results_patches_x_coords_2, results_patches_y_coords_2 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 2, image, compare_stride=compare_stride)
results_patches_x_coords_3, results_patches_y_coords_3 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 3, image, compare_stride=compare_stride)

In [None]:
generate_visualisation_for_4_descrs(x_queries, y_queries, 
                                    results_patches_x_coords_0, results_patches_y_coords_0,
                                    results_patches_x_coords_1, results_patches_y_coords_1,
                                    results_patches_x_coords_2, results_patches_y_coords_2,
                                    results_patches_x_coords_3, results_patches_y_coords_3,
                                    image)

In [None]:
58*40

In [None]:
for xquery_it in range(58):
    for yquery_it in range(40):
        print(xquery_it, yquery_it)
        x_queries = [patch_size * xquery_it]
        y_queries = [patch_size * yquery_it]

        results_patches_x_coords_0, results_patches_y_coords_0 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 0, image, compare_stride=compare_stride)
        results_patches_x_coords_1, results_patches_y_coords_1 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 1, image, compare_stride=compare_stride)
        results_patches_x_coords_2, results_patches_y_coords_2 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 2, image, compare_stride=compare_stride)
        results_patches_x_coords_3, results_patches_y_coords_3 = retrieve_patches_for_queries_and_descr(x_queries, y_queries, 3, image, compare_stride=compare_stride)
        
        generate_visualisation_for_4_descrs(x_queries, y_queries, 
                                    results_patches_x_coords_0, results_patches_y_coords_0,
                                    results_patches_x_coords_1, results_patches_y_coords_1,
                                    results_patches_x_coords_2, results_patches_y_coords_2,
                                    results_patches_x_coords_3, results_patches_y_coords_3,
                                    image)