In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import logging
from PIL import Image

from src.DatabaseManager.faiss_database_manager import FAISSDatabaseManager
from src.ImageFeatureExtractor.onnx_image_feature_extractor import ONNXImageFeatureExtractor

def register_images(image_dir, db_manager, extractor):
    image_dir = Path(image_dir)
    image_paths = [
        str(p) for p in image_dir.glob("**/*")
        if p.suffix.lower() in [".jpg", ".jpeg", ".png"]
    ]
    rows = []
    for img_path in tqdm(image_paths, desc="Registering images"):
        try:
            feature = extractor.extract_feature(img_path)
            rows.append({"embedding": feature, "file_path": img_path})
        except Exception as e:
            logging.error(f"Error processing {img_path}: {e}")
    db_manager.insert_embeddings(rows)

def search_images(db_manager, extractor, query_image_path, k=5):
    # 特徴抽出
    query_feature = extractor.extract_feature(query_image_path)
    # FAISSによる検索
    distances, indices = db_manager.search(query_feature, k)
    results = []
    for d, idx in zip(distances[0], indices[0]):
        # インデックスが-1の場合は該当なしとする
        if idx != -1:
            file_path = db_manager.file_paths[idx]
            results.append((d, file_path))
    return results



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
logging.basicConfig(level=logging.INFO)

# 画像特徴抽出器のインスタンス生成
extractor = ONNXImageFeatureExtractor(onnx_path="./model/ONNX/mobilenet_v2.onnx")

# FAISSの設定
index_file = "faiss_index.index"  # インデックスの保存先ファイル
recreate = True
db_manager = FAISSDatabaseManager(extractor.dim, index_file=index_file, recreate=recreate)

# 画像ディレクトリから画像登録
image_directory = "../../sample_data/WIDER_OpenData/"  # 適切なパスに変更してください
register_images(image_directory, db_manager, extractor)


INFO:src.ImageFeatureExtractor.onnx_image_feature_extractor:ONNX model input: name=input.1, shape=[1, 3, 224, 224], type=tensor(float)
INFO:src.ImageFeatureExtractor.onnx_image_feature_extractor:モデルの出力次元: 1000
Registering images: 100%|██████████| 100/100 [00:02<00:00, 39.97it/s]


In [5]:
# クエリ画像を用いた検索
query_image = "../../sample_data/WIDER_OpenData/3--Riot/3_Riot_Riot_3_101.jpg"  # 適切なパスに変更してください
results = search_images(db_manager, extractor, query_image, k=5)

# 検索結果の表示
if results:
    for distance, file_path in results:
        print(f"Distance: {distance:.4f}, File: {file_path}")
else:
    print("検索結果が見つかりませんでした。")

# FAISSインデックスの保存
db_manager.save(index_file=index_file)

Distance: 1.0000, File: ..\..\sample_data\WIDER_OpenData\3--Riot\3_Riot_Riot_3_101.jpg
Distance: 0.8943, File: ..\..\sample_data\WIDER_OpenData\3--Riot\3_Riot_Riot_3_199.jpg
Distance: 0.8058, File: ..\..\sample_data\WIDER_OpenData\3--Riot\3_Riot_Riot_3_186.jpg
Distance: 0.7021, File: ..\..\sample_data\WIDER_OpenData\3--Riot\3_Riot_Riot_3_184.jpg
Distance: 0.5565, File: ..\..\sample_data\WIDER_OpenData\6--Funeral\6_Funeral_Funeral_6_109.jpg
