In [31]:
import os
import random
import torch
from torchvision import transforms
from torchvision.io.image import read_image, ImageReadMode
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm

In [2]:
def sample_triplets(image_dir, num_triplets):
    dirnames = os.listdir(image_dir)
    triplets = set()

    while len(triplets) < num_triplets:
        # Randomly sample an anchor image
        anchor_dirname = random.choice(dirnames)
        anchor_filename = random.choice(os.listdir(os.path.join(image_dir, anchor_dirname)))
        anchor_path = os.path.join(image_dir, anchor_dirname, anchor_filename)

        # Randomly sample a positive image (same class as anchor)
        positive_dirname = anchor_dirname
        positive_dir_files = os.listdir(os.path.join(image_dir, positive_dirname))
        positive_dir_files.remove(anchor_filename)
        positive_filename = random.choice(positive_dir_files)
        positive_path = os.path.join(image_dir, positive_dirname, positive_filename)

        # Randomly sample a negative image (different class from anchor)
        dirnames_ = dirnames.copy()
        dirnames_.remove(anchor_dirname)
        negative_dirname = random.choice(dirnames_)
        negative_filename = random.choice(os.listdir(os.path.join(image_dir, negative_dirname)))
        negative_path = os.path.join(image_dir, negative_dirname, negative_filename)

        triplets.add((anchor_path, positive_path, negative_path))

    triplets = list(triplets)
    return triplets

In [6]:
class ImageTripletDataset(Dataset):
    def __init__(self, triplets, transform=None):
        self.triplets = triplets
        self.transform = transform

    def __getitem__(self, index):
        anchor_path, positive_path, negative_path = self.triplets[index]

        anchor_image = read_image(anchor_path, mode=ImageReadMode.RGB) / 255.0
        positive_image = read_image(positive_path, mode=ImageReadMode.RGB) / 255.0
        negative_image = read_image(negative_path, mode=ImageReadMode.RGB) / 255.0

        if self.transform is not None:
            anchor_image = self.transform(anchor_image)
            positive_image = self.transform(positive_image)
            negative_image = self.transform(negative_image)

        return anchor_image, positive_image, negative_image

    def __len__(self):
        return len(self.triplets)

In [34]:
# Set the image directory and other parameters
image_dir = "dataset/CUB_200_2011/images"
num_triplets = 10
batch_size = 32

# Sample image triplets
triplets = sample_triplets(image_dir, num_triplets)

# Define the data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create the image triplet dataset
dataset = ImageTripletDataset(triplets, transform=transform)

# Create the data loader
dataloader = DataLoader(dataset, batch_size=None, shuffle=True)

# # Iterate through the data loader
# for i, (anchor, positive, negative) in enumerate(dataloader):
#     # Process the anchor, positive, and negative images
#     # Perform any necessary operations or training steps
#     if i == 0:
#         print(anchor.shape)
#         print(positive.shape)
#         print(negative.shape)
#         break

In [22]:
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18

# Define your dataset and data loaders for anchor, positive, and negative examples

# Define the triplet loss function
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = torch.dist(anchor, positive, p=2)
        distance_negative = torch.dist(anchor, negative, p=2)
        loss = torch.relu(distance_positive - distance_negative + self.margin)
        return loss

In [83]:
# Load the pretrained ResNet18 model
pretrained_model = resnet18(pretrained=True)

# Create the finetuning model with the pretrained backbone
finetuned_model = nn.Sequential(*list(pretrained_model.children())[:-1])

In [84]:
# # Freeze the model parameters
# for name, param in finetuned_model.named_parameters():
#     if name == '6' or name == '7' or name == '8':
#         param.requires_grad = True
#     else:
#         param.requires_grad = False

In [85]:
# Define the triplet loss
triplet_loss = TripletLoss()

# Define the optimizer and learning rate scheduler
optimizer = optim.Adam(finetuned_model.parameters(), lr=0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Training loop
num_epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
finetuned_model.to(device)
finetuned_model.train()

for epoch in tqdm(range(num_epochs)):
    total_loss = 0.0

    for anchor, positive, negative in dataloader:
        anchor = anchor.to(device)              # [3, 224, 224]
        positive = positive.to(device)
        negative = negative.to(device)

        optimizer.zero_grad()
        anchor_embedding = finetuned_model(anchor.unsqueeze(0)).flatten()       # [512, ]
        positive_embedding = finetuned_model(positive.unsqueeze(0)).flatten()
        negative_embedding = finetuned_model(negative.unsqueeze(0)).flatten()

        loss = triplet_loss(anchor_embedding, positive_embedding, negative_embedding)
        '''
            # In case the previous cell to freeze model params is run, 
            the next line must be uncommented to solve the error : 
            RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

            TODO: 
            On freezing model params for layers 6, 7 (both blocks of conv, batchnorm layers), 8 (Avg Pool), 
                the loss doesn't decrease at all!
        '''
        # loss.requires_grad = True #### (Only when freezing some model params)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    lr_scheduler.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")

# Evaluate or save the finetuned model
model_path = "../models/resnet18_triplet_finetuned.pth"
torch.save(finetuned_model, model_path)

  5%|▌         | 1/20 [00:00<00:13,  1.38it/s]

Epoch 1/20, Loss: 0.9815043807029724


 10%|█         | 2/20 [00:01<00:13,  1.38it/s]

Epoch 2/20, Loss: 1.0720288634300232


 15%|█▌        | 3/20 [00:02<00:12,  1.38it/s]

Epoch 3/20, Loss: 0.5543064594268798


 20%|██        | 4/20 [00:02<00:11,  1.38it/s]

Epoch 4/20, Loss: 0.42964686155319215


 25%|██▌       | 5/20 [00:03<00:10,  1.39it/s]

Epoch 5/20, Loss: 0.7255280315876007


 30%|███       | 6/20 [00:04<00:10,  1.39it/s]

Epoch 6/20, Loss: 0.3984716713428497


 35%|███▌      | 7/20 [00:05<00:09,  1.39it/s]

Epoch 7/20, Loss: 0.1897673487663269


 40%|████      | 8/20 [00:05<00:08,  1.39it/s]

Epoch 8/20, Loss: 0.0863064169883728


 45%|████▌     | 9/20 [00:06<00:07,  1.38it/s]

Epoch 9/20, Loss: 0.21846996545791625


 50%|█████     | 10/20 [00:07<00:07,  1.38it/s]

Epoch 10/20, Loss: 0.4011864960193634


 55%|█████▌    | 11/20 [00:07<00:06,  1.39it/s]

Epoch 11/20, Loss: 0.04901754856109619


 60%|██████    | 12/20 [00:08<00:05,  1.39it/s]

Epoch 12/20, Loss: 0.0


 65%|██████▌   | 13/20 [00:09<00:05,  1.39it/s]

Epoch 13/20, Loss: 0.0


 70%|███████   | 14/20 [00:10<00:04,  1.39it/s]

Epoch 14/20, Loss: 0.0


 75%|███████▌  | 15/20 [00:10<00:03,  1.38it/s]

Epoch 15/20, Loss: 0.0


 80%|████████  | 16/20 [00:11<00:02,  1.38it/s]

Epoch 16/20, Loss: 0.0


 85%|████████▌ | 17/20 [00:12<00:02,  1.38it/s]

Epoch 17/20, Loss: 0.0


 90%|█████████ | 18/20 [00:12<00:01,  1.39it/s]

Epoch 18/20, Loss: 0.0


 95%|█████████▌| 19/20 [00:13<00:00,  1.39it/s]

Epoch 19/20, Loss: 0.0


100%|██████████| 20/20 [00:14<00:00,  1.39it/s]

Epoch 20/20, Loss: 0.0





In [86]:
model_path = "../models/resnet18_triplet_finetuned.pth"
finetuned_model = torch.load(model_path)

Create a class dictionary to store mappings of class names with class ids

In [87]:
img_path = "dataset/CUB_200_2011/images/"
dir_list = os.listdir(img_path)
class_dict = {}
for dirname in dir_list:
    tokens = dirname.split(".")
    class_dict[tokens[1].lower()] = tokens[0]

class_dict

{'prothonotary_warbler': '177',
 'yellow_throated_vireo': '157',
 'prairie_warbler': '176',
 'cardinal': '017',
 'sooty_albatross': '003',
 'florida_jay': '074',
 'olive_sided_flycatcher': '040',
 'swainson_warbler': '178',
 'cape_glossy_starling': '134',
 'chestnut_sided_warbler': '165',
 'yellow_bellied_flycatcher': '043',
 'northern_waterthrush': '183',
 'ruby_throated_hummingbird': '068',
 'ringed_kingfisher': '082',
 'great_grey_shrike': '112',
 'parakeet_auklet': '007',
 'red_winged_blackbird': '010',
 'geococcyx': '110',
 'pine_warbler': '175',
 'white_eyed_vireo': '156',
 'field_sparrow': '119',
 'tropical_kingbird': '077',
 'nelson_sharp_tailed_sparrow': '126',
 'western_grebe': '053',
 'bewick_wren': '193',
 'mockingbird': '091',
 'brewer_sparrow': '115',
 'ring_billed_gull': '064',
 'house_sparrow': '118',
 'cedar_waxwing': '186',
 'indigo_bunting': '014',
 'herring_gull': '062',
 'grasshopper_sparrow': '121',
 'vesper_sparrow': '131',
 'seaside_sparrow': '128',
 'heermann_g

Next, compute embeddings for query and gallery images using the finetuned `resnet18` encoder.

In [88]:
from torchvision.transforms.functional import normalize, resize
import numpy as np

In [89]:
# function to compute the embeddings for each image in a input path using model defined above and save them in a output path
def compute_and_save_embeddings(inp_path : str, out_path : str):
    # create output directories
    os.makedirs(out_path + "/query", exist_ok=True)
    os.makedirs(out_path + "/gallery", exist_ok=True)

    query_path = inp_path + "/query"
    gallery_path = inp_path + "/gallery"

    # compute embeddings for query images
    query_files = os.listdir(query_path)
    for file in query_files:
        img = read_image(query_path + "/" + file)
        input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        out = finetuned_model(input_tensor.unsqueeze(0))
        output_tensor = out.flatten()
        torch.save(output_tensor, out_path + "/query/" + file[:-4] + ".pt")

    # compute embeddings for gallery images
    gallery_dirnames = os.listdir(gallery_path)
    for dirname in gallery_dirnames:
        os.makedirs(out_path + "/gallery" + "/" + dirname, exist_ok=True)
        gallery_files = os.listdir(gallery_path + "/" + dirname)
        
        for file in gallery_files:
            img = read_image(gallery_path + "/" + dirname + "/" + file)
            input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            out = finetuned_model(input_tensor.unsqueeze(0))
            output_tensor = out.flatten() 
            torch.save(output_tensor, out_path + "/gallery/" + dirname + "/" + file[:-4] + ".pt")

In [97]:
compute_and_save_embeddings("dataset/img_retrieval_CUB_200_2011", "embeddings_CUB_200_2011_ft")



RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Finally, retrieve the saved embeddings, find the top-1 matched image using cosine similarity and save the heatmaps.

In [95]:
from image_ops import load_and_resize, pil_bgr_to_rgb, combine_image_and_heatmap
from similarity_ops import compute_spatial_similarity
import numpy as np
from PIL import Image
import cv2
import pandas as pd
import matplotlib.pyplot as plt

In [91]:
for name, param in finetuned_model.named_children():
    print(name)
    if name == '8':
        print(param)

0
1
2
3
4
5
6
7
8
AdaptiveAvgPool2d(output_size=(1, 1))


In [92]:
stylianou_model = nn.Sequential(*list(finetuned_model.children())[:-1])

In [93]:
for name, param in stylianou_model.named_children():
    print(name)
    if name == '7':
        print(param)

0
1
2
3
4
5
6
7
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): BasicBlock(
    (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchN

In [94]:
def stylianou(img1_path, img2_path, save_path): 
    '''
        Separate definition for CUB dataset, with only query-heatmap overlay as output.
    '''
    img1_filename = img1_path.split("/")[-1][:-4]
    img1 = read_image(img1_path)
    img2 = read_image(img2_path)

    # Preprocess
    img1_norm = normalize(resize(img1, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    img2_norm = normalize(resize(img2, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    features1 = stylianou_model(img1_norm.unsqueeze(0))
    features2 = stylianou_model(img2_norm.unsqueeze(0))

    c, h, w = features1.squeeze(0).shape

    # Compute the similarity heatmap
    conv1 = features1.squeeze(0).permute(1, 2, 0).detach().numpy().reshape(h*w, c)
    conv2 = features2.squeeze(0).permute(1, 2, 0).detach().numpy().reshape(h*w, c)
    similarity = compute_spatial_similarity(conv1, conv2)

    similarity1, _ = similarity

    dummy_arr = np.zeros((224, 224, 3))

    img1_out = combine_image_and_heatmap(dummy_arr, similarity1)  # overlay heatmap on image

    overlay_img = img1_out[:, :, :3]
    sim_path = save_path + "/" + "{}.jpg".format(img1_filename)
    cv2.imwrite(sim_path, overlay_img)

In [96]:
# function to retrieve the query embeddings, compute the cosine similarity with all the gallery embeddings, return the top 1 results and save whether top-1 class matches or not
def retrieve_visualize(img_path : str, emb_path : str, vis_path: str, csv_path: str):
    # create output directories
    os.makedirs(vis_path, exist_ok=True)

    query_path = emb_path + "/query"
    gallery_path = emb_path + "/gallery"

    # retrieve and visualize query images
    df = pd.read_csv(csv_path, sep='\t', encoding='utf-8')
    
    query_files = os.listdir(query_path)
    for query_file in query_files:
        query_emb = torch.load(query_path + "/" + query_file)
        gallery_dirnames = os.listdir(gallery_path)
        max_sim = -1
        max_file_path = ""
        for dirname in gallery_dirnames:
            file_names = os.listdir(gallery_path + "/" + dirname)
            for file in file_names:
                gallery_emb = torch.load(gallery_path + "/" + dirname + "/" + file)
                sim = torch.cosine_similarity(query_emb, gallery_emb, dim=0)
                if sim > max_sim:
                    max_sim = sim
                    max_file_path = dirname + "/" + file

        print("Query : {} | Top reference : {}".format(query_file, max_file_path))

        correct = 0

        # save whether the retrieved image is of the correct class or not
        try:
            query_class = '_'.join(query_file[:-3].split('_')[:-2])
            query_class_id = int(class_dict[query_class.lower()])
            max_class = max_file_path.split('/')[0].split('.')[1]
            max_class_id = int(class_dict[max_class.lower()])

            
            if query_class_id == max_class_id:
                correct = 1

        except:
            pass

        query_imgname = query_file[:-3] + ".jpg"
        df.loc[df['img_name'] == query_imgname, 'correct'] = correct
        stylianou(img_path + "/query/" + query_file[:-3] + ".jpg", img_path + "/gallery/" + max_file_path[:-3] + ".jpg", vis_path)
    
    df.to_csv(csv_path, sep='\t', encoding='utf-8')

In [None]:
img_path = "dataset/img_retrieval_CUB_200_2011_ft"
emb_path = "embeddings_CUB_200_2011_ft"
vis_path = "visualizations_CUB_200_2011_ft/heatmaps"
csv_path = "dataset/CUB_200_2011/annotations.csv"
retrieve_visualize(img_path, emb_path, vis_path, csv_path)