In [1]:
!pip install pytorch_lightning
!pip install pytorch_metric_learning
!pip install faiss-cpu
!pip install faiss-gpu
!pip install lightning_lite



In [2]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from pathlib import Path
from PIL import Image
from prettytable import PrettyTable
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_metric_learning import losses, miners
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as T
from lightning_lite.utilities.seed import seed_everything
import csv
import faiss
import faiss.contrib.torch_utils
import numpy as np
import os
import pandas as pd
import pytorch_lightning as pl
import re
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import math
from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.distances import CosineSimilarity, DotProductSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer

In [4]:


##############################################################################################
##############################################################################################

def get_validation_recalls(r_list, q_list, k_values, gt, print_results=True, faiss_gpu=False, dataset_name='dataset without name ?'):



        embed_size = r_list.shape[1]
        print('------------')
        print('Embed size: ')
        print(embed_size)
        print('------------')

        if faiss_gpu:
            res = faiss.StandardGpuResources()
            flat_config = faiss.GpuIndexFlatConfig()
            flat_config.useFloat16 = True
            flat_config.device = 0
            faiss_index = faiss.GpuIndexFlatL2(res, embed_size, flat_config)
        # build index
        else:
            faiss_index = faiss.IndexFlatL2(embed_size)

        # add references
        faiss_index.add(r_list)

        # search for queries in the index
        _, predictions = faiss_index.search(q_list, max(k_values))



        # start calculating recall_at_k
        correct_at_k = np.zeros(len(k_values))
        for q_idx, pred in enumerate(predictions):
            for i, n in enumerate(k_values):
                # if in top N then also in top NN, where NN > N
                if np.any(np.in1d(pred[:n], gt[q_idx])):
                    correct_at_k[i:] += 1
                    break

        correct_at_k = correct_at_k / len(predictions)
        d = {k:v for (k,v) in zip(k_values, correct_at_k)}

        #if print_results:
        print('\n') # print a new line
        table = PrettyTable()
        table.field_names = ['K']+[str(k) for k in k_values]
        table.add_row(['Recall@K']+ [f'{100*v:.2f}' for v in correct_at_k])
        print(table.get_string(title=f"Performance on {dataset_name}"))

        return d, predictions





##############################################################################################
#####DATASETS
##############################################################################################


DATASET_ROOT = '/content/drive/MyDrive/Datasets/sf_xs/val/'
GT_ROOT = '/content/drive/MyDrive/Datasets/sf_xs/val/' # BECAREFUL, this is the ground truth that comes with GSV-Cities

path_obj = Path(DATASET_ROOT)
if not path_obj.exists():
    raise Exception(f'Please make sure the path {DATASET_ROOT} to SanFrancisco dataset is correct')

if not path_obj.joinpath('ref') or not path_obj.joinpath('query'):
    raise Exception(f'Please make sure the directories query and ref are situated in the directory {DATASET_ROOT}')

class SFXSValDataset(Dataset):
    def __init__(self, input_transform = None):


        self.input_transform = input_transform

        # reference images names
        self.dbImages = np.load(GT_ROOT+'val_dbImages.npy')
        #print(len(self.dbImages))

        print('path gt root:  '+GT_ROOT+' val_dbImages.npy')

        # query images names
        self.qImages = np.load(GT_ROOT+'val_qImages.npy')
        #print(len(self.qImages))

        # ground truth
        self.ground_truth = np.load(GT_ROOT+'val_gt.npy', allow_pickle=True)

        # reference images then query images
        self.images = np.concatenate((self.dbImages, self.qImages))

        self.num_references = len(self.dbImages)

        #print('num reference')
        #print(self.num_references)

        print('num queries')
        self.num_queries = len(self.qImages)
        print(self.num_queries)


    def __getitem__(self, index):
        img = Image.open(DATASET_ROOT+self.images[index])

        if self.input_transform:
            img = self.input_transform(img)

        return img, index

    def __len__(self):
        return len(self.images)


##############################################################################################


# Paths
dataset_path = '/content/drive/MyDrive/Datasets/gsv_xs/'

cities_list = []
df = pd.DataFrame()
cities_list = os.listdir(os.path.join(dataset_path, 'train'))

default_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class GSVCitiesDataset(Dataset):
    def __init__(self,
                 cities,
                 img_per_place=8,
                 min_img_per_place=8,
                 random_sample_from_each_place=True,
                 transform=default_transform,
                 base_path=dataset_path):
        super(GSVCitiesDataset, self).__init__()
        self.base_path = base_path
        self.cities = cities_list




        assert img_per_place <= min_img_per_place, \
            f"img_per_place should be less than {min_img_per_place}"
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.random_sample_from_each_place = random_sample_from_each_place
        self.transform = transform

        # Generate the dataframe containing images metadata
        self.dataframe = self._getdataframes()

        # Self labels
        self.labels = self.dataframe['place_id_orig'].unique()



        # Get all unique place ids
        self.places_ids = pd.unique(self.dataframe.index)
        self.total_nb_images = len(self.dataframe)






    @staticmethod
    def get_img_name(row):
        easting = str(row['easting'])
        northing = str(row['northing'])
        zone = str(row['zone'])
        grid_zone = str(row['grid_zone'])
        place_id = str(row['place_id_orig']) #MICHELE
        pano_id = row['pano_id']
        year = str(row['year']).zfill(4)
        month = str(row['month']).zfill(2)
        north_degree = str(row['north_degree']).zfill(3)
        lat, lon = str(row['latitude']), str(row['longitude'])
        name = '@' + easting + '@' + northing + '@' + zone + '@' + grid_zone + '@' + lat + '@' + lon + '@' + pano_id + '@@' + \
            north_degree + '@@@@' + year + month + '@' + place_id + '@' + '.jpg'
        return name

    @staticmethod
    def image_loader(path):
        return Image.open(path).convert('RGB')

    def __len__(self):
        '''Denotes the total number of places (not images)'''
        return len(self.places_ids)

    def __getitem__(self, index):
        place_id = self.places_ids[index]

        # Get the place in form of a dataframe (each row corresponds to one image)
        place = self.dataframe.loc[place_id]

        # Sample K images (rows) from this place
        # We can either sort and take the most recent k images
        # or randomly sample them
        if self.random_sample_from_each_place:
            place = place.sample(n=self.img_per_place)
        else:  # Always get the same most recent images
            place = place.sort_values(
                by=['year', 'month', 'latitude'], ascending=False)
            place = place[:self.img_per_place]

        imgs = []
        for i, row in place.iterrows():
            img_name = self.get_img_name(row)
            img_path = os.path.join(self.base_path, 'train', row['city_id'], img_name)

            img = self.image_loader(img_path)

            if self.transform is not None:
                img = self.transform(img)

            imgs.append(img)

        # NOTE: contrary to image classification where __getitem__ returns only one image
        # in GSVCities, we return a place, which is a tensor of K images (K=self.img_per_place)
        # This will return a tensor of shape [K, channels, height, width]. This needs to be taken into account
        # in the DataLoader (which will yield batches of shape [BS, K, channels, height, width])
        return torch.stack(imgs), torch.tensor(place_id).repeat(self.img_per_place)

    def _getdataframes(self):
        column_names = ['easting', 'northing', 'zone', 'grid_zone', 'latitude', 'longitude', 'pano_id', 'north_degree', 'year', 'month','city_id','place_id_orig']

        list_img_metadata = []

        # Process all records in all folders to obtain a token format of the record and save them in a list
        #1 CITY ONLY
        #self.cities=['phoenix']
        #for city in self.cities:
        #2. ALL CITIES
        for city in cities_list:
            city_path = os.path.join(self.base_path, 'train', city)
            for filename in os.listdir(city_path):
                # Process the file using your process_file function
                img_metadata = self._process_file_name(filename)
                list_img_metadata.append(img_metadata)
        df = pd.DataFrame(list_img_metadata, columns=column_names)
        res = df[df.groupby('place_id_orig')['place_id_orig'].transform(
            'size') >= self.min_img_per_place].copy()

        res['progressive_number'] = res.groupby('place_id_orig').ngroup() + 1
        self.labels_used = res['progressive_number'].max()

        print('-----------------')
        print('self label used')
        print( self.labels_used)
        print('-----------------')
        return res.set_index('progressive_number')

    def _process_file_name(self, file_name):
        # Split the file content using '@' as delimiter
        data_tokens = file_name.split('@')

        # Extract latitude and longitude
        easting = data_tokens[1]
        northing = data_tokens[2]
        zone = data_tokens[3]
        grid_zone = data_tokens[4]
        latitude = data_tokens[5]
        longitude = data_tokens[6]

        # Extract pano ID
        pano_id = data_tokens[7]

        # Extract north degree
        north_degree = data_tokens[9]

        # Extract year and month
        year_month = data_tokens[13]
        year = year_month[:4]
        month = year_month[4:]

        # Extract place ID and city ID
        place_id = data_tokens[14]
        city_id = data_tokens[14].split('_')[1].lower()



        # Structure the extracted data as a dictionary
        img_metadata = {
            'easting': easting,
            'northing': northing,
            'zone': zone,
            'grid_zone': grid_zone,
            'latitude': latitude,
            'longitude': longitude,
            'pano_id': pano_id,
            'north_degree': north_degree,
            'year': year,
            'month': month,
            'city_id':city_id,
            'place_id_orig': place_id,

        }
        return img_metadata


    def _get_max_progressive_number(self):
            return self.labels_used



##############################################################################################
##############################################################################################
##############################################################################################


IMAGENET_MEAN_STD = {'mean': [0.485, 0.456, 0.406],
                     'std': [0.229, 0.224, 0.225]}

VIT_MEAN_STD = {'mean': [0.5, 0.5, 0.5],
                'std': [0.5, 0.5, 0.5]}

#UNCOMMENT FOR ALL CITY
TRAIN_CITIES = [
    'phoenix',
    'prg',
    'prs',
    'rome',
    'trt',
    'washingtondc',
    'bangkok',
    'barcelona',
    'boston',
    'brussels',
    'buenosaires',
    'chicago',
    'lisbon',
    'london',
    'losangeles',
    'madrid',
    'medellin',
    'melbourne',
    'mexicocity',
    'miami',
    'minneapolis',
    'osaka',
    'osl'
]
"""

TRAIN_CITIES = [
    'phoenix'
]
"""


class GSVCitiesDataModule(pl.LightningDataModule):
    def __init__(self,
                 batch_size=32,
                 #DEBUG: batch_size=4,
                 img_per_place=8,
                 min_img_per_place=8,
                 shuffle_all=True,
                 image_size=(224, 224),
                 num_workers=8,
                 show_data_stats=True,
                 cities=TRAIN_CITIES,
                 mean_std=IMAGENET_MEAN_STD,
                 batch_sampler=None,
                 random_sample_from_each_place=True,
                 val_set_names = ['sfxsval']
                 ):
        super().__init__()
        self.batch_size = batch_size
        self.img_per_place = img_per_place
        self.min_img_per_place = min_img_per_place
        self.shuffle_all = shuffle_all
        self.image_size = image_size
        self.num_workers = num_workers
        self.batch_sampler = batch_sampler
        self.show_data_stats = show_data_stats
        self.cities = cities
        self.mean_dataset = mean_std['mean']
        self.std_dataset = mean_std['std']
        self.random_sample_from_each_place = random_sample_from_each_place
        self.save_hyperparameters() # save hyperparameter with Pytorch Lightening
        self.val_set_names=val_set_names

        self.train_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset),
        ])

        self.valid_transform = T.Compose([
            T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR),
            T.ToTensor(),
            T.Normalize(mean=self.mean_dataset, std=self.std_dataset)])

        self.train_loader_config = {
            'batch_size': self.batch_size,
            'num_workers': self.num_workers,
            'drop_last': False,
            'pin_memory': False,
            'shuffle': self.shuffle_all,
            'prefetch_factor': 2,  # Numero di batch pre-caricati per worker
            'persistent_workers': True  # Mantiene i worker attivi tra i batch
        }

        self.valid_loader_config = {
            'batch_size': 256,
            'num_workers': self.num_workers//2,
            'drop_last': False,
            'pin_memory': True,
            'shuffle': False}

        self.numberofclasses = 0



    def setup(self, stage):
        if stage == 'fit':
            # load train dataloader with reload routine
            self.reload()

            self.val_datasets = []
            for valid_set_name in self.val_set_names:
                if 'sfxsval' in valid_set_name.lower():
                    self.val_datasets.append(SFXSValDataset(input_transform=self.valid_transform))
                else:
                    print(f'Validation set {valid_set_name} does not exist or has not been implemented yet')
                    raise NotImplementedError

        if self.show_data_stats:
          self.print_stats()

    def reload(self):
        self.train_dataset = GSVCitiesDataset(
            cities=self.cities,
            img_per_place=self.img_per_place,
            min_img_per_place=self.min_img_per_place,
            random_sample_from_each_place=self.random_sample_from_each_place,
            transform=self.train_transform)
        self.numberofclasses = self.train_dataset._get_max_progressive_number()

    def train_dataloader(self):
        self.reload()
        return DataLoader(dataset=self.train_dataset, **self.train_loader_config)


    def get_classes_count(self):
      return self.numberofclasses


    def val_dataloader(self):
      val_dataloaders = []
      for val_dataset in self.val_datasets:
          val_dataloaders.append(DataLoader(
              dataset=val_dataset, **self.valid_loader_config))
      return val_dataloaders


    def print_stats(self):
          print()  # print a new line
          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False
          table.add_row(["# of cities", f"{len(TRAIN_CITIES)}"])
          table.add_row(["# of places", f'{self.train_dataset.__len__()}'])
          table.add_row(["# of images", f'{self.train_dataset.total_nb_images}'])
          table.add_row(["# of used class", f'{self.train_dataset._get_max_progressive_number()}'])



          print(table.get_string(title="Training Dataset"))
          print()

          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False

          table = PrettyTable()
          table.field_names = ['Data', 'Value']
          table.align['Data'] = "l"
          table.align['Value'] = "l"
          table.header = False
          table.add_row(
              ["Batch size (PxK)", f"{self.batch_size}x{self.img_per_place}"])

          table.add_row(["Image size", f"{self.image_size}"])
          print(table.get_string(title="Training config"))


##############################################################################################
##############################################################################################
##############################################################################################


class ResNet(nn.Module):
    def __init__(self,
                 model_name='18',
                 pretrained=True,
                 layers_to_freeze=0,
                 layers_to_crop=[4],
                 ):

        super().__init__()
        self.model_name = model_name.lower()
        self.layers_to_freeze = layers_to_freeze

        if pretrained:
            # the new naming of pretrained weights, you can change to V2 if desired.
            weights = 'IMAGENET1K_V1'
        else:
            weights = None


        if '18' in model_name:
            self.model = torchvision.models.resnet18(weights=weights)

        else:
            raise NotImplementedError(
                'Backbone architecture not recognized!')

        # freeze only if the model is pretrained
        if pretrained:
            if layers_to_freeze >= 1:
                self.model.layer1.requires_grad_(False)
            if layers_to_freeze >= 2:
                self.model.layer2.requires_grad_(False)
            if layers_to_freeze >= 3:
                self.model.layer3.requires_grad_(False)



        self.model.fc = None

        if 4 in layers_to_crop:
            self.model.layer4 = None
        if 3 in layers_to_crop:
            self.model.layer3 = None

        out_channels = 2048
        if '18' in model_name:
            out_channels = 512

        self.out_channels = out_channels // 2 if self.model.layer4 is None else out_channels
        self.out_channels = self.out_channels // 2 if self.model.layer3 is None else self.out_channels

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        if self.model.layer3 is not None:
            x = self.model.layer3(x)
        if self.model.layer4 is not None:
            x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = x.float()

        return x



##############################################################################################
##############################################################################################


from pytorch_metric_learning import losses

class VPRModel(pl.LightningModule):
    """This is the main model for Visual Place Recognition
    we use PyTorch Lightning for modularity purposes.
    """



    def __init__(self):
        super().__init__()
        # Hyperparameters
        self.lr=0.0002 # 0.03 for sgd
        self.optimizer='adam' # sgd, adam or adamw
        self.weight_decay=0 # 0.001 for sgd or 0.0 for adam
        self.momentum=0.9
        self.warmpup_steps=600
        self.milestones=[4, 8, 12, 16]
        self.lr_mult=0.3

        self.backbone = ResNet()
        self.loss_fn = losses.CosFaceLoss(36004, 256, margin=0.6, scale=30)
        self.miner = miners.MultiSimilarityMiner(0.1, distance=CosineSimilarity())



        # Metrics (initialize empty lists)
        self.batch_acc = []
        self.faiss_gpu = False
        self.validation_step_outputs = []

        # Save hyperparameters
        self.save_hyperparameters()

    def forward(self, x):
        x = self.backbone(x)
        return x

    def configure_optimizers(self):
        if self.optimizer.lower() == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.weight_decay,
                                        momentum=self.momentum)
        elif self.optimizer.lower() == 'adam':
            optimizer = torch.optim.AdamW(self.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.weight_decay)
        else:
            raise ValueError(f'Optimizer {self.optimizer} has not been added to "configure_optimizers()"')

        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_mult)
        return [optimizer], [scheduler]

    def loss_function(self, descriptors, labels):
        # Mining (if applicable)
        if self.miner is not None:
            miner_outputs = self.miner(descriptors, labels)
            loss = self.loss_fn(descriptors, labels, miner_outputs)
        else:
            loss = self.loss_fn(descriptors, labels)



        return loss

    def training_step(self, batch, batch_idx):
        places, labels = batch
        BS, N, ch, h, w = places.shape

       #print('-------------')
       # print('LABELS COUNT')
       # print(labels = labels.view(-1))
       # print('-------------')

        # Reshape and forward pass
        images = places.view(BS * N, ch, h, w)
        labels = labels.view(-1)
        descriptors = self(images)

        # Calculate loss
        loss = self.loss_function(descriptors, labels)


        # Log loss and return dictionary
        self.log('train_loss', round(loss.item(),2), prog_bar=True, logger=True)
        self.log('loss', round(loss.item(),2), logger=True)
        return {'loss': loss}

    def on_training_epoch_end(self, training_step_outputs):
        # we empty the batch_acc list for next epoch
          self.batch_acc = []


    # For validation, we will also iterate step by step over the validation set
    # this is the way Pytorch Lghtning is made. All about modularity, folks.
    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        places, _ = batch
        # calculate descriptors
        descriptors = self(places)
        self.validation_step_outputs.append(descriptors.detach().cpu())
        return descriptors.detach().cpu()

    def on_validation_epoch_end(self):
        dm = self.trainer.datamodule
        # The following line is a hack: if we have only one validation set, then
        # we need to put the outputs in a list (Pytorch Lightning does not do it presently)
        if len(dm.val_datasets)==1: # we need to put the outputs in a list
            val_step_outputs = [self.validation_step_outputs]

        for i, (val_set_name, val_dataset) in enumerate(zip(dm.val_set_names, dm.val_datasets)):
            feats = torch.concat(val_step_outputs[i], dim=0)

            num_references = val_dataset.num_references
            num_queries = val_dataset.num_queries
            ground_truth = val_dataset.ground_truth

            # split to ref and queries
            r_list = feats[ : num_references]
            q_list = feats[num_references : ]

            recalls_dict, predictions = get_validation_recalls(r_list=r_list,
                                                q_list=q_list,
                                                k_values=[1, 5, 10, 15, 20, 25],
                                                gt=ground_truth,
                                                print_results=True,
                                                dataset_name=val_set_name,
                                                faiss_gpu=self.faiss_gpu
                                                )
            del r_list, q_list, feats, num_references, ground_truth

            self.log(f'sfxsR1', recalls_dict[1], prog_bar=True, logger=True)
            self.log(f'sfxsR5', recalls_dict[5], prog_bar=True, logger=True)
            self.log(f'sfxsR10', recalls_dict[10], prog_bar=True, logger=True)
            del self.validation_step_outputs
            self.validation_step_outputs = []
        print('\n\n')

##############################################################################################
##############################################################################################




if __name__ == '__main__':

    seed_everything(seed=1, workers=True)

    # the datamodule contains train and validation dataloaders,
    # refer to ./dataloader/GSVCitiesDataloader.py for details
    # if you want to train on specific cities, you can comment/uncomment
    # cities from the list TRAIN_CITIES
    datamodule = GSVCitiesDataModule()

    model = VPRModel()

    # model params saving using Pytorch Lightning
    # we save the best 3 models accoring to Recall@1 on sfxsval
    checkpoint_cb = ModelCheckpoint(
      monitor='loss',  # Monitora la loss del training
      filename='/content/drive/MyDrive/Datasets/checkpoint/resnet_nogem_cosimilarity_improved_{epoch}_{loss}_{sfxsR1}_{sfxsR5}',
      auto_insert_metric_name=True,
      save_weights_only=False,
      save_top_k=-1,

)


        # Istanzia il trainer senza validazione
    trainer = pl.Trainer(
        accelerator='gpu',
        devices=1,
        default_root_dir='/content/drive/MyDrive/Datasets/checkpoint/',
        num_sanity_val_steps=0,
        precision='16-mixed',
        max_epochs=30,
        callbacks=[checkpoint_cb],
    )

    # we call the trainer, and give it the model and the datamodule
    # now you see the modularity of Pytorch Lighning?
    trainer.fit(model=model, datamodule=datamodule)


INFO:lightning_lite.utilities.seed:Global seed set to 1
INFO:lightning_lite.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning_lite.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:lightning_lite.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:lightning_lite.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:lightning_lite.utilities.rank_zero:HPU available: False, using: 0 HPUs


-----------------
self label used
36003
-----------------


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type                 | Params
--------------------------------------------------
0 | backbone | ResNet               | 2.8 M 
1 | loss_fn  | CosFaceLoss          | 9.2 M 
2 | miner    | MultiSimilarityMiner | 0     
--------------------------------------------------
12.0 M    Trainable params
0         Non-trainable params
12.0 M    Total params
47.999    Total estimated model params size (MB)


path gt root:  /content/drive/MyDrive/Datasets/sf_xs/val/ val_dbImages.npy
num queries
7993

+--------------------------+
|     Training Dataset     |
+-----------------+--------+
| # of cities     | 23     |
| # of places     | 36003  |
| # of images     | 373784 |
| # of used class | 36003  |
+-----------------+--------+

+-------------------------------+
|        Training config        |
+------------------+------------+
| Batch size (PxK) | 32x8       |
| Image size       | (224, 224) |
+------------------+------------+
-----------------
self label used
36003
-----------------


  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


  x.storage().data_ptr() + x.storage_offset() * 4)




+----------------------------------------------------+
|               Performance on sfxsval               |
+----------+------+------+------+------+------+------+
|    K     |  1   |  5   |  10  |  15  |  20  |  25  |
+----------+------+------+------+------+------+------+
| Recall@K | 1.05 | 2.64 | 3.67 | 4.92 | 6.03 | 6.89 |
+----------+------+------+------+------+------+------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+------------------------------------------------------+
|                Performance on sfxsval                |
+----------+------+------+------+------+-------+-------+
|    K     |  1   |  5   |  10  |  15  |   20  |   25  |
+----------+------+------+------+------+-------+-------+
| Recall@K | 2.56 | 5.53 | 7.38 | 9.13 | 10.52 | 11.92 |
+----------+------+------+------+------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+-------------------------------------------------------+
|                 Performance on sfxsval                |
+----------+------+------+------+-------+-------+-------+
|    K     |  1   |  5   |  10  |   15  |   20  |   25  |
+----------+------+------+------+-------+-------+-------+
| Recall@K | 2.69 | 6.12 | 8.70 | 10.88 | 12.49 | 14.07 |
+----------+------+------+------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+------------------------------------------------------+
|                Performance on sfxsval                |
+----------+------+------+------+------+-------+-------+
|    K     |  1   |  5   |  10  |  15  |   20  |   25  |
+----------+------+------+------+------+-------+-------+
| Recall@K | 2.31 | 4.95 | 7.22 | 9.06 | 10.50 | 11.90 |
+----------+------+------+------+------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+-------------------------------------------------------+
|                 Performance on sfxsval                |
+----------+------+------+------+-------+-------+-------+
|    K     |  1   |  5   |  10  |   15  |   20  |   25  |
+----------+------+------+------+-------+-------+-------+
| Recall@K | 2.68 | 6.23 | 8.66 | 10.90 | 12.80 | 14.17 |
+----------+------+------+------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+--------------------------------------------------------+
|                 Performance on sfxsval                 |
+----------+------+------+-------+-------+-------+-------+
|    K     |  1   |  5   |   10  |   15  |   20  |   25  |
+----------+------+------+-------+-------+-------+-------+
| Recall@K | 3.13 | 7.24 | 10.08 | 12.37 | 14.25 | 15.98 |
+----------+------+------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+--------------------------------------------------------+
|                 Performance on sfxsval                 |
+----------+------+------+-------+-------+-------+-------+
|    K     |  1   |  5   |   10  |   15  |   20  |   25  |
+----------+------+------+-------+-------+-------+-------+
| Recall@K | 3.23 | 7.34 | 10.57 | 12.75 | 14.85 | 16.58 |
+----------+------+------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+--------------------------------------------------------+
|                 Performance on sfxsval                 |
+----------+------+------+-------+-------+-------+-------+
|    K     |  1   |  5   |   10  |   15  |   20  |   25  |
+----------+------+------+-------+-------+-------+-------+
| Recall@K | 3.64 | 8.06 | 11.35 | 13.59 | 15.60 | 17.39 |
+----------+------+------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+--------------------------------------------------------+
|                 Performance on sfxsval                 |
+----------+------+------+-------+-------+-------+-------+
|    K     |  1   |  5   |   10  |   15  |   20  |   25  |
+----------+------+------+-------+-------+-------+-------+
| Recall@K | 4.03 | 8.61 | 12.22 | 14.50 | 16.30 | 18.07 |
+----------+------+------+-------+-------+-------+-------+





Validation: |          | 0/? [00:00<?, ?it/s]

------------
Embed size: 
256
------------


+--------------------------------------------------------+
|                 Performance on sfxsval                 |
+----------+------+------+-------+-------+-------+-------+
|    K     |  1   |  5   |   10  |   15  |   20  |   25  |
+----------+------+------+-------+-------+-------+-------+
| Recall@K | 3.47 | 7.66 | 10.67 | 13.05 | 14.89 | 16.30 |
+----------+------+------+-------+-------+-------+-------+





/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
