# Importing libraries

In [138]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import timm
from torchvision import transforms, models, datasets
import os
import faiss
from map import evaluate
from revisited_dataset import RevisitedDataset
from triplet_dataset import TripletData
from sklearn.preprocessing import normalize
from sklearn.decomposition import PCA

# Initial Parameters

In [139]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [140]:
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

In [141]:
dataset_name = 'roxford5k'

path = 'E:/Datasets/paris/'
root = f'E:/Datasets/paris/{dataset_name}'

In [142]:
checkpoint_names =  ['vit_small_patch16_224',
                     'deit3_small_patch16_224',
                     'swinv2_cr_small_224',
                     'resnet50']

# Loading data and model

In [143]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

In [144]:
offline_dataset = RevisitedDataset(root=root, phase='database', dataset_name=dataset_name, transform=transform)
offline_loader  = DataLoader(offline_dataset, batch_size=64)

In [145]:
selected_model = checkpoint_names[0]
model = timm.create_model(selected_model, pretrained=True).to(device)
model.head = nn.Identity()
model.eval()
# embed_dim = 1000
# embed_dim = 768
embed_dim = 384

# Offline global feature DB generation

In [146]:
compute = True
if compute:
    index_flat = faiss.IndexFlatL2(embed_dim)   # build the index

    img_indeces   = []

    with torch.no_grad():
        prev = 1
        for i, img in enumerate(offline_loader):
            img = img.to(device)
        
            representation = model(img)
            representation = representation.cpu().detach().numpy()
            representation = normalize(representation)
            index_flat.add(representation) #add the representation to index
            img_indeces.extend(list(range(i*prev, i*prev + len(img))))   #store the image name to find it later on
            prev = len(img)
            
    index = index_flat

In [147]:
# Writing INDEX to and reading from a disk, Looks like that it works with CPU index only?
index_name = f'{dataset_name}_{selected_model}.index'

faiss.write_index(index_flat, os.path.join('indeces', index_name))

# index = faiss.read_index(os.path.join(path, 'indeces', index_name))

# with open(os.path.join(path, 'indeces', 'rparis_vit_tiny_indeces.txt'), 'w') as f:
#     for i in img_indeces:
#         f.write(f'{i}\n')

# Query Evaluation

In [154]:
online_dataset  = RevisitedDataset(root=root, phase='query', setup='easy', dataset_name=dataset_name, transform=transform)
online_loader   = DataLoader(online_dataset, batch_size=1)

In [155]:
Is = []
gnts = []
with torch.no_grad():
    for img, gndt in online_loader:
        img = img.to(device)

        test_embed = model(img).cpu().detach().numpy()
        test_embed = normalize(test_embed)
        _, I = index.search(test_embed, 10000)

        # print(f"Retrieved Image is OK?: {I[0][0] in gndt['ok']}")

        Is.append(I[0])
        gnts.append(gndt)

Is = np.array(Is)
gnts = np.array(gnts)

In [156]:
mAP, *_ = evaluate.compute_map(Is.T, gnts)
mAP

0.3488580073361315