This is the latest version of the codebase that allows to fully test all possibile combination that have been created.

In [1]:
#Dependencies
#This cell includes all dependencies we need to import
#To avoid to reimport also pytorch and to speed up execution of the import, the latest version has been imported.
!pip install pytorch_lightning
!pip install pytorch_metric_learning
!pip install faiss-cpu
!pip install faiss-gpu
!pip install lightning_lite
!pip install optuna


Collecting pytorch_lightning
  Downloading pytorch_lightning-2.3.3-py3-none-any.whl (812 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m812.3/812.3 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.4.0.post0-py3-none-any.whl (868 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.11.3.post0-py3-none-any.whl (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.0.0->pytorch_lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.0.0->pytorch_lightning)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.0

In [2]:
#Google Drive mount
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#Packages import
from pathlib import Path
from PIL import Image
from prettytable import PrettyTable
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_metric_learning import losses, miners
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as T
from lightning_lite.utilities.seed import seed_everything
import csv
import faiss
import faiss.contrib.torch_utils
import numpy as np
import os
import pandas as pd
import pytorch_lightning as pl
import re
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import math
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.distances import CosineSimilarity, DotProductSimilarity
import torch.nn.functional
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
import optuna
from optuna.trial import TrialState
from pytorch_metric_learning import losses
import joblib

In [4]:
#This function is used to validate the result of training
def get_validation_recalls(r_list, q_list, k_values, gt, print_results=True, faiss_gpu=False, dataset_name='dataset without name ?'):

        embed_size = r_list.shape[1]
        print(embed_size)
        if faiss_gpu:
            res = faiss.StandardGpuResources()
            flat_config = faiss.GpuIndexFlatConfig()
            flat_config.useFloat16 = True
            flat_config.device = 0
            faiss_index = faiss.GpuIndexFlatL2(res, embed_size, flat_config)
        # build index
        else:
            faiss_index = faiss.IndexFlatL2(embed_size)

        # add references
        faiss_index.add(r_list)

        # search for queries in the index
        _, predictions = faiss_index.search(q_list, max(k_values))



        # start calculating recall_at_k
        correct_at_k = np.zeros(len(k_values))
        for q_idx, pred in enumerate(predictions):
            for i, n in enumerate(k_values):
                # if in top N then also in top NN, where NN > N
                if np.any(np.in1d(pred[:n], gt[q_idx])):
                    correct_at_k[i:] += 1
                    break

        correct_at_k = correct_at_k / len(predictions)
        d = {k:v for (k,v) in zip(k_values, correct_at_k)}

        #if print_results:
        print('\n') # print a new line
        table = PrettyTable()
        table.field_names = ['K']+[str(k) for k in k_values]
        table.add_row(['Recall@K']+ [f'{100*v:.2f}' for v in correct_at_k])
        print(table.get_string(title=f"Performance on {dataset_name}"))

        return d, predictions





In [5]:
#Training dataset

# Paths
dataset_path = '/content/drive/MyDrive/Datasets/gsv_xs/'

cities_list = []
df = pd.DataFrame()
cities_list = os.listdir(os.path.join(dataset_path, 'train'))

default_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class GSVCitiesDataset(Dataset):
    def __init__(self,
                 cities,
                 img_per_place=4,
                 min_img_per_place=4,
                 random_sample_from_each_place=True,
                 base_path=dataset_path,
                 transform = default_transform):
        super(GSVCitiesDataset, self).__init__()
        self.base_path = base_path
        self.cities = cities_list
        self.transform = default_transform

        assert img_per_place <= min_img_per_place, \
            f"img_per_place should be less than {min_img_per_place}"
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.random_sample_from_each_place = random_sample_from_each_place

        # Generate the dataframe containing images metadata
        self.dataframe = self._getdataframes()

        # Self labels
        self.labels = self.dataframe['place_id_orig'].unique()

        # Get all unique place ids
        self.places_ids = pd.unique(self.dataframe.index)
        self.total_nb_images = len(self.dataframe)


    @staticmethod
    def get_img_name(row):
        easting = str(row['easting'])
        northing = str(row['northing'])
        zone = str(row['zone'])
        grid_zone = str(row['grid_zone'])
        place_id = str(row['place_id_orig']) #MICHELE
        pano_id = row['pano_id']
        year = str(row['year']).zfill(4)
        month = str(row['month']).zfill(2)
        north_degree = str(row['north_degree']).zfill(3)
        lat, lon = str(row['latitude']), str(row['longitude'])
        name = '@' + easting + '@' + northing + '@' + zone + '@' + grid_zone + '@' + lat + '@' + lon + '@' + pano_id + '@@' + \
            north_degree + '@@@@' + year + month + '@' + place_id + '@' + '.jpg'
        return name

    @staticmethod
    def image_loader(path):
        return Image.open(path).convert('RGB')

    def __len__(self):
        '''Denotes the total number of places (not images)'''
        return len(self.places_ids)

    def __getitem__(self, index):
        place_id = self.places_ids[index]

        # Get the place in form of a dataframe (each row corresponds to one image)
        place = self.dataframe.loc[place_id]

        # Sample K images (rows) from this place
        # We can either sort and take the most recent k images
        # or randomly sample them
        if self.random_sample_from_each_place:
            place = place.sample(n=self.img_per_place)
        else:  # Always get the same most recent images
            place = place.sort_values(
                by=['year', 'month', 'latitude'], ascending=False)
            place = place[:self.img_per_place]

        imgs = []
        for i, row in place.iterrows():
            img_name = self.get_img_name(row)
            img_path = os.path.join(self.base_path, 'train', row['city_id'], img_name)
            img = self.image_loader(img_path)

            if self.transform is not None:
                    img = self.transform(img)

            imgs.append(img)

        # NOTE: contrary to image classification where __getitem__ returns only one image
        # in GSVCities, we return a place, which is a tensor of K images (K=self.img_per_place)
        # This will return a tensor of shape [K, channels, height, width]. This needs to be taken into account
        # in the DataLoader (which will yield batches of shape [BS, K, channels, height, width])
        return torch.stack(imgs), torch.tensor(place_id).repeat(self.img_per_place)

    def _getdataframes(self):
        column_names = ['easting', 'northing', 'zone', 'grid_zone', 'latitude', 'longitude', 'pano_id', 'north_degree', 'year', 'month','city_id','place_id_orig']

        list_img_metadata = []

        """
        custom_cities = [
        'phoenix',
        'prg'
        ]
        """

        for city in cities_list:
        #for city in custom_cities:
            city_path = os.path.join(self.base_path, 'train', city)
            for filename in os.listdir(city_path):
                # Process the file using your process_file function
                img_metadata = self._process_file_name(filename)
                list_img_metadata.append(img_metadata)
        df = pd.DataFrame(list_img_metadata, columns=column_names)
        res = df[df.groupby('place_id_orig')['place_id_orig'].transform(
            'size') >= self.min_img_per_place].copy()

        res['progressive_number'] = res.groupby('place_id_orig').ngroup() + 1
        self.labels_used = res['progressive_number'].max()


        return res.set_index('progressive_number')

    def _process_file_name(self, file_name):
        # Split the file content using '@' as delimiter
        data_tokens = file_name.split('@')

        # Extract latitude and longitude
        easting = data_tokens[1]
        northing = data_tokens[2]
        zone = data_tokens[3]
        grid_zone = data_tokens[4]
        latitude = data_tokens[5]
        longitude = data_tokens[6]

        # Extract pano ID
        pano_id = data_tokens[7]

        # Extract north degree
        north_degree = data_tokens[9]

        # Extract year and month
        year_month = data_tokens[13]
        year = year_month[:4]
        month = year_month[4:]

        # Extract place ID and city ID
        place_id = data_tokens[14]
        city_id = data_tokens[14].split('_')[1].lower()

        # Structure the extracted data as a dictionary
        img_metadata = {
            'easting': easting,
            'northing': northing,
            'zone': zone,
            'grid_zone': grid_zone,
            'latitude': latitude,
            'longitude': longitude,
            'pano_id': pano_id,
            'north_degree': north_degree,
            'year': year,
            'month': month,
            'city_id':city_id,
            'place_id_orig': place_id,

        }
        return img_metadata





In [6]:
#Training dataloader
IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

VIT_MEAN_STD = {'mean': [0.5, 0.5, 0.5],
                'std': [0.5, 0.5, 0.5]}



TRAIN_CITIES = [
    'phoenix',
    'prg', # refers to Prague
    'prs',
    'rome',
    'trt', # refers to Toronto
    'washingtondc',
    'bangkok',
    'barcelona',
    'boston',
    'brussels',
    'buenosaires',
    'chicago',
    'lisbon',
    'london',
    'losangeles',
    'madrid',
    'medellin',
    'melbourne',
    'mexicocity',
    'miami',
    'minneapolis',
    'osaka',
    'osl' # refers to Oslo
]

"""
TRAIN_CITIES = [
    'phoenix',
    'prg'
]

"""


class GSVCitiesDataModule(pl.LightningDataModule):
    def __init__(self,
                 batch_size=128,
                 img_per_place=4,
                 min_img_per_place=4,
                 shuffle_all=True,
                 image_size=(224, 224),
                 num_workers=8,
                 show_data_stats=True,
                 cities=TRAIN_CITIES,
                 mean_std=IMAGENET_MEAN_STD,
                 batch_sampler=None,
                 random_sample_from_each_place=True,
                 val_set_names = ['sfxsval']
                 ):
        super().__init__()
        self.batch_size = batch_size
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.shuffle_all = shuffle_all
        self.image_size = image_size
        self.num_workers = num_workers
        self.batch_sampler = batch_sampler
        self.show_data_stats = show_data_stats
        self.cities = cities
        self.mean_dataset = mean_std['mean']
        self.std_dataset = mean_std['std']
        self.random_sample_from_each_place = random_sample_from_each_place
        self.save_hyperparameters() # save hyperparameter with Pytorch Lightening
        self.val_set_names=val_set_names

        self.train_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.RandAugment(num_ops=5, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset),
        ])

        self.valid_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset)])

        self.train_loader_config = {
            'batch_size': self.batch_size,
            'num_workers': self.num_workers,
            'drop_last': False,
            'pin_memory': False,
            'shuffle': self.shuffle_all,
            'prefetch_factor': 2,  # Numero di batch pre-caricati per worker
            'persistent_workers': True  # Mantiene i worker attivi tra i batch
        }

        self.valid_loader_config = {
            'batch_size': 256,
            'num_workers': self.num_workers,
            'drop_last': False,
            'pin_memory': True,
            'shuffle': False}



    def setup(self, stage):
        if stage == 'fit':
            # load train dataloader with reload routine
            self.reload()

            self.val_datasets = []
            for valid_set_name in self.val_set_names:
                if 'sfxsval' in valid_set_name.lower():
                    self.val_datasets.append(SFXSValDataset(input_transform=self.valid_transform))
                else:
                    print(f'Validation set {valid_set_name} does not exist or has not been implemented yet')
                    raise NotImplementedError

        if self.show_data_stats:
          self.print_stats()

    def reload(self):
        self.train_dataset = GSVCitiesDataset(
            cities=self.cities,
            img_per_place=self.img_per_place,
            min_img_per_place=self.min_img_per_place,
            random_sample_from_each_place=self.random_sample_from_each_place,
            )

    def train_dataloader(self):
        self.reload()
        return DataLoader(dataset=self.train_dataset, **self.train_loader_config)


    def val_dataloader(self):
      val_dataloaders = []
      for val_dataset in self.val_datasets:
          val_dataloaders.append(DataLoader(
              dataset=val_dataset, **self.valid_loader_config))
      return val_dataloaders


    def print_stats(self):
          print()  # print a new line
          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False
          table.add_row(["# of cities", f"{len(TRAIN_CITIES)}"])
          table.add_row(["# of places", f'{self.train_dataset.__len__()}'])
          table.add_row(["# of images", f'{self.train_dataset.total_nb_images}'])
          print(table.get_string(title="Training Dataset"))
          print()

          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False

          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False
          table.add_row(
              ["Batch size (PxK)", f"{self.batch_size}x{self.img_per_place}"])
          table.add_row(
              ["# of iterations", f"{self.train_dataset.__len__()//self.batch_size}"])
          table.add_row(["Image size", f"{self.image_size}"])
          print(table.get_string(title="Training config"))

In [7]:
#Validation dataset

default_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
DATASET_ROOT = '/content/drive/MyDrive/Datasets/sf_xs/val/'
GT_ROOT = '/content/drive/MyDrive/Datasets/sf_xs/val/' # BECAREFUL, this is the ground truth that comes with GSV-Cities

path_obj = Path(DATASET_ROOT)
if not path_obj.exists():
    raise Exception(f'Please make sure the path {DATASET_ROOT} to SanFrancisco dataset is correct')

if not path_obj.joinpath('ref') or not path_obj.joinpath('query'):
    raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}')

class SFXSValDataset(Dataset):
    def __init__(self, input_transform = default_transform):


        self.input_transform = input_transform

        # reference images names
        self.dbImages = np.load(GT_ROOT+'val_dbImages.npy')
        #print(len(self.dbImages))

        print('path gt root:  '+GT_ROOT+' val_dbImages.npy')

        # query images names
        self.qImages = np.load(GT_ROOT+'val_qImages.npy')
        #print(len(self.qImages))

        # ground truth
        self.ground_truth = np.load(GT_ROOT+'val_gt.npy', allow_pickle=True)

        # reference images then query images
        self.images = np.concatenate((self.dbImages, self.qImages))

        self.num_references = len(self.dbImages)



        print('num queries')
        self.num_queries = len(self.qImages)
        print(self.num_queries)


    def __getitem__(self, index):
        img = Image.open(DATASET_ROOT+self.images[index])

        if self.input_transform:
            img = self.input_transform(img)

        return img, index

    def __len__(self):
        return len(self.images)







In [8]:
class ResNet(nn.Module):
    def __init__(self,layers_to_freeze=2
                 ):

        super().__init__()
        weights = 'IMAGENET1K_V1'
        self.model = torchvision.models.resnet18(weights=weights)
        #Avg Pooling is removed to let feature mixer aggregator to work
        self.model.avgpool = None
        #Final FC is not needed
        self.model.fc = None
        #4-th layer is cropped
        self.model.layer4 = None
        #out channels


        if layers_to_freeze >= 1:
            self.model.layer1.requires_grad_(False)
        if layers_to_freeze >= 2:
            self.model.layer2.requires_grad_(False)
        if layers_to_freeze >= 3:
            self.model.layer3.requires_grad_(False)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        return x




In [9]:

class FeatureMixerLayer(nn.Module):
    def __init__(self, in_dim, mlp_ratio=1):
        super().__init__()
        self.mix = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, int(in_dim * mlp_ratio)),
            nn.ReLU(),
            nn.Linear(int(in_dim * mlp_ratio), in_dim),
        )

        for m in self.modules():
            if isinstance(m, (nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        return x + self.mix(x)


class MixVPR(nn.Module):
    def __init__(self,
                 in_channels=256,
                 in_h=14,
                 in_w=14,
                 out_channels=256,
                 mix_depth=5,
                 mlp_ratio=5,
                 out_rows=1,
                 ) -> None:
        super().__init__()

        self.in_h = in_h # height of input feature maps
        self.in_w = in_w # width of input feature maps
        self.in_channels = in_channels # depth of input feature maps

        self.out_channels = out_channels # depth wise projection dimension
        self.out_rows = out_rows # row wise projection dimesion

        self.mix_depth = mix_depth # L the number of stacked FeatureMixers
        self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block

        hw = in_h*in_w
        self.mix = nn.Sequential(*[
            FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio)
            for _ in range(self.mix_depth)
        ])
        self.channel_proj = nn.Linear(in_channels, out_channels)
        self.row_proj = nn.Linear(hw, out_rows)

    def forward(self, x):
        x = x.flatten(2)
        x = self.mix(x)
        x = x.permute(0, 2, 1)
        x = self.channel_proj(x)
        x = x.permute(0, 2, 1)
        x = self.row_proj(x)
        x = torch.nn.functional.normalize(x.flatten(1), p=2, dim=-1)
        return x




In [10]:
class GeMPool(nn.Module):
    """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch
    we add flatten and norm so that we can use it as one aggregation layer.
    """
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        x = F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
        x = x.flatten(1)
        return F.normalize(x, p=2, dim=1)

In [11]:
class AVGPool(nn.Module):
    """Implementation of Average Pooling layer ."""
    def __init__(self):
        super(AVGPool, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Pooling to a fixed output size (1, 1)

    def forward(self, x):
        x = self.pool(x)  # Perform average pooling
        x = x.flatten(1)  # Flatten the pooled output
        return x.float()


In [12]:
class VPRModel(pl.LightningModule):

    def __init__(self
                     ,optimizer='adamw'
                     ,alternative = 1
                     ,reducer=None
                     ,lroptimizer ='MultiStepLR'
                     ,freeze_backbone = 'false'
                     ,aggregator = 'mixvpr'
                     ,mixVprDepth = 5
                     ,mlp_ratio = 5
                     ,out_rows = 4
                     ,lr_profile='medium'
                     ):
        super().__init__()
        # Hyperparameters
        self.optimizer=optimizer
        self.lroptimizer = lroptimizer
        self.alternative = alternative
        self.freeze_backbone = freeze_backbone
        self.aggregator = aggregator
        self.mixVprDepth = mixVprDepth
        self.mlp_ratio = mlp_ratio
        self.out_rows = out_rows
        self.lr_profile = lr_profile


        if self.freeze_backbone == 'true':
          self.backbone = ResNet(layers_to_freeze = 0)
        else:
          self.backbone = ResNet(layers_to_freeze = 2)

        if self.aggregator == 'mixvpr':
          self.aggregator = MixVPR(mix_depth = self.mixVprDepth, mlp_ratio=self.mlp_ratio, out_rows=self.out_rows)
        elif self.aggregator == 'avgPool':
          self.aggregator = AVGPool()
        elif self.aggregator == 'GeMPool':
          self.aggregator = GeMPool()



        if alternative == 1:

          #Alternative 1:
          #Advantages: This combination is commonly used and effective in distinguishing between similar and dissimilar pairs, making it robust for various VPR tasks.
          self.loss_fn = losses.TripletMarginLoss(margin=0.1, swap=False, smooth_loss=False, triplets_per_anchor='all')
          self.miner = miners.TripletMarginMiner(margin=0.2, type_of_triplets="semihard")

        elif alternative ==2:
          #Alternative 2:
          #Advantages: NTXentLoss is particularly effective for self-supervised learning setups and can handle large batches, which may enhance the learning of robust embeddings.
          self.loss_fn = losses.NTXentLoss(temperature=0.07)
          self.miner = miners.MultiSimilarityMiner(epsilon=0.1)

        elif alternative == 3:
          #Alternative 3:
          #Advantages: CircleLoss explicitly optimizes the decision boundary, and the DistanceWeightedMiner ensures that hard samples are mined effectively without an overabundance of easy negatives.
          self.loss_fn = losses.CircleLoss(m=0.4,gamma=80)
          self.miner = miners.DistanceWeightedMiner(cutoff=0.5, nonzero_loss_cutoff=1.4)

        elif alternative == 4:
          #Alternative 4:
          self.loss_fn = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
          self.miner = miners.PairMarginMiner(pos_margin=0.2, neg_margin=0.8)

        elif alternative == 5:
          #Alternative 5:
          self.loss_fn = losses.MultiSimilarityLoss(alpha=1.0,beta=50,base=0.0,distance=DotProductSimilarity())
          self.miner = miners.MultiSimilarityMiner(0.1, distance=CosineSimilarity())

        elif alternative == 6:
          #Alternative 6:
          self.loss_fn = losses.CosFaceLoss(62515, 256, margin=0.35, scale=64).to(torch.device('cuda'))
          self.miner = miners.MultiSimilarityMiner(0.1, distance=CosineSimilarity())

        elif alternative == 7:
          #Alternative 7:
          self.loss_fn = losses.ArcFaceLoss(62515, 256, margin=28.6, scale=64).to(torch.device('cuda'))
          self.miner = miners.MultiSimilarityMiner(0.1, distance=CosineSimilarity())

        # Metrics (initialize empty lists)
        self.faiss_gpu = False
        self.validation_step_outputs = []

        # Save hyperparameters
        self.save_hyperparameters()

    def forward(self, x):
        x = self.backbone(x)
        x = self.aggregator(x)
        return x

    def configure_optimizers(self):

      if self.alternative in (6,7):

        if self.optimizer.lower() == 'sgd':
          optimizer = torch.optim.SGD(self.loss_fn.parameters(),lr=0.03,weight_decay=0.001,momentum=0.9 )
        elif self.optimizer.lower() == 'adam':
          optimizer = torch.optim.Adam(self.loss_fn.parameters(),lr=0.0002)
        elif self.optimizer.lower() == 'adamw':
          optimizer = torch.optim.AdamW(self.loss_fn.parameters(),lr=0.0002,weight_decay=0.001)
        elif self.optimizer.lower() == 'asgd':
          optimizer = torch.optim.ASGD(self.loss_fn.parameters(),lr=0.1)

      else:
        if self.optimizer.lower() == 'sgd':
          if self.lr_profile == 'medium':
            optimizer = torch.optim.SGD(self.parameters(),lr=0.03,weight_decay=0.001,momentum=0.9 )
          elif self.lr_profile == 'high':
            optimizer = torch.optim.SGD(self.parameters(),lr=0.1,weight_decay=0.01,momentum=0.9 )
          elif self.lr_profile == 'low':
            optimizer = torch.optim.SGD(self.parameters(),lr=0.005,weight_decay=0.0001,momentum=0.9 )

        elif self.optimizer.lower() == 'adam':
           if self.lr_profile == 'medium':
            optimizer = torch.optim.Adam(self.parameters(),lr=0.0002)
           elif self.lr_profile == 'high':
            optimizer = torch.optim.Adam(self.parameters(),lr=0.05)
           elif self.lr_profile == 'low':
            optimizer = torch.optim.Adam(self.parameters(),lr=0.00005)

        elif self.optimizer.lower() == 'adamw':
           if self.lr_profile == 'medium':
            optimizer = torch.optim.AdamW(self.parameters(),lr=0.0002,weight_decay=0.0001)
           elif self.lr_profile == 'high':
            optimizer = torch.optim.AdamW(self.parameters(),lr=0.005,weight_decay=0.001)
           elif self.lr_profile == 'low':
            optimizer = torch.optim.AdamW(self.parameters(),lr=0.00005,weight_decay=0.00001)

        elif self.optimizer.lower() == 'asgd':
           if self.lr_profile == 'medium':
            optimizer = torch.optim.ASGD(self.parameters(),lr=0.05)
           elif self.lr_profile == 'high':
            optimizer = torch.optim.ASGD(self.parameters(),lr=0.5)
           elif self.lr_profile == 'low':
            optimizer = torch.optim.ASGD(self.parameters(),lr=0.005)

      if self.lroptimizer == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3,6,9,12,15,18,21,24,27], gamma=0.5)
      elif self.lroptimizer == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1)
      elif self.lroptimizer == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
      elif self.lroptimizer == 'OneCycle':
        scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=optimizer.param_groups[0]['lr'], total_steps=105)


      return {
          'optimizer': optimizer,
          'lr_scheduler': scheduler, # Changed scheduler to lr_scheduler
          'monitor': 'sfxsR1'
       }

    def loss_function(self, descriptors, labels):
        # Mining (if applicable)
        if self.miner is not None:
            miner_outputs = self.miner(descriptors, labels)
            loss = self.loss_fn(descriptors, labels, miner_outputs)
        else:
            loss = self.loss_fn(descriptors, labels)
        return loss

    def training_step(self, batch, batch_idx):
        places, labels = batch
        BS, N, ch, h, w = places.shape
        # Reshape and forward pass
        images = places.view(BS * N, ch, h, w)
        labels = labels.view(-1)
        descriptors = self(images)
        # Calculate loss
        loss = self.loss_function(descriptors, labels)


        # Log loss and return dictionary
        self.log('train_loss', round(loss.item(),2), prog_bar=True, logger=True)
        self.log('loss', round(loss.item(),2), logger=True)
        return {'loss': loss}

    def on_training_epoch_end(self, training_step_outputs):
        # we empty the batch_acc list for next epoch
          self.batch_acc = []


    # For validation, we will also iterate step by step over the validation set
    # this is the way Pytorch Lghtning is made. All about modularity, folks.
    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        places, _ = batch
        # calculate descriptors
        descriptors = self(places)
        self.validation_step_outputs.append(descriptors.detach().cpu())
        return descriptors.detach().cpu()

    def on_validation_epoch_end(self):
        dm = self.trainer.datamodule
        # The following line is a hack: if we have only one validation set, then
        # we need to put the outputs in a list (Pytorch Lightning does not do it presently)
        if len(dm.val_datasets)==1: # we need to put the outputs in a list
            val_step_outputs = [self.validation_step_outputs]

        for i, (val_set_name, val_dataset) in enumerate(zip(dm.val_set_names, dm.val_datasets)):
            feats = torch.concat(val_step_outputs[i], dim=0)

            num_references = val_dataset.num_references
            num_queries = val_dataset.num_queries
            ground_truth = val_dataset.ground_truth

            # split to ref and queries
            r_list = feats[ : num_references]
            q_list = feats[num_references : ]

            recalls_dict, predictions = get_validation_recalls(r_list=r_list,
                                                q_list=q_list,
                                                k_values=[1, 5, 10, 15, 20, 25],
                                                gt=ground_truth,
                                                print_results=True,
                                                dataset_name=val_set_name,
                                                faiss_gpu=self.faiss_gpu
                                                )
            del r_list, q_list, feats, num_references, ground_truth

            self.log(f'sfxsR1', recalls_dict[1], prog_bar=True, logger=True)
            self.log(f'sfxsR5', recalls_dict[5], prog_bar=True, logger=True)
            self.log(f'sfxsR10', recalls_dict[10], prog_bar=True, logger=True)
            del self.validation_step_outputs
            self.validation_step_outputs = []
        print('\n\n')


In [19]:

def objective(trial):
    # Hyperparameters to tune
    optimizer = trial.suggest_categorical('optimizer', ['adam'])
    batch_size = trial.suggest_categorical('batch_size', [64])
    lroptimizer = trial.suggest_categorical('lroptimizer', ['MultiStepLR'])
    alternative = trial.suggest_categorical('alternative', [5])
    aggregator = trial.suggest_categorical('aggregator', ['mixvpr'])
    freeze_backbone = trial.suggest_categorical('freeze_backbone', ['false'])
    lr_profile = trial.suggest_categorical('lr_profile', ['medium'])

     # Log the hyperparameters being tested
    print('-----------------')
    print(f'Trial {trial.number}: optimizer={optimizer}, '
          f'batch_size={batch_size}, lroptimizer={lroptimizer}, '
          f'aggregator={aggregator}, freeze_backbone={freeze_backbone}, '
          f'alternative={alternative},'
          f'lr_profile={lr_profile},freeze_backbone={freeze_backbone},'
          'mixVprDepth=2, mlp_ratio=8, out_rows=7')


    print('-----------------')


    if alternative == 1:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_triple_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==2:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_ntx_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==3:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_circle_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==4:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_contrastive_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==5:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_multisim_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==6:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_cosface_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'
    elif alternative ==7:
      fileNaming ='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/ablation_arcface_mixvpr_{epoch}_{loss}_{sfxsR1}_{sfxsR5}'



    checkpoint_cb = ModelCheckpoint(
      monitor='sfxsR1',
      filename=fileNaming,
      auto_insert_metric_name=True,
      save_weights_only=False,
      save_top_k=-1,

    )


    trainer = pl.Trainer(
        accelerator='gpu',
        devices=1,
        default_root_dir='/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/',
        num_sanity_val_steps=0,
        precision='16-mixed',
        max_epochs=30,
        callbacks=[checkpoint_cb],  # Add both callbacks here
    )




    datamodule = GSVCitiesDataModule(batch_size)
    model = VPRModel(optimizer = optimizer, lroptimizer = lroptimizer, alternative = alternative, aggregator=aggregator, freeze_backbone= freeze_backbone,mixVprDepth=2, mlp_ratio=8, out_rows=7,lr_profile=lr_profile)


    trainer.fit(model=model, datamodule=datamodule)

    best_score = trainer.callback_metrics.get('sfxsR1')



    return best_score.item()  # Convert the tensor to float

if __name__ == '__main__':

    seed_everything(seed=1, workers=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=1)

    print('Best trial:')
    trial = study.best_trial
    print('  Value: {}'.format(trial.value))
    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))

    study = optuna.create_study()
    joblib.dump(study, "/content/drive/MyDrive/Colab Notebooks/Latest Version/15 - FINAL TUNING/study.pkl")

INFO:lightning_lite.utilities.seed:Global seed set to 1
[I 2024-07-13 14:26:54,177] A new study created in memory with name: no-name-738ffc55-e874-4f3b-b558-639839d1477f
INFO:lightning_lite.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning_lite.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:lightning_lite.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:lightning_lite.utilities.rank_zero:HPU available: False, using: 0 HPUs


-----------------
Trial 0: optimizer=adam, batch_size=64, lroptimizer=MultiStepLR, aggregator=mixvpr, freeze_backbone=false, alternative=5,lr_profile=medium,freeze_backbone=false,mixVprDepth=2, mlp_ratio=8, out_rows=7
-----------------
path gt root:  /content/drive/MyDrive/Datasets/sf_xs/val/ val_dbImages.npy


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


num queries
7993

+----------------------+
|   Training Dataset   |
+-------------+--------+
| # of cities | 23     |
| # of places | 62514  |
| # of images | 524858 |
+-------------+--------+

+-------------------------------+
|        Training config        |
+------------------+------------+
| Batch size (PxK) | 64x4       |
| # of iterations  | 976        |
| Image size       | (224, 224) |
+------------------+------------+


INFO:pytorch_lightning.callbacks.model_summary:
  | Name       | Type                 | Params | Mode 
------------------------------------------------------------
0 | backbone   | ResNet               | 2.8 M  | train
1 | aggregator | MixVPR               | 1.3 M  | train
2 | loss_fn    | MultiSimilarityLoss  | 0      | train
3 | miner      | MultiSimilarityMiner | 0      | train
------------------------------------------------------------
3.4 M     Trainable params
673 K     Non-trainable params
4.1 M     Total params
16.334    Total estimated model params size (MB)
  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

1792


  x.storage().data_ptr() + x.storage_offset() * 4)




+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 72.44 | 81.00 | 83.90 | 85.73 | 87.00 | 88.10 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.19 | 83.52 | 86.41 | 88.25 | 89.29 | 90.28 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 75.99 | 83.99 | 86.98 | 88.79 | 90.07 | 90.82 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 77.29 | 85.59 | 88.20 | 89.99 | 91.29 | 91.98 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 77.81 | 85.61 | 88.33 | 90.22 | 91.19 | 92.12 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 78.42 | 85.99 | 88.90 | 90.52 | 91.43 | 92.32 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.08 | 86.76 | 89.72 | 91.34 | 92.16 | 93.02 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.06 | 86.70 | 89.75 | 91.34 | 92.14 | 92.82 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 78.94 | 86.56 | 89.59 | 91.13 | 92.08 | 92.79 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.79 | 87.45 | 90.33 | 91.67 | 92.64 | 93.22 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.98 | 87.38 | 90.25 | 91.57 | 92.56 | 93.07 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.74 | 87.55 | 90.27 | 91.59 | 92.51 | 93.12 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 79.89 | 87.48 | 90.20 | 91.53 | 92.48 | 93.02 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.05 | 87.65 | 90.40 | 91.66 | 92.58 | 93.21 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.01 | 87.54 | 90.32 | 91.54 | 92.42 | 93.24 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.06 | 87.53 | 90.17 | 91.52 | 92.49 | 93.16 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.07 | 87.60 | 90.25 | 91.51 | 92.48 | 93.13 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.18 | 87.55 | 90.29 | 91.56 | 92.47 | 93.16 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.25 | 87.68 | 90.40 | 91.56 | 92.51 | 93.18 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.16 | 87.66 | 90.43 | 91.69 | 92.59 | 93.28 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.13 | 87.81 | 90.42 | 91.64 | 92.66 | 93.36 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.12 | 87.65 | 90.47 | 91.67 | 92.54 | 93.31 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.15 | 87.68 | 90.35 | 91.67 | 92.69 | 93.28 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.21 | 87.65 | 90.33 | 91.71 | 92.62 | 93.26 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.12 | 87.65 | 90.29 | 91.64 | 92.61 | 93.31 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.15 | 87.69 | 90.35 | 91.66 | 92.64 | 93.26 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.13 | 87.63 | 90.29 | 91.59 | 92.56 | 93.26 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.15 | 87.70 | 90.40 | 91.72 | 92.57 | 93.28 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.07 | 87.64 | 90.27 | 91.72 | 92.56 | 93.24 |
+----------+-------+-------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

1792


+----------------------------------------------------------+
|                  Performance on sfxsval                  |
+----------+-------+-------+-------+-------+-------+-------+
|    K     |   1   |   5   |   10  |   15  |   20  |   25  |
+----------+-------+-------+-------+-------+-------+-------+
| Recall@K | 80.11 | 87.55 | 90.35 | 91.59 | 92.61 | 93.28 |
+----------+-------+-------+-------+-------+-------+-------+





INFO:lightning_lite.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.
[I 2024-07-14 06:00:54,617] Trial 0 finished with value: 0.8010759353637695 and parameters: {'optimizer': 'adam', 'batch_size': 64, 'lroptimizer': 'MultiStepLR', 'alternative': 5, 'aggregator': 'mixvpr', 'freeze_backbone': 'false', 'lr_profile': 'medium'}. Best is trial 0 with value: 0.8010759353637695.
[I 2024-07-14 06:00:54,619] A new study created in memory with name: no-name-323a7296-b597-4e36-bdf6-334bb2e0c0b0


Best trial:
  Value: 0.8010759353637695
  Params: 
    optimizer: adam
    batch_size: 64
    lroptimizer: MultiStepLR
    alternative: 5
    aggregator: mixvpr
    freeze_backbone: false
    lr_profile: medium
