In [1]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import tqdm, tqdm.notebook
tqdm.tqdm = tqdm.notebook.tqdm
from tqdm import tqdm
from torchvision import transforms
from PIL import Image
import pickle
import numpy as np
import pandas as pd
from hloc.utils.read_write_model import read_images_binary
import os
import json

# Base path configuration (상대 경로)
BASE_PATH = '..'

# Derived paths
PATHS = {
    'database': os.path.join(BASE_PATH, 'DataBase'),
    'global_desc': os.path.join(BASE_PATH, 'GlobalDescriptors'),
    'outputs': os.path.join(BASE_PATH, 'outputs')
}

# Constants
K_VALUES = [200, 150, 100, 50, 20, 5, 1]
MODEL_NAMES = ['1st', '2nd', '3rd', '4th', '5th']


In [2]:
def load_data():
    """데이터베이스와 이미지 데이터 로드"""
    database_path = os.path.join(PATHS['database'], 'DataBase_Norm_Tiny.pickle')
    images_path = os.path.join(PATHS['outputs'], 'aachen/sfm_superpoint+superglue/images.bin')
    
    with open(database_path, 'rb') as f:
        database = pickle.load(f)
    
    images = read_images_binary(images_path)
    return database, images

def calculate_recall(reference_file, anc_pos_relations, images):
    """Recall 계산"""
    search_results = {}
    total_image_id = list(images.keys())
    
    # Parse reference file
    with open(reference_file, 'r') as file:
        for line in file:
            try:
                anchor_img, result_img = line.strip().split()
                anchor_key = None
                result_key = None
                
                for j in total_image_id:
                    tar_img = images[j]
                    if tar_img.name == anchor_img:
                        anchor_key = j
                    if tar_img.name == result_img:
                        result_key = j
                        
                if anchor_key not in search_results:
                    search_results[anchor_key] = []
                search_results[anchor_key].append(result_key)
            except ValueError as e:
                print(f"Error processing line in {reference_file}: {line.strip()}")
                continue
    
    # Calculate recall
    recalls = []
    for anchor_key, results in search_results.items():
        found_positive = 0
        for item in anc_pos_relations:
            if item['Anchor'][0] == anchor_key:
                anchor_positives = item['Positive']
                for result_key in results:
                    if result_key in anchor_positives:
                        found_positive = 1
                        break
                break
        recalls.append(found_positive)
    
    return np.mean(recalls) if recalls else 0.0

def process_all():
    """모든 모델에 대해 처리"""
    database, images = load_data()
    results = {k: [] for k in K_VALUES}
    
    for model_name in tqdm(MODEL_NAMES):
        print(f"\nProcessing {model_name} model...")
        
        for k in K_VALUES:
            reference_file = os.path.join(PATHS['global_desc'], f'{model_name}_StudentReference_{k}.txt')
            
            try:
                recall = calculate_recall(reference_file, database, images)
                results[k].append(recall)
                print(f"k={k}: {recall:.4f}")
            except FileNotFoundError:
                print(f"File not found for {model_name} model, k={k}")
                results[k].append(None)
    
    return results

def display_results(results):
    """결과 출력"""
    print("\n=== Average Recall Values ===")
    print("\nk values:", end="")
    for k in K_VALUES:
        print(f"\t{k}", end="")
    print()
    
    # Print individual model results
    for i, model_name in enumerate(MODEL_NAMES):
        print(f"\n{model_name} model:", end="")
        for k in K_VALUES:
            if results[k][i] is not None:
                print(f"\t{results[k][i]:.4f}", end="")
            else:
                print("\tN/A", end="")
    
    # Calculate and print averages
    print("\n\nAverage:", end="")
    for k in K_VALUES:
        values = [v for v in results[k] if v is not None]
        if values:
            avg = np.mean(values)
            print(f"\t{avg:.4f}", end="")
        else:
            print("\tN/A", end="")
    
    # Print summary for easy copying
    print("\n\nprint('Average recall values for different k:')")
    for k in K_VALUES:
        values = [v for v in results[k] if v is not None]
        if values:
            avg = np.mean(values)
            print(f"k = {k}: {avg:.4f}")

def save_results(results):
    """결과를 JSON 파일로 저장"""
    print("\nSaving results to recall_results.json")
    output_data = {
        str(k): [float(v) if v is not None else None for v in vals]
        for k, vals in results.items()
    }
    
    with open('recall@K_results.json', 'w') as f:
        json.dump(output_data, f, indent=4)

In [3]:
results = process_all()

  0%|          | 0/5 [00:00<?, ?it/s]


Processing 1st model...
k=200: 0.9977
k=150: 0.9977
k=100: 0.9965
k=50: 0.9908
k=20: 0.9850
k=5: 0.9643
k=1: 0.9021

Processing 2nd model...
k=200: 0.9965
k=150: 0.9965
k=100: 0.9965
k=50: 0.9942
k=20: 0.9838
k=5: 0.9653
k=1: 0.8924

Processing 3rd model...
k=200: 0.9965
k=150: 0.9942
k=100: 0.9931
k=50: 0.9919
k=20: 0.9873
k=5: 0.9711
k=1: 0.9062

Processing 4th model...
k=200: 0.9919
k=150: 0.9896
k=100: 0.9884
k=50: 0.9850
k=20: 0.9745
k=5: 0.9583
k=1: 0.8843

Processing 5th model...
k=200: 0.9988
k=150: 0.9977
k=100: 0.9942
k=50: 0.9907
k=20: 0.9850
k=5: 0.9676
k=1: 0.9016


In [4]:
display_results(results)


=== Average Recall Values ===

k values:	200	150	100	50	20	5	1

1st model:	0.9977	0.9977	0.9965	0.9908	0.9850	0.9643	0.9021
2nd model:	0.9965	0.9965	0.9965	0.9942	0.9838	0.9653	0.8924
3rd model:	0.9965	0.9942	0.9931	0.9919	0.9873	0.9711	0.9062
4th model:	0.9919	0.9896	0.9884	0.9850	0.9745	0.9583	0.8843
5th model:	0.9988	0.9977	0.9942	0.9907	0.9850	0.9676	0.9016

Average:	0.9963	0.9951	0.9938	0.9905	0.9831	0.9653	0.8973

print('Average recall values for different k:')
k = 200: 0.9963
k = 150: 0.9951
k = 100: 0.9938
k = 50: 0.9905
k = 20: 0.9831
k = 5: 0.9653
k = 1: 0.8973


In [5]:
save_results(results)


Saving results to recall_results.json
