# STEP 2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Installing the libraries and uploading the datasets

In [None]:
!pip install -r /content/drive/MyDrive/Project/code/requirements.txt

In [None]:
import zipfile
from tqdm.notebook import tqdm

SRC_PATHS = ['/content/drive/MyDrive/Project/code/datasets/gsv_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/sf_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/tokyo_xs.zip']

DST_PATH = '/content/datasets'

for path in tqdm(SRC_PATHS, 'Uploading the datasets from Drive to disk'):
  with zipfile.ZipFile(path, 'r') as zip_ref:
      zip_ref.extractall(DST_PATH)

Uploading the datasets from Drive to disk:   0%|          | 0/3 [00:00<?, ?it/s]

## Creating npy files from SanFranciscoXS-val and SanFranciscoXS-test datasets

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/SanFranciscoXS"

datasets = ['val', 'test']
for ds in datasets:
  # creating sfxs_ds_dbImages.npy
  db_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/database")
  db_imagelist = ["database/"+image for image in db_imagelist]
  db_imagelist = np.array(db_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_dbImages.npy", db_imagelist)
  print(f"sfxs_{ds}_dbImages.npy created with {np.size(db_imagelist)} elements")

  # creating sfxs_ds_qImages.npy
  q_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/queries")
  q_imagelist = ["queries/"+image for image in q_imagelist]
  q_imagelist = np.array(q_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_qImages.npy", q_imagelist)
  print(f"sfxs_{ds}_qImages.npy created with {np.size(q_imagelist)} elements")

  # creating sfxs_ds_gt.npy
  def imageDist(image1, image2):
    img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
    img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
    return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

  THRESHOLD = 25
  gt_list = []
  for q_img in q_imagelist:
    gt_query = []
    counter = 0
    for db_img in db_imagelist:
      if imageDist(q_img, db_img) <= THRESHOLD:
        gt_query.append(counter)
      counter += 1
    gt_list.append(gt_query)
  gt_list = np.array(gt_list, dtype=object)
  np.save(DEST_PATH+f"/sfxs_{ds}_gt.npy", gt_list)
  print(f"sfxs_{ds}_gt.npy created with {np.size(gt_list)} elements")


sfxs_val_dbImages.npy created with 8015 elements
sfxs_val_qImages.npy created with 7993 elements
sfxs_val_gt.npy created with 7993 elements
sfxs_test_dbImages.npy created with 27191 elements
sfxs_test_qImages.npy created with 1000 elements
sfxs_test_gt.npy created with 1000 elements


## Training the model on GSV-CitiesXS dataset

In [None]:
%run /content/drive/MyDrive/Project/code/main.py

  rank_zero_deprecation(
INFO:lightning_lite.utilities.seed:Global seed set to 1
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit native Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Loading gsv_cities.csv

+----------------------+
|   Training Dataset   |
+-------------+--------+
| # of cities | 23     |
| # of places | 62514  |
| # of images | 524701 |
+-------------+--------+

+-----------------------------+
|     Validation Datasets     |
+------------------+----------+
| Validation set 1 | sfxs_val |
+------------------+----------+

+-------------------------------+
|        Training config        |
+------------------+------------+
| Batch size (PxK) | 100x4      |
| # of iterations  | 625        |
| Image size       | (224, 224) |
+------------------+------------+


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type                 | Params
--------------------------------------------------
0 | loss_fn  | MultiSimilarityLoss  | 0     
1 | miner    | MultiSimilarityMiner | 0     
2 | backbone | ResNet               | 2.8 M 
--------------------------------------------------
2.1 M     Trainable params
683 K     Non-trainable params
2.8 M     Total params
5.566     Total estimated model params size (MB)


Loading gsv_cities.csv


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 41.45 | 57.55 | 64.52 | 68.67 | 71.99 | 74.04 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 44.20 | 60.77 | 68.07 | 71.84 | 74.83 | 76.74 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 48.09 | 64.44 | 70.60 | 74.39 | 77.21 | 79.04 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 49.23 | 65.46 | 71.46 | 75.18 | 77.98 | 80.02 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 50.69 | 66.42 | 72.71 | 76.54 | 78.98 | 80.76 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 51.32 | 67.00 | 73.68 | 77.17 | 79.76 | 81.58 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 51.02 | 67.07 | 73.59 | 77.18 | 79.64 | 81.31 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 52.01 | 67.85 | 74.15 | 77.52 | 79.84 | 81.65 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 51.97 | 68.06 | 74.38 | 77.59 | 80.25 | 81.81 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 52.88 | 68.53 | 75.39 | 78.66 | 80.56 | 82.13 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.38 | 68.74 | 75.30 | 78.52 | 80.58 | 82.26 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.18 | 68.92 | 75.09 | 78.33 | 80.43 | 82.20 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.63 | 69.09 | 75.69 | 79.01 | 80.95 | 82.37 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.45 | 69.07 | 75.12 | 78.58 | 80.93 | 82.52 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.61 | 69.41 | 75.39 | 78.64 | 80.92 | 82.65 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.96 | 69.60 | 75.60 | 78.76 | 81.25 | 82.75 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.18 | 69.54 | 75.79 | 79.04 | 81.41 | 82.77 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.08 | 69.50 | 75.60 | 78.88 | 81.33 | 82.83 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.35 | 69.76 | 75.73 | 78.96 | 81.36 | 83.04 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.02 | 69.40 | 75.67 | 78.87 | 81.28 | 82.90 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.35 | 69.52 | 75.62 | 78.89 | 81.16 | 82.87 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.10 | 69.54 | 75.57 | 78.87 | 81.25 | 82.96 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.10 | 69.47 | 75.69 | 78.77 | 81.18 | 82.81 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.22 | 69.55 | 75.83 | 78.78 | 81.11 | 82.91 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.21 | 69.49 | 75.65 | 78.77 | 81.06 | 82.85 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.36 | 69.87 | 75.74 | 78.87 | 81.26 | 83.05 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.36 | 69.84 | 75.72 | 78.98 | 81.16 | 82.94 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.49 | 69.94 | 75.80 | 79.06 | 81.43 | 83.24 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.46 | 69.94 | 75.70 | 79.03 | 81.27 | 83.11 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.




+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 54.54 | 69.85 | 75.68 | 78.92 | 81.35 | 83.09 |
+----------+-------+-------+-------+-------+-------+-------+





## Doing inference on SanFranciscoXS-val and SanFranciscoXS-test datasets  


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from drive.MyDrive.Project.code.utils.validation import get_validation_recalls
from drive.MyDrive.Project.code.dataloaders.val.SanFranciscoXSDataset import SanFranciscoXSDataset
from drive.MyDrive.Project.code.dataloaders.val.TokyoXSDataset import TokyoXSDataset
from drive.MyDrive.Project.code.main import VPRModel

MEAN=[0.485, 0.456, 0.406]; STD=[0.229, 0.224, 0.225]

IM_SIZE = (224, 224)

def input_transform(image_size=IM_SIZE):
  return T.Compose([
    # T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
		T.Resize(image_size,  interpolation=T.InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD)
  ])

def get_val_dataset(dataset_name, input_transform=input_transform()):
  dataset_name = dataset_name.lower()

  if 'sfxs_val' in dataset_name:
    ds = SanFranciscoXSDataset(which_ds = 'val', input_transform = input_transform)

  if 'sfxs_test' in dataset_name:
    ds = SanFranciscoXSDataset(input_transform = input_transform)

  if 'tokyoxs' in dataset_name:
    ds = TokyoXSDataset(input_transform = input_transform)

  num_references = ds.num_references
  num_queries = ds.num_queries
  ground_truth = ds.ground_truth
  return ds, num_references, num_queries, ground_truth

def get_descriptors(model, dataloader, device):
  descriptors = []
  with torch.no_grad():
    for batch in tqdm(dataloader, 'Calculating descritptors...'):
      imgs, labels = batch
      output = model(imgs.to(device)).cpu()
      descriptors.append(output)

  return torch.cat(descriptors)

# define which device you'd like run experiments on (cuda:0 if you only have one gpu)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = VPRModel(backbone_arch='resnet18',
                 layers_to_crop=[4],
                 agg_arch='AVG',
        )

state_dict = torch.load('/content/drive/MyDrive/Project/code/LOGS/resnet18/lightning_logs/version_0/checkpoints/resnet18_epoch(27)_step(17528)_R1[0.5390]_R5[0.6941].ckpt') # link to the trained weights
# model.load_state_dict(state_dict)
model.load_state_dict(state_dict['state_dict'])
model.eval()
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [None]:
val_dataset_name = 'sfxs_val'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/401 [00:00<?, ?it/s]

Descriptor dimension 256


+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.85 | 69.41 | 75.20 | 78.51 | 80.76 | 82.48 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
val_dataset_name = 'sfxs_test'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/705 [00:00<?, ?it/s]

Descriptor dimension 256


+----------------------------------------------------------+
|                 Performance on sfxs_test                 |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 21.30 | 34.30 | 41.30 | 44.70 | 47.20 | 49.40 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from PIL import ImageOps
from torchvision import transforms

IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

def denormalize(img, mean, std):
    '''
    Function created for printing the images in a correct way:
    when we run the dataloader, we apply different trasformations to images,
    in particular we normalize the images' tensor.
    So we have to make the inverse function before visualization.
    '''
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return img * std + mean

def frame(pred_img, p_index, query_gt):

    to_pil = transforms.ToPILImage()
    to_tensor = transforms.ToTensor()

    pred_img = to_pil(pred_img)

    # the predicted image is in the ground truth
    if p_index in query_gt:
        border_color = (0, 255, 0) # green
    else:
        border_color = (255, 0, 0) # red

    pred_img = ImageOps.expand(pred_img, border=10, fill=border_color)

    return to_tensor(pred_img)

logger = TensorBoardLogger(f'/content/drive/MyDrive/Project/code/LOGS/tb_logs')
queries_to_log = 5
predictions_to_log = 5

np.random.seed(42)
query_indices = np.random.choice(num_queries, queries_to_log, replace=False)

prediction_indices = preds[query_indices]

for i, q_index in enumerate(query_indices):
    query_img = val_dataset.__getitem__(q_index+num_references)[0]
    query_img = denormalize(query_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
    logger.experiment.add_image(f'query_{q_index}', query_img)
    for p_index in prediction_indices[i][:predictions_to_log]:
        predicted_img = val_dataset.__getitem__(p_index)[0]
        predicted_img = denormalize(predicted_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
        predicted_img = frame(predicted_img, p_index, ground_truth[q_index])
        logger.experiment.add_image(f'query_{q_index}/prediction_{p_index}', predicted_img)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/Project/code/LOGS/tb_logs

# STEP 3

## Creating npy files from TokyoXS dataset

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/TokyoXS"

# creating tokyoxs_dbImages.npy
db_imagelist = os.listdir(f"/content/datasets/tokyo_xs/test/database")
db_imagelist = ["database/"+image for image in db_imagelist]
db_imagelist = np.array(db_imagelist)
np.save(DEST_PATH+f"/tokyoxs_dbImages.npy", db_imagelist)
print(f"tokyoxs_dbImages.npy created with {np.size(db_imagelist)} elements")

# creating tokyoxs_qImages.npy
q_imagelist = os.listdir(f"/content/drive/MyDrive/Project/code/datasets/tokyo_xs/test/queries")
q_imagelist = ["queries/"+image for image in q_imagelist]
q_imagelist = np.array(q_imagelist)
np.save(DEST_PATH+f"/tokyoxs_qImages.npy", q_imagelist)
print(f"tokyoxs_qImages.npy created with {np.size(q_imagelist)} elements")

# creating tokyoxs_gt.npy
def imageDist(image1, image2):
  img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
  img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
  return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

THRESHOLD = 25
gt_list = []
for q_img in q_imagelist:
  gt_query = []
  counter = 0
  for db_img in db_imagelist:
    if imageDist(q_img, db_img) <= THRESHOLD:
      gt_query.append(counter)
    counter += 1
  gt_list.append(gt_query)
gt_list = np.array(gt_list, dtype=object)
np.save(DEST_PATH+f"/tokyoxs_gt.npy", gt_list)
print(f"tokyoxs_gt.npy created with {np.size(gt_list)} elements")

tokyoxs_dbImages.npy created with 12771 elements
tokyoxs_qImages.npy created with 315 elements
tokyoxs_gt.npy created with 315 elements


## Doing inference on TokyoXS dataset

In [None]:
val_dataset_name = 'tokyoxs'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/328 [00:00<?, ?it/s]

Descriptor dimension 256


+----------------------------------------------------------+
|                  Performance on tokyoxs                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 29.21 | 47.62 | 55.24 | 63.17 | 68.25 | 71.75 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from PIL import ImageOps
from torchvision import transforms

IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

def denormalize(img, mean, std):
    '''
    Function created for printing the images in a correct way:
    when we run the dataloader, we apply different trasformations to images,
    in particular we normalize the images' tensor.
    So we have to make the inverse function before visualization.
    '''
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return img * std + mean

def frame(pred_img, p_index, query_gt):

    to_pil = transforms.ToPILImage()
    to_tensor = transforms.ToTensor()

    pred_img = to_pil(pred_img)

    # the predicted image is in the ground truth
    if p_index in query_gt:
        border_color = (0, 255, 0) # green
    else:
        border_color = (255, 0, 0) # red

    pred_img = ImageOps.expand(pred_img, border=10, fill=border_color)

    return to_tensor(pred_img)

logger = TensorBoardLogger(f'/content/drive/MyDrive/Project/code/LOGS/tb_logs')
queries_to_log = 5
predictions_to_log = 5

np.random.seed(42)
query_indices = np.random.choice(num_queries, queries_to_log, replace=False)

prediction_indices = preds[query_indices]

for i, q_index in enumerate(query_indices):
    query_img = val_dataset.__getitem__(q_index+num_references)[0]
    query_img = denormalize(query_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
    logger.experiment.add_image(f'query_{q_index}', query_img)
    for p_index in prediction_indices[i][:predictions_to_log]:
        predicted_img = val_dataset.__getitem__(p_index)[0]
        predicted_img = denormalize(predicted_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
        predicted_img = frame(predicted_img, p_index, ground_truth[q_index])
        logger.experiment.add_image(f'query_{q_index}/prediction_{p_index}', predicted_img)

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/Project/code/LOGS/tb_logs