In [17]:
import os
import torch
from torchvision.io.image import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from torchvision.models import resnet18

In [18]:
model = resnet18(pretrained=True).eval()
for name, module in model.named_children():
    print(name)

conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc




In [19]:
f = torch.nn.Sequential(*list(model.children())[:-1]) 

In [20]:
img = read_image("dataset/img_retrieval/query/15881274738.jpg")
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
out = f(input_tensor.unsqueeze(0))

In [21]:
output_tensor = out.flatten()

In [22]:
# function to compute the embeddings for each image in a input path using model defined above and save them in a output path
def compute_and_save_embeddings(inp_path : str, out_path : str):
    # create output directories
    os.makedirs(out_path + "/query", exist_ok=True)
    os.makedirs(out_path + "/gallery", exist_ok=True)

    query_path = inp_path + "/query"
    gallery_path = inp_path + "/gallery"

    # compute embeddings for query images
    query_files = os.listdir(query_path)
    for file in query_files:
        img = read_image(query_path + "/" + file)
        input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        out = f(input_tensor.unsqueeze(0))
        output_tensor = out.flatten()
        torch.save(output_tensor, out_path + "/query/" + file[:-4] + ".pt")

    # compute embeddings for gallery images
    gallery_dirnames = os.listdir(gallery_path)
    for dirname in gallery_dirnames:
        os.makedirs(out_path + "/gallery" + "/" + dirname, exist_ok=True)
        gallery_files = os.listdir(gallery_path + "/" + dirname)
        
        for file in gallery_files:
            img = read_image(gallery_path + "/" + dirname + "/" + file)
            input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            out = f(input_tensor.unsqueeze(0))
            output_tensor = out.flatten() 
            torch.save(output_tensor, out_path + "/gallery/" + dirname + "/" + file[:-4] + ".pt")

In [23]:
compute_and_save_embeddings("dataset/img_retrieval", "embeddings")

In [28]:
from image_ops import load_and_resize, preprocess_im, pil_bgr_to_rgb, combine_image_and_heatmap, combine_horz
from similarity_ops import compute_spatial_similarity
import numpy as np
import os
from PIL import Image

In [36]:
counter = 0

In [40]:
def stylianou(img1_path, img2_path, save_path):
    global counter
    img1 = read_image(img1_path)
    img2 = read_image(img2_path)

    # Preprocess
    img1_norm = normalize(resize(img1, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    img2_norm = normalize(resize(img2, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    f = torch.nn.Sequential(*list(model.children())[:-2])  
    features1 = f(img1_norm.unsqueeze(0))
    features2 = f(img2_norm.unsqueeze(0))

    c, h, w = features1.squeeze(0).shape

    # Compute the similarity heatmap
    conv1 = features1.squeeze(0).permute(1, 2, 0).detach().numpy().reshape(h*w, c)
    conv2 = features2.squeeze(0).permute(1, 2, 0).detach().numpy().reshape(h*w, c)
    similarity = compute_spatial_similarity(conv1, conv2)

    similarity1, similarity2 = similarity

    img1_arr = load_and_resize(img1_path)
    img2_arr = load_and_resize(img2_path)

    img1_out = combine_image_and_heatmap(img1_arr, similarity1)  # overlay heatmap on image
    img2_out = combine_image_and_heatmap(img2_arr, similarity2)

    sim_final = combine_horz([img1_out, img2_out])  # combine both overlayed images side by side

    sim_final_pil = Image.fromarray(np.uint8(sim_final))
    sim_bgr2rgb = pil_bgr_to_rgb(sim_final_pil)   # convert bgr image to rgb (final preprocessing needed ?)

    sim_path = save_path + "/" + "{}.jpg".format(counter)
    counter += 1
    sim_bgr2rgb.save(sim_path)

In [41]:
# function to retrieve the query embeddings, compute the cosine similarity with all the gallery embeddings and return the top 1 results
def retrieve_visualize(img_path : str, emb_path : str, vis_path: str):
    # create output directories
    os.makedirs(vis_path, exist_ok=True)

    query_path = emb_path + "/query"
    gallery_path = emb_path + "/gallery"

    # retrieve and visualize query images
    query_files = os.listdir(query_path)
    for query_file in query_files:
        query_emb = torch.load(query_path + "/" + query_file)
        gallery_dirnames = os.listdir(gallery_path)
        max_sim = -1
        max_file_path = ""
        for dirname in gallery_dirnames:
            file_names = os.listdir(gallery_path + "/" + dirname)
            for file in file_names:
                gallery_emb = torch.load(gallery_path + "/" + dirname + "/" + file)
                sim = torch.cosine_similarity(query_emb, gallery_emb, dim=0)
                if sim > max_sim:
                    max_sim = sim
                    max_file_path = dirname + "/" + file
        print("Query : {} | Top reference : {}".format(query_file, max_file_path))
        stylianou(img_path + "/query/" + query_file[:-3] + ".jpg", img_path + "/gallery/" + max_file_path[:-3] + ".jpg", vis_path)

In [42]:
retrieve_visualize("dataset/img_retrieval", "embeddings", "visualizations")

Query : Indigo_Bunting_0016_13661.pt | Top reference : 014.Indigo_Bunting/Indigo_Bunting_0058_12207.pt




Query : Red_Legged_Kittiwake_0064_795422.pt | Top reference : 084.Red_legged_Kittiwake/Red_Legged_Kittiwake_0062_795434.pt
Query : Western_Grebe_0058_36403.pt | Top reference : 050.Eared_Grebe/Eared_Grebe_0004_34277.pt
Query : Laysan_Albatross_0104_630.pt | Top reference : 001.Black_footed_Albatross/Black_Footed_Albatross_0058_796074.pt
Query : Common_Yellowthroat_0069_190400.pt | Top reference : 020.Yellow_breasted_Chat/4802655907.pt
Query : Downy_Woodpecker_0042_184144.pt | Top reference : 112.Great_Grey_Shrike/Great_Grey_Shrike_0070_106547.pt
Query : Black_Throated_Sparrow_0094_107085.pt | Top reference : 132.White_crowned_Sparrow/White_Crowned_Sparrow_0040_127313.pt
Query : Green_Tailed_Towhee_0092_797397.pt | Top reference : 043.Yellow_bellied_Flycatcher/Yellow_Bellied_Flycatcher_0039_795471.pt
Query : Worm_Eating_Warbler_0014_176042.pt | Top reference : 154.Red_eyed_Vireo/Red_Eyed_Vireo_0034_157219.pt
Query : Northern_Waterthrush_0062_177364.pt | Top reference : 039.Least_Flycatc

In [1]:
dataset_path = "dataset/CUBV2"
imgs_path = dataset_path + "/images"
bbox_path = dataset_path + "/CUBV2_boxes.txt"

with open(bbox_path, 'r') as f:
    for i, line in enumerate(f):
        if i == 0: continue
        l = line.split(',')
        img_id, x0, y0, x1, y1 = l
        print(img_id, x0, y0, x1, y1)

        '''
            TODO: 
                use these bbox coordinates as ground truth
                define function to compute bboxes from img heatmap (for that, first save query image heatmaps separately, not paired with top-1 match as done earlier)
                use these two to compute accuracy metrics  
        '''

11789 176 136 282 155

11790 44 183 381 70

11791 437 42 397 379

11792 291 180 246 275

11793 169 122 198 167

11794 109 179 292 362

11795 80 168 525 322

11796 290 192 187 137

11797 68 63 483 278

11798 121 121 308 171

11799 179 145 318 213

11800 68 163 506 147

11801 85 85 500 363

11802 69 110 751 237

11803 38 102 543 121

11804 248 118 389 403

11805 183 88 239 313

11806 269 52 161 284

11807 117 76 336 291

11808 394 103 337 471

11809 122 119 259 457

11810 114 65 536 633

11811 0 212 296 431

11812 34 236 250 387

11813 0 3 294 481

11814 0 7 249 468

11815 11 46 418 305

11816 44 39 592 668

11817 68 98 401 390

11818 117 15 384 420

11819 307 140 551 539

11820 58 193 482 533

11821 110 124 289 423

11822 77 71 635 366

11823 105 70 367 309

11824 189 70 586 274

11825 66 107 518 353

11826 145 32 617 279

11827 0 19 795 580

11828 333 116 269 148

11829 103 23 277 204

11830 329 70 318 239

11831 316 174 293 228

11832 36 135 340 293

11833 279 164 677 482

11834 186 1