# STEP 6

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Installing the libraries and uploading the datasets

In [None]:
!pip install -r /content/drive/MyDrive/Project/code/requirements.txt

In [None]:
import zipfile
from tqdm.notebook import tqdm

SRC_PATHS = ['/content/drive/MyDrive/Project/code/datasets/gsv_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/sf_xs.zip',
             '/content/drive/MyDrive/Project/code/datasets/tokyo_xs.zip']

DST_PATH = '/content/datasets'

for path in tqdm(SRC_PATHS, 'Uploading the datasets from Drive to disk'):
  with zipfile.ZipFile(path, 'r') as zip_ref:
      zip_ref.extractall(DST_PATH)

Uploading the datasets from Drive to disk:   0%|          | 0/3 [00:00<?, ?it/s]

## Creating npy files from SanFranciscoXS-val and SanFranciscoXS-test datasets

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/SanFranciscoXS"

datasets = ['val', 'test']
for ds in datasets:
  # creating sfxs_ds_dbImages.npy
  db_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/database")
  db_imagelist = ["database/"+image for image in db_imagelist]
  db_imagelist = np.array(db_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_dbImages.npy", db_imagelist)
  print(f"sfxs_{ds}_dbImages.npy created with {np.size(db_imagelist)} elements")

  # creating sfxs_ds_qImages.npy
  q_imagelist = os.listdir(f"/content/datasets/sf_xs/{ds}/queries")
  q_imagelist = ["queries/"+image for image in q_imagelist]
  q_imagelist = np.array(q_imagelist)
  np.save(DEST_PATH+f"/sfxs_{ds}_qImages.npy", q_imagelist)
  print(f"sfxs_{ds}_qImages.npy created with {np.size(q_imagelist)} elements")

  # creating sfxs_ds_gt.npy
  def imageDist(image1, image2):
    img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
    img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
    return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

  THRESHOLD = 25
  gt_list = []
  for q_img in q_imagelist:
    gt_query = []
    counter = 0
    for db_img in db_imagelist:
      if imageDist(q_img, db_img) <= THRESHOLD:
        gt_query.append(counter)
      counter += 1
    gt_list.append(gt_query)
  gt_list = np.array(gt_list, dtype=object)
  np.save(DEST_PATH+f"/sfxs_{ds}_gt.npy", gt_list)
  print(f"sfxs_{ds}_gt.npy created with {np.size(gt_list)} elements")


sfxs_val_dbImages.npy created with 8015 elements
sfxs_val_qImages.npy created with 7993 elements
sfxs_val_gt.npy created with 7993 elements
sfxs_test_dbImages.npy created with 27191 elements
sfxs_test_qImages.npy created with 1000 elements
sfxs_test_gt.npy created with 1000 elements


## Training the model on GSV-CitiesXS dataset

In [None]:
%run /content/drive/MyDrive/Project/code/main.py

  rank_zero_deprecation(
INFO:lightning_lite.utilities.seed:Global seed set to 1
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit native Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Loading gsv_cities.csv

+----------------------+
|   Training Dataset   |
+-------------+--------+
| # of cities | 23     |
| # of places | 62514  |
| # of images | 524701 |
+-------------+--------+

+-----------------------------+
|     Validation Datasets     |
+------------------+----------+
| Validation set 1 | sfxs_val |
+------------------+----------+

+-------------------------------+
|        Training config        |
+------------------+------------+
| Batch size (PxK) | 100x4      |
| # of iterations  | 625        |
| Image size       | (224, 224) |
+------------------+------------+


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type               | Params
--------------------------------------------------
0 | loss_fn    | TripletMarginLoss  | 0     
1 | miner      | TripletMarginMiner | 0     
2 | backbone   | ResNet             | 2.8 M 
3 | aggregator | MixVPR             | 454 K 
--------------------------------------------------
2.6 M     Trainable params
683 K     Non-trainable params
3.2 M     Total params
6.475     Total estimated model params size (MB)


Loading gsv_cities.csv


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 53.96 | 67.87 | 73.03 | 76.05 | 78.43 | 80.03 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 56.76 | 69.92 | 75.23 | 78.44 | 80.66 | 82.47 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 58.28 | 71.19 | 76.62 | 79.61 | 81.52 | 83.11 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 58.41 | 71.89 | 77.10 | 79.84 | 82.22 | 83.62 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 60.48 | 72.90 | 77.86 | 80.91 | 82.82 | 84.52 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 61.99 | 74.99 | 79.32 | 81.91 | 83.97 | 85.62 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 62.05 | 74.78 | 79.14 | 81.87 | 83.85 | 85.30 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.06 | 75.39 | 79.92 | 82.55 | 84.65 | 86.04 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 62.87 | 75.38 | 79.96 | 82.80 | 84.62 | 86.09 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.06 | 75.58 | 80.15 | 82.85 | 84.82 | 86.13 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.68 | 75.40 | 80.31 | 83.11 | 85.02 | 86.26 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.82 | 75.78 | 80.58 | 83.27 | 85.06 | 86.44 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.86 | 76.00 | 80.61 | 83.46 | 85.19 | 86.51 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.78 | 76.10 | 80.47 | 83.20 | 85.10 | 86.43 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.22 | 76.14 | 80.90 | 83.52 | 85.40 | 86.70 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.94 | 75.97 | 80.22 | 83.11 | 85.06 | 86.41 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.03 | 76.00 | 80.43 | 83.06 | 85.02 | 86.46 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.77 | 75.99 | 80.31 | 83.04 | 84.96 | 86.53 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.89 | 75.84 | 80.28 | 83.00 | 85.02 | 86.51 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.08 | 76.05 | 80.33 | 83.06 | 85.09 | 86.53 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.99 | 75.89 | 80.12 | 82.87 | 85.09 | 86.44 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.96 | 75.97 | 80.25 | 82.97 | 85.11 | 86.64 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.98 | 76.10 | 80.37 | 83.16 | 85.21 | 86.73 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.84 | 76.12 | 80.36 | 83.10 | 85.26 | 86.58 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.98 | 75.84 | 80.35 | 83.05 | 85.09 | 86.48 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 63.97 | 75.82 | 80.22 | 82.91 | 85.07 | 86.40 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.07 | 75.79 | 80.35 | 82.96 | 85.06 | 86.39 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.06 | 75.98 | 80.33 | 83.11 | 85.14 | 86.53 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]



+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.11 | 75.82 | 80.23 | 83.07 | 85.16 | 86.49 |
+----------+-------+-------+-------+-------+-------+-------+



Loading gsv_cities.csv


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.




+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.07 | 75.80 | 80.41 | 83.05 | 85.14 | 86.49 |
+----------+-------+-------+-------+-------+-------+-------+





## Doing inference on SanFranciscoXS-val and SanFranciscoXS-test datasets


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from drive.MyDrive.Project.code.utils.validation import get_validation_recalls
from drive.MyDrive.Project.code.dataloaders.val.SanFranciscoXSDataset import SanFranciscoXSDataset
from drive.MyDrive.Project.code.dataloaders.val.TokyoXSDataset import TokyoXSDataset
from drive.MyDrive.Project.code.main import VPRModel

MEAN=[0.485, 0.456, 0.406]; STD=[0.229, 0.224, 0.225]

IM_SIZE = (224, 224)

def input_transform(image_size=IM_SIZE):
  return T.Compose([
    # T.Resize(image_size, interpolation=T.InterpolationMode.BICUBIC),
		T.Resize(image_size,  interpolation=T.InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD)
  ])

def get_val_dataset(dataset_name, input_transform=input_transform()):
  dataset_name = dataset_name.lower()

  if 'sfxs_val' in dataset_name:
    ds = SanFranciscoXSDataset(which_ds = 'val', input_transform = input_transform)

  if 'sfxs_test' in dataset_name:
    ds = SanFranciscoXSDataset(input_transform = input_transform)

  if 'tokyoxs' in dataset_name:
    ds = TokyoXSDataset(input_transform = input_transform)

  num_references = ds.num_references
  num_queries = ds.num_queries
  ground_truth = ds.ground_truth
  return ds, num_references, num_queries, ground_truth

def get_descriptors(model, dataloader, device):
  descriptors = []
  with torch.no_grad():
    for batch in tqdm(dataloader, 'Calculating descritptors...'):
      imgs, labels = batch
      output = model(imgs.to(device)).cpu()
      descriptors.append(output)

  return torch.cat(descriptors)

# define which device you'd like run experiments on (cuda:0 if you only have one gpu)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = VPRModel(backbone_arch='resnet18',
                 layers_to_crop=[4],
                 agg_arch='MixVPR',
                 agg_config={'in_channels' : 256,
                             'in_h' : 14,
                             'in_w' : 14,
                             'out_channels' : 256,
                             'mix_depth' : 5,
                             'mlp_ratio' : 1,
                             'out_rows' : 4},
                )

state_dict = torch.load('/content/drive/MyDrive/Project/code/LOGS/resnet18/lightning_logs/version_0/checkpoints/resnet18_epoch(14)_step(9390)_R1[0.6422]_R5[0.7614].ckpt')
# model.load_state_dict(state_dict)
model.load_state_dict(state_dict['state_dict'])
model.eval()
model = model.to(device)

In [None]:
val_dataset_name = 'sfxs_val'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/401 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                 Performance on sfxs_val                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 64.22 | 76.15 | 80.90 | 83.52 | 85.40 | 86.70 |
+----------+-------+-------+-------+-------+-------+-------+


In [None]:
val_dataset_name = 'sfxs_test'
batch_size = 40

val_dataset, num_references, num_queries, ground_truth = get_val_dataset(val_dataset_name)
val_loader = DataLoader(val_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, val_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=val_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/705 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                 Performance on sfxs_test                 |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 27.50 | 43.10 | 49.10 | 53.30 | 55.40 | 57.10 |
+----------+-------+-------+-------+-------+-------+-------+


## Creating npy files from TokyoXS dataset

In [None]:
import os
import numpy as np
import math

# destination path
DEST_PATH = "/content/drive/MyDrive/Project/code/datasets/TokyoXS"

# creating tokyoxs_dbImages.npy
db_imagelist = os.listdir(f"/content/datasets/tokyo_xs/test/database")
db_imagelist = ["database/"+image for image in db_imagelist]
db_imagelist = np.array(db_imagelist)
np.save(DEST_PATH+f"/tokyoxs_dbImages.npy", db_imagelist)
print(f"tokyoxs_dbImages.npy created with {np.size(db_imagelist)} elements")

# creating tokyoxs_qImages.npy
q_imagelist = os.listdir(f"/content/drive/MyDrive/Project/code/datasets/tokyo_xs/test/queries")
q_imagelist = ["queries/"+image for image in q_imagelist]
q_imagelist = np.array(q_imagelist)
np.save(DEST_PATH+f"/tokyoxs_qImages.npy", q_imagelist)
print(f"tokyoxs_qImages.npy created with {np.size(q_imagelist)} elements")

# creating tokyoxs_gt.npy
def imageDist(image1, image2):
  img1_c1, img1_c2 = float(image1.split('@')[1]), float(image1.split('@')[2])
  img2_c1, img2_c2 = float(image2.split('@')[1]), float(image2.split('@')[2])
  return math.sqrt((img1_c1 - img2_c1)**2 + (img1_c2 - img2_c2)**2)

THRESHOLD = 25
gt_list = []
for q_img in q_imagelist:
  gt_query = []
  counter = 0
  for db_img in db_imagelist:
    if imageDist(q_img, db_img) <= THRESHOLD:
      gt_query.append(counter)
    counter += 1
  gt_list.append(gt_query)
gt_list = np.array(gt_list, dtype=object)
np.save(DEST_PATH+f"/tokyoxs_gt.npy", gt_list)
print(f"tokyoxs_gt.npy created with {np.size(gt_list)} elements")

tokyoxs_dbImages.npy created with 12771 elements
tokyoxs_qImages.npy created with 315 elements
tokyoxs_gt.npy created with 315 elements


## Doing inference on TokyoXS dataset

In [None]:
test_dataset_name = 'tokyoxs'
batch_size = 40

test_dataset, num_references, num_queries, ground_truth = get_val_dataset(test_dataset_name)
test_loader = DataLoader(test_dataset, num_workers=4, batch_size=batch_size)

descriptors = get_descriptors(model, test_loader, device)
print(f'Descriptor dimension {descriptors.shape[1]}')

# now we split into references and queries
r_list = descriptors[ : num_references].cpu()
q_list = descriptors[num_references : ].cpu()
recalls_dict, preds = get_validation_recalls(r_list=r_list,
                                    q_list=q_list,
                                    k_values=[1, 5, 10, 15, 20, 25],
                                    gt=ground_truth,
                                    print_results=True,
                                    dataset_name=test_dataset_name,
                                    )

Calculating descritptors...:   0%|          | 0/328 [00:00<?, ?it/s]

Descriptor dimension 1024


+----------------------------------------------------------+
|                  Performance on tokyoxs                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 37.78 | 60.32 | 71.75 | 77.78 | 79.68 | 81.27 |
+----------+-------+-------+-------+-------+-------+-------+


Parameters, Loss, Optimizer, Miner used:
model = VPRModel(
        backbone_arch='resnet18',
        pretrained=True,
        layers_to_freeze=2,
        layers_to_crop=[4],

        agg_arch='MixVPR',
        agg_config={'in_channels' : 256,
                    'in_h' : 14,
                    'in_w' : 14,
                    'out_channels' : 256,
                    'mix_depth' : 5,
                    'mlp_ratio' : 1,
                    'out_rows' : 4},

       lr=0.0002, # 0.03 for sgd
        optimizer='adam', # sgd, adam or adamw
        weight_decay=0, # 0.001 for sgd or 0.0 for adam
        momentum=0.9,
        warmpup_steps=600,
        milestones=[5, 10, 15, 25],
        lr_mult=0.3,
        
        #---------------------------------
        #---- Training loss function -----
        # see utils.losses.py for more losses
        # example: ContrastiveLoss, TripletMarginLoss, MultiSimilarityLoss,
        # FastAPLoss, CircleLoss, SupConLoss,
        #
        loss_name='TripletMarginLoss',
        miner_name='TripletMarginMiner', # example: TripletMarginMiner, MultiSimilarityMiner, PairMarginMiner
        miner_margin=0.1,
        faiss_gpu=False
        )             