# STEP 6

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Installing the libraries and uploading the datasets

In [None]:
!pip install -r /content/drive/MyDrive/Project/code/requirements.txt

Collecting faiss-cpu==1.7.2 (from -r /content/drive/MyDrive/Project/code/requirements.txt (line 1))
  Downloading faiss_cpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m63.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting faiss-gpu==1.7.2 (from -r /content/drive/MyDrive/Project/code/requirements.txt (line 2))
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting matplotlib==3.5.3 (from -r /content/drive/MyDrive/Project/code/requirements.txt (line 3))
  Downloading matplotlib-3.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.9/11.9 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy==1.23

In [None]:
import zipfile
from tqdm.notebook import tqdm

SRC_PATHS = ['/content/drive/MyDrive/Project/code/datasets/gsv_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/sf_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/tokyo_xs.zip']

DST_PATH = '/content/datasets'

for path in tqdm(SRC_PATHS, 'Uploading the datasets from Drive to disk'):
  with zipfile.ZipFile(path, 'r') as zip_ref:
      zip_ref.extractall(DST_PATH)

Uploading the datasets from Drive to disk:   0%|          | 0/3 [00:00<?, ?it/s]

## Creating npy files from SanFranciscoXS-val and SanFranciscoXS-test datasets

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/SanFranciscoXS"

datasets = ['val', 'test']
for ds in datasets:
  # creating sfxs_ds_dbImages.npy
  db_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/database")
  db_imagelist = ["database/"+image for image in db_imagelist]
  db_imagelist = np.array(db_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_dbImages.npy", db_imagelist)
  print(f"sfxs_{ds}_dbImages.npy created with {np.size(db_imagelist)} elements")

  # creating sfxs_ds_qImages.npy
  q_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/queries")
  q_imagelist = ["queries/"+image for image in q_imagelist]
  q_imagelist = np.array(q_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_qImages.npy", q_imagelist)
  print(f"sfxs_{ds}_qImages.npy created with {np.size(q_imagelist)} elements")

  # creating sfxs_ds_gt.npy
  def imageDist(image1, image2):
    img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
    img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
    return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

  THRESHOLD = 25
  gt_list = []
  for q_img in q_imagelist:
    gt_query = []
    counter = 0
    for db_img in db_imagelist:
      if imageDist(q_img, db_img) <= THRESHOLD:
        gt_query.append(counter)
      counter += 1
    gt_list.append(gt_query)
  gt_list = np.array(gt_list, dtype=object)
  np.save(DEST_PATH+f"/sfxs_{ds}_gt.npy", gt_list)
  print(f"sfxs_{ds}_gt.npy created with {np.size(gt_list)} elements")


sfxs_val_dbImages.npy created with 8015 elements
sfxs_val_qImages.npy created with 7993 elements
sfxs_val_gt.npy created with 7993 elements
sfxs_test_dbImages.npy created with 27191 elements
sfxs_test_qImages.npy created with 1000 elements
sfxs_test_gt.npy created with 1000 elements


## Training the model on GSV-CitiesXS dataset

In [None]:
%run /content/drive/MyDrive/Project/code/main.py

  rank_zero_deprecation(
INFO:lightning_lite.utilities.seed:Global seed set to 1
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit native Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Loading gsv_cities.csv

+----------------------+
|   Training Dataset   |
+-------------+--------+
| # of cities | 23     |
| # of places | 62514  |
| # of images | 524701 |
+-------------+--------+

+-----------------------------+
|     Validation Datasets     |
+------------------+----------+
| Validation set 1 | sfxs_val |
+------------------+----------+

+-------------------------------+
|        Training config        |
+------------------+------------+
| Batch size (PxK) | 100x4      |
| # of iterations  | 625        |
| Image size       | (224, 224) |
+------------------+------------+


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type                 | Params
----------------------------------------------------
0 | loss_fn    | MultiSimilarityLoss  | 0     
1 | miner      | MultiSimilarityMiner | 0     
2 | backbone   | ResNet               | 2.8 M 
3 | aggregator | MixVPR               | 454 K 
----------------------------------------------------
2.6 M     Trainable params
683 K     Non-trainable params
3.2 M     Total params
6.475     Total estimated model params size (MB)


Loading gsv_cities.csv


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 65.13 | 74.78 | 78.47 | 80.37 | 81.77 | 82.99 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 68.90 | 77.99 | 81.41 | 83.16 | 84.50 | 85.55 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 71.36 | 79.76 | 83.10 | 84.84 | 85.94 | 87.08 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 72.65 | 81.15 | 84.21 | 86.08 | 87.21 | 88.23 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 73.39 | 81.45 | 84.39 | 86.14 | 87.26 | 87.91 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 74.77 | 82.62 | 85.32 | 87.24 | 88.33 | 89.18 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 74.78 | 82.67 | 85.55 | 87.36 | 88.64 | 89.55 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.52 | 83.12 | 86.20 | 87.90 | 89.32 | 90.13 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.55 | 83.12 | 86.18 | 87.68 | 88.95 | 89.90 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.52 | 82.99 | 86.10 | 87.80 | 88.98 | 89.83 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.88 | 83.02 | 86.40 | 88.15 | 89.27 | 90.07 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.09 | 83.56 | 86.64 | 88.43 | 89.54 | 90.32 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.12 | 83.64 | 86.55 | 88.36 | 89.47 | 90.27 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.12 | 83.71 | 86.83 | 88.55 | 89.60 | 90.33 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.18 | 83.55 | 86.54 | 88.24 | 89.47 | 90.32 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.99 | 83.47 | 86.34 | 88.23 | 89.27 | 90.05 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.98 | 83.47 | 86.34 | 88.23 | 89.28 | 89.99 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.08 | 83.40 | 86.39 | 88.21 | 89.35 | 90.13 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.07 | 83.42 | 86.36 | 88.23 | 89.25 | 90.09 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.20 | 83.50 | 86.54 | 88.35 | 89.40 | 90.22 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.20 | 83.41 | 86.45 | 88.38 | 89.34 | 90.09 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.22 | 83.52 | 86.61 | 88.36 | 89.38 | 90.19 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.18 | 83.32 | 86.45 | 88.24 | 89.30 | 90.04 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.23 | 83.45 | 86.56 | 88.34 | 89.47 | 90.09 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.25 | 83.49 | 86.59 | 88.41 | 89.42 | 90.23 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.05 | 83.15 | 86.45 | 88.14 | 89.18 | 89.98 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.00 | 83.20 | 86.43 | 88.19 | 89.18 | 89.85 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.03 | 83.20 | 86.38 | 88.20 | 89.13 | 89.88 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.03 | 83.32 | 86.43 | 88.26 | 89.24 | 89.94 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.08 | 83.29 | 86.39 | 88.19 | 89.24 | 89.98 |
+----------+-------+-------+-------+-------+-------+-------+


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.







## Doing inference on SanFranciscoXS-val and SanFranciscoXS-test datasets


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from drive.MyDrive.Project.code.utils.validation import get_validation_recalls
from drive.MyDrive.Project.code.dataloaders.val.SanFranciscoXSDataset import SanFranciscoXSDataset
from drive.MyDrive.Project.code.dataloaders.val.TokyoXSDataset import TokyoXSDataset
from drive.MyDrive.Project.code.main import VPRModel

MEAN=[0.485, 0.456, 0.406]; STD=[0.229, 0.224, 0.225]

IM_SIZE = (224, 224)

def input_transform(image_size=IM_SIZE):
  return T.Compose([
    # T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
		T.Resize(image_size,  interpolation=T.InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD)
  ])

def get_val_dataset(dataset_name, input_transform=input_transform()):
  dataset_name = dataset_name.lower()

  if 'sfxs_val' in dataset_name:
    ds = SanFranciscoXSDataset(which_ds = 'val', input_transform = input_transform)

  if 'sfxs_test' in dataset_name:
    ds = SanFranciscoXSDataset(input_transform = input_transform)

  if 'tokyoxs' in dataset_name:
    ds = TokyoXSDataset(input_transform = input_transform)

  num_references = ds.num_references
  num_queries = ds.num_queries
  ground_truth = ds.ground_truth
  return ds, num_references, num_queries, ground_truth

def get_descriptors(model, dataloader, device):
  descriptors = []
  with torch.no_grad():
    for batch in tqdm(dataloader, 'Calculating descritptors...'):
      imgs, labels = batch
      output = model(imgs.to(device)).cpu()
      descriptors.append(output)

  return torch.cat(descriptors)

# define which device you'd like run experiments on (cuda:0 if you only have one gpu)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = VPRModel(backbone_arch='resnet18',
                 layers_to_crop=[4],
                 agg_arch='MixVPR',
                 agg_config={'in_channels' : 256,
                             'in_h' : 14,
                             'in_w' : 14,
                             'out_channels' : 256,
                             'mix_depth' : 5,
                             'mlp_ratio' : 1,
                             'out_rows' : 4},
                )

state_dict = torch.load('/content/drive/MyDrive/Project/code/LOGS/resnet18/lightning_logs/version_0/checkpoints/resnet18_epoch(24)_step(15650)_R1[0.7625]_R5[0.8349].ckpt')
# model.load_state_dict(state_dict)
model.load_state_dict(state_dict['state_dict'])
model.eval()
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [None]:
val_dataset_name = 'sfxs_val'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/401 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 76.27 | 83.50 | 86.60 | 88.43 | 89.42 | 90.23 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
val_dataset_name = 'sfxs_test'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/705 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                 Performance on sfxs_test                 |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 55.10 | 66.50 | 71.50 | 73.30 | 74.80 | 75.90 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from PIL import ImageOps
from torchvision import transforms

IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

def denormalize(img, mean, std):
    '''
    Function created for printing the images in a correct way:
    when we run the dataloader, we apply different trasformations to images,
    in particular we normalize the images' tensor.
    So we have to make the inverse function before visualization.
    '''
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return img * std + mean

def frame(pred_img, p_index, query_gt):

    to_pil = transforms.ToPILImage()
    to_tensor = transforms.ToTensor()

    pred_img = to_pil(pred_img)

    # the predicted image is in the ground truth
    if p_index in query_gt:
        border_color = (0, 255, 0) # green
    else:
        border_color = (255, 0, 0) # red

    pred_img = ImageOps.expand(pred_img, border=10, fill=border_color)

    return to_tensor(pred_img)

logger = TensorBoardLogger(f'/content/drive/MyDrive/Project/code/LOGS/tb_logs')
queries_to_log = 5
predictions_to_log = 5

np.random.seed(42)
query_indices = np.random.choice(num_queries, queries_to_log, replace=False)

prediction_indices = preds[query_indices]

for i, q_index in enumerate(query_indices):
    query_img = val_dataset.__getitem__(q_index+num_references)[0]
    query_img = denormalize(query_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
    logger.experiment.add_image(f'query_{q_index}', query_img)
    for p_index in prediction_indices[i][:predictions_to_log]:
        predicted_img = val_dataset.__getitem__(p_index)[0]
        predicted_img = denormalize(predicted_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
        predicted_img = frame(predicted_img, p_index, ground_truth[q_index])
        logger.experiment.add_image(f'query_{q_index}/prediction_{p_index}', predicted_img)

## Creating npy files from TokyoXS dataset

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/TokyoXS"

# creating tokyoxs_dbImages.npy
db_imagelist = os.listdir(f"/content/datasets/tokyo_xs/test/database")
db_imagelist = ["database/"+image for image in db_imagelist]
db_imagelist = np.array(db_imagelist)
np.save(DEST_PATH+f"/tokyoxs_dbImages.npy", db_imagelist)
print(f"tokyoxs_dbImages.npy created with {np.size(db_imagelist)} elements")

# creating tokyoxs_qImages.npy
q_imagelist = os.listdir(f"/content/drive/MyDrive/Project/code/datasets/tokyo_xs/test/queries")
q_imagelist = ["queries/"+image for image in q_imagelist]
q_imagelist = np.array(q_imagelist)
np.save(DEST_PATH+f"/tokyoxs_qImages.npy", q_imagelist)
print(f"tokyoxs_qImages.npy created with {np.size(q_imagelist)} elements")

# creating tokyoxs_gt.npy
def imageDist(image1, image2):
  img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
  img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
  return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

THRESHOLD = 25
gt_list = []
for q_img in q_imagelist:
  gt_query = []
  counter = 0
  for db_img in db_imagelist:
    if imageDist(q_img, db_img) <= THRESHOLD:
      gt_query.append(counter)
    counter += 1
  gt_list.append(gt_query)
gt_list = np.array(gt_list, dtype=object)
np.save(DEST_PATH+f"/tokyoxs_gt.npy", gt_list)
print(f"tokyoxs_gt.npy created with {np.size(gt_list)} elements")

tokyoxs_dbImages.npy created with 12771 elements
tokyoxs_qImages.npy created with 315 elements
tokyoxs_gt.npy created with 315 elements


## Doing inference on TokyoXS dataset

In [None]:
val_dataset_name = 'tokyoxs'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/328 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                  Performance on tokyoxs                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 73.65 | 85.40 | 89.52 | 91.43 | 93.65 | 93.97 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from PIL import ImageOps
from torchvision import transforms

IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

def denormalize(img, mean, std):
    '''
    Function created for printing the images in a correct way:
    when we run the dataloader, we apply different trasformations to images,
    in particular we normalize the images' tensor.
    So we have to make the inverse function before visualization.
    '''
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return img * std + mean

def frame(pred_img, p_index, query_gt):

    to_pil = transforms.ToPILImage()
    to_tensor = transforms.ToTensor()

    pred_img = to_pil(pred_img)

    # the predicted image is in the ground truth
    if p_index in query_gt:
        border_color = (0, 255, 0) # green
    else:
        border_color = (255, 0, 0) # red

    pred_img = ImageOps.expand(pred_img, border=10, fill=border_color)

    return to_tensor(pred_img)

logger = TensorBoardLogger(f'/content/drive/MyDrive/Project/code/LOGS/tb_logs')
queries_to_log = 5
predictions_to_log = 5

np.random.seed(42)
query_indices = np.random.choice(num_queries, queries_to_log, replace=False)

prediction_indices = preds[query_indices]

for i, q_index in enumerate(query_indices):
    query_img = val_dataset.__getitem__(q_index+num_references)[0]
    query_img = denormalize(query_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
    logger.experiment.add_image(f'query_{q_index}', query_img)
    for p_index in prediction_indices[i][:predictions_to_log]:
        predicted_img = val_dataset.__getitem__(p_index)[0]
        predicted_img = denormalize(predicted_img, IMAGENET_MEAN_STD['mean'], IMAGENET_MEAN_STD['std'])
        predicted_img = frame(predicted_img, p_index, ground_truth[q_index])
        logger.experiment.add_image(f'query_{q_index}/prediction_{p_index}', predicted_img)

Parameters, Loss, Optimizer, Miner used:
model = VPRModel(
        backbone_arch='resnet18',
        pretrained=True,
        layers_to_freeze=2,
        layers_to_crop=[4],

        agg_arch='MixVPR',
        agg_config={'in_channels' : 256,
                    'in_h' : 14,
                    'in_w' : 14,
                    'out_channels' : 256,
                    'mix_depth' : 5,
                    'mlp_ratio' : 1,
                    'out_rows' : 4},

       lr=0.0002, # 0.03 for sgd
        optimizer='adam', # sgd, adam or adamw
        weight_decay=0, # 0.001 for sgd or 0.0 for adam
        momentum=0.9,
        warmpup_steps=600,
        milestones=[5, 10, 15, 25],
        lr_mult=0.3,
        
        #---------------------------------
        #---- Training loss function -----
        # see utils.losses.py for more losses
        # example: ContrastiveLoss, TripletMarginLoss, MultiSimilarityLoss,
        # FastAPLoss, CircleLoss, SupConLoss,
        #
        loss_name='MultiSimilarityLoss',
        miner_name='MultiSimilarityMiner', # example: TripletMarginMiner, MultiSimilarityMiner, PairMarginMiner
        miner_margin=0.1,
        faiss_gpu=False
        )             