# Importing libraries

In [22]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import timm
from torchvision import transforms, models, datasets
from sklearn.preprocessing import normalize
import os
import faiss
from map import evaluate
import pickle
from revisited_dataset import RevisitedDataset
from triplet_dataset import TripletData, TripletLoss


# Initial Parameters

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

In [4]:
dataset_name = 'rparis6k'

path = 'E:/Datasets/paris/'
root = f'E:/Datasets/paris/{dataset_name}'

In [5]:
checkpoint_names =  ['vit_small_patch16_224',
                     'deit3_small_patch16_224',
                     'swinv2_cr_small_224',
                     'resnet50']

In [6]:
import os
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models

# Loading data and model

In [7]:
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


# Datasets and Dataloaders
train_data = TripletData(root, train_transforms)

train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=16, shuffle=True,)

In [12]:
offline_dataset = RevisitedDataset(root=root, phase='database', transform=val_transforms)
offline_loader  = DataLoader(offline_dataset, batch_size=64)

In [9]:
selected_model = checkpoint_names[0]
model = timm.create_model(selected_model, pretrained=True).to(device)
model.head = nn.Identity()
model.train()
# embed_dim = 1000
# embed_dim = 768
embed_dim = 384

# Training

In [None]:
epochs = 1

optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()

In [24]:
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
    print("Train Loss: {}".format(epoch_loss.item()))

  0%|          | 0/396 [00:00<?, ?it/s]

Train Loss: 188.0474090576172


# Offline global feature DB generation

In [25]:
compute = True
model.eval()
if compute:
    index_flat = faiss.IndexFlatL2(embed_dim)   # build the index

    img_indeces   = []

    with torch.no_grad():
        prev = 1
        for i, img in enumerate(offline_loader):
            img = img.to(device)
        
            representation = model(img)
            representation = representation.cpu().detach().numpy()
            index_flat.add(representation) #add the representation to index
            img_indeces.extend(list(range(i*prev, i*prev + len(img))))   #store the image name to find it later on
            prev = len(img)
            
    index = index_flat

In [288]:
# Writing INDEX to and reading from a disk, Looks like that it works with CPU index only?
index_name = f'rparis_ft_{selected_model}.index'

faiss.write_index(index_flat, os.path.join('indeces', index_name))

# index = faiss.read_index(os.path.join(path, 'indeces', index_name))

# with open(os.path.join(path, 'indeces', 'rparis_vit_tiny_indeces.txt'), 'w') as f:
#     for i in img_indeces:
#         f.write(f'{i}\n')

# Query Evaluation

In [29]:
online_dataset  = RevisitedDataset(root=root, phase='query', setup='easy', transform=val_transforms)
online_loader   = DataLoader(online_dataset, batch_size=1)

In [30]:
Is = []
gnts = []
model.eval()
with torch.no_grad():
    for img, gndt in online_loader:
        img = img.to(device)

        test_embed = model(img).cpu().detach().numpy()
        test_embed = normalize(test_embed)
        _, I = index.search(test_embed, 10000)

        print(f"Retrieved Image is OK?: {I[0][0] in gndt['ok']}")

        Is.append(I[0])
        gnts.append(gndt)

Is = np.array(Is)
gnts = np.array(gnts)

Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved Image is OK?: False
Retrieved 

In [31]:
mAP, *_ = evaluate.compute_map(Is.T, gnts)
mAP

0.035975331208757164