In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import matplotlib.pyplot as plt
from dataloader_assg import  get_data_loader, get_query_image
# from dataloader_coco import get_data_loader_train, get_data_loader_val
from visual_embedding import get_visual_embedding
from image_captioning import get_caption
from similarity_check import compute_similarity
from PIL import Image
from caption_embedding import sentenceTransformerEmbeddings
import numpy as np


ModuleNotFoundError: No module named 'dataloader_assg'

### matching with visual embeddings

In [None]:
model_images = 'dataset\hlcv_assg\model'
batch_size = 4

model_dataLoader = get_data_loader(model_images, batch_size, shuffle=True)

#plot images
# for images, paths in model_dataLoader:
#     for image in images:
#         image = image.permute(1, 2, 0)
#         plt.imshow(image)
#         plt.show()
#     break

In [None]:
# get visual embedding
model_name = "resnet50"  # Choose the desired model name

model_images_embedding = []
model_images_path = []
cnt = 0
for images, paths in model_dataLoader:
    embedding = get_visual_embedding(images, model_name)
    # append embedding to list
    if embedding.shape[0] != batch_size:
        model_images_embedding.append(embedding)
        model_images_path.append(paths)

    else:
        model_images_embedding.extend(list(embedding))
        model_images_path.extend(paths)



In [None]:
print(len(model_images_embedding))

In [None]:
# compute similarity

query_image_path = 'dataset\hlcv_assg\query\obj6__40.png'
query_image = get_query_image(query_image_path)
query_embedding = get_visual_embedding(query_image, model_name)

top_k_similar = compute_similarity(query_embedding, model_images_embedding, similarity_metric='cosine', top_k=4)
result_paths = []
for embedding, similarity,idx in top_k_similar:
    print("--------{}---------".format(idx))
    print(f"Similarity score: {similarity}")
    print(f"Similar embedding: {embedding}")
    result_paths.append(model_images_path[idx])





In [None]:
#display query image
image = plt.imread(query_image_path)
plt.imshow(image)
plt.title("Query Image")
plt.show()


for path in result_paths:
    image = plt.imread(path)
    plt.imshow(image)
    plt.title("Result Image")
    plt.show()

### matching with text + visual embeddings

In [None]:
model_images = 'dataset\hlcv_assg\model'
batch_size = 4
cnn_model_name = "resnet50"

def load_image(image_path):
    image = Image.open(image_path)
    return image

In [None]:
model_dataLoader = get_data_loader(model_images, batch_size, shuffle=False)


In [None]:
vis_embeddings_list = []
captions_list= []
caption_embeddings_list = []
merged_embeddings_list = []
img_paths_list = []

for images, paths in model_dataLoader:

        #get visual embedding
        vis_embeddings = get_visual_embedding(images, cnn_model_name)
        #get captions
        imgs = []
        for path in paths:
            imgs.append(load_image(path))
        captions = get_caption(imgs)
        #get caption embedding
        cap_embeddings = sentenceTransformerEmbeddings(captions)

        # append vis embedding to list
        if vis_embeddings.shape[0] != batch_size:
            vis_embeddings_list.append(vis_embeddings)
            merged_embedding  = np.concatenate((np.array([vis_embeddings]), cap_embeddings), axis=1)
            img_paths_list.extend(paths)
        else:
            vis_embeddings_list.extend(list(vis_embeddings))
            merged_embedding  = np.concatenate((np.array(vis_embeddings), cap_embeddings), axis=1)
            img_paths_list.extend(paths)
        #add caption to list
        captions_list.extend(captions)
        #append caption embedding to list
        caption_embeddings_list.extend(cap_embeddings)
        # concat visual and caption embedding
        merged_embeddings_list.extend(merged_embedding)

        # break


In [None]:
print(len(vis_embeddings_list))
print(len(captions_list))
print(len(caption_embeddings_list))
print(len(merged_embeddings_list))
print(len(img_paths_list))

In [None]:
for i in range(4):
    print("--------{}---------".format(i))
    print(f"Image path: {img_paths_list[i]}")
    print(f"Caption: {captions_list[i]}")
    print(f"Visual embedding: {vis_embeddings_list[i].shape}")
    print(f"Caption embedding: {caption_embeddings_list[i].shape}")
    print(np.concatenate((vis_embeddings_list[i], caption_embeddings_list[i]), axis=0).shape)
    print(f"Merged embedding: {merged_embeddings_list[i].shape}")
    print("")


In [None]:
#load query image
query_image_path = 'dataset\hlcv_assg\query\obj2__40.png'
query_image = get_query_image(query_image_path)
query_image_pil = load_image(query_image_path)

#get embeddings
query_vis_embedding = get_visual_embedding(query_image, cnn_model_name)
query_caption = get_caption(query_image_pil)
print(query_caption)
query_cap_embeddings = sentenceTransformerEmbeddings(captions)
query_merged_embedding = np.concatenate((np.array([vis_embeddings]), cap_embeddings), axis=1)




In [None]:

#find similarity
top_k_similar = compute_similarity(query_cap_embeddings[0], caption_embeddings_list, similarity_metric='cosine', top_k=4)
result_paths = []
for embedding, similarity,idx in top_k_similar:
    print("--------{}---------".format(idx))
    print(f"Similarity score: {similarity}")
    print(f"Similar embedding: {embedding.shape}")
    print(captions_list[idx])
    print(img_paths_list[idx])
    result_paths.append(img_paths_list[idx])


In [None]:
#display query image
image = plt.imread(query_image_path)
plt.imshow(image)
plt.title("Query Image")
plt.show()


for path in result_paths:
    image = plt.imread(path)
    plt.imshow(image)
    plt.title("Result Image")
    plt.show()