In [21]:
import json
import numpy as np
import h5py
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchvision import models
from torchvision import transforms
import torch
from torch.utils.data import Dataset, DataLoader
import os
import math
from sklearn import preprocessing

In [22]:
with h5py.File("data02/london_lite_gt.h5","r") as f:
    fovs = f["fov"][:]
    sim = f["sim"][:].astype(np.uint8)

In [23]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, json_path, transform=None, n=8):
        self.transform=transform
        self.n = n
        self.root_dir = root_dir
        with open(json_path,"r") as f:
            m_idx = json.load(f)
            self.m_imgs = np.array(m_idx["im_paths"])

    def __len__(self):
        return len(self.m_imgs)

    def __getitem__(self, idx):
        img = plt.imread(os.path.join(self.root_dir, self.m_imgs[idx]))

        if self.transform:
            img = self.transform(img)

        return img

In [24]:
def find_global_features(loader, device, model):
    global_features = None
    for img_batch in tqdm(loader):
        img_batch = img_batch.to(device)
        with torch.no_grad():
            output = model(img_batch)

            # Pooling if model is densenet
            if len(model) == 1:
                output, _ = output.max(dim=2)  # Max pooling along the spatial dimensions (dim=2)
                output, _ = output.max(dim=2)

            if global_features is None:
                global_features = output.cpu().numpy().squeeze()
            else:
                global_features = np.vstack((global_features, output.cpu().numpy().squeeze()))

    return global_features

In [25]:
def manhatten(matrix, vector):
    return np.linalg.norm(matrix - vector, axis=1, ord=1)

def infinity(matrix, vector):
    return np.linalg.norm(matrix - vector, axis=1, ord=math.inf)

def eucledian(matrix, vector):
    return np.linalg.norm(matrix - vector, axis=1)

def cosine(matrix, vector):
    dists = np.zeros(len(vector))
    for i in range(matrix.shape[0]):
        dists[i] = np.dot(matrix[i], vector) / (np.linalg.norm(matrix[i]) * np.linalg.norm(vector))

    return -dists

In [26]:
def run_all_queries(device, model, distance_metric, bow_map_images, scaler):
    q_database = CustomDataset(root_dir="data02", json_path="data02/query/query_lite.json", transform=transform)
    all_relevant_images = []
    all_retrieved_images = []
    for query_idx in tqdm(range(len(q_database))):
        img = q_database[query_idx]

        # compute bag of words
        with torch.no_grad():
            img = img.to(device)
            o = model(img[None, :])
            # Pooling
            if len(model) == 1:
                o, _ = o.max(dim=2)  # Max pooling along the spatial dimensions (dim=2)
                o, _ = o.max(dim=2)

            repr = o.cpu().numpy().squeeze()

        new_repr = scaler.transform(repr.reshape(-1, 1).transpose())
        new_repr = new_repr.transpose().reshape(-1)

        dists = distance_metric(bow_map_images, new_repr)
        retrieved_images = np.argsort(np.array(dists))

        all_retrieved_images.append(retrieved_images)
        relevant_images = np.where(sim[query_idx, :] == 1)[0]
        all_relevant_images.append(relevant_images)

    return all_retrieved_images, all_relevant_images

In [27]:
def recall_at_k(relevant, retrieved, k):
    return np.sum(np.in1d(relevant, retrieved[:k])) / len(relevant)

def precision_at_k(relevant, retrieved, k):
    tp = np.sum(np.in1d(relevant, retrieved[:k]))
    return tp / k

def average_precision(relevant, retrieved):
    # BEGIN ANSWER
    avg_prec = 0
    for doc in relevant:
        k = np.where(retrieved == doc)[0][0] + 1
        avg_prec += precision_at_k(relevant, retrieved[:k], k)
    return avg_prec/len(relevant)

def mean_average_precision(all_relevant, all_retrieved):
    # BEGIN ANSWER
    total = 0
    count = len(all_retrieved)

    for qid in range(len(all_retrieved)):
        avg_precision = average_precision(all_relevant[qid], all_retrieved[qid])
        total += avg_precision
    # END ANSWER
    return total / count

def average_recall_at_k(all_relevant, all_retrieved, k):
    running_recall = 0
    for relevant, retrieved in zip(all_relevant, all_retrieved):
        r_k = recall_at_k(relevant, retrieved, k)
        running_recall += r_k

    return running_recall / len(all_relevant)

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loaded device: {device}")

Loaded device: cpu


In [29]:
densenet = models.densenet201(pretrained=True)
densenet = torch.nn.Sequential(*list(densenet.children())[:-1])
densenet = densenet.to(device)

In [30]:
resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))
resnet = resnet.to(device)

In [31]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512, 512), antialias=False)
])

database = CustomDataset(root_dir="data02", json_path="data02/database/database_lite.json", transform=transform)
loader = DataLoader(database, batch_size=64)

for i, model in enumerate([densenet, resnet]):
   for distance_metric in [manhatten, infinity, eucledian, cosine]:
        global_features = find_global_features(loader, device, model)

        # Compute z-score statistics
        scaler = preprocessing.StandardScaler()
        # Normalize the vectors of the map collection (0 mean and 1 std)
        scaled_features = scaler.fit_transform(global_features)

        all_retrieved_images, all_relevant_images = run_all_queries(device, model, distance_metric, global_features, scaler)

        print("-------------")
        if i == 0:
            print("Model: DenseNet202")
        else:
            print("Model: ResNet101")
        print(f"Distance metric: {distance_metric.__name__}")
        print("")
        mAP = mean_average_precision(all_relevant_images, all_retrieved_images)
        print(f"mAP: {mAP}")
        for k in [1, 5, 10]:
            r_k = average_recall_at_k(all_relevant_images, all_retrieved_images, k)
            print(f"Recall@{k}: {r_k}")
        print("-------------")

  0%|          | 0/16 [00:00<?, ?it/s]

  img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()

KeyboardInterrupt

