In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import random
import logging
from pathlib import Path
from tqdm import tqdm

from imageSearch.DatabaseManager import FAISSDatabaseManager
from imageSearch.ImageFeatureExtractor import Resnet50ImageFeatureExtractor, ONNXImageFeatureExtractor, DreamSimImageFeatureExtractor
from imageSearch.utils.display_notebook import display_images_grid_thml

def register_images(image_dir, db_manager, extractor):
    image_dir = Path(image_dir)
    image_paths = [
        str(p) for p in image_dir.glob("**/*")
        if p.suffix.lower() in [".jpg", ".jpeg", ".png"]
    ]
    rows = []
    for img_path in tqdm(image_paths, desc="Registering images"):
        try:
            feature = extractor.extract_feature(img_path)
            rows.append({"embedding": feature, "file_path": img_path})
        except Exception as e:
            logging.error(f"Error processing {img_path}: {e}")
    db_manager.insert_embeddings(rows)

def search_images(db_manager, extractor, query_image_path, k=5):
    # 特徴抽出
    query_feature = extractor.extract_feature(query_image_path)
    # FAISSによる検索
    distances, indices = db_manager.search(query_feature, k)
    results = []
    for d, idx in zip(distances[0], indices[0]):
        # インデックスが-1の場合は該当なしとする
        if idx != -1:
            file_path = db_manager.file_paths[idx]
            results.append((d, file_path))
    return results


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
logging.basicConfig(level=logging.INFO)

# 画像特徴抽出器のインスタンス生成
# extractor = Resnet50ImageFeatureExtractor()
# extractor = DreamSimImageFeatureExtractor(cache_dir="./model/DreamSim/")
extractor = ONNXImageFeatureExtractor(onnx_path="./model/ONNX/mobilenet_v2.onnx")

# FAISSの設定
index_file = "./localDB/image_feature.index"  # インデックスの保存先ファイル
recreate = True
db_manager = FAISSDatabaseManager(extractor.dim, index_file=index_file, recreate=recreate)

# 画像ディレクトリから画像登録
image_directory = "../../sample_data/WIDER_OpenData/"
register_images(image_directory, db_manager, extractor)


Registering images: 100%|██████████| 100/100 [00:07<00:00, 13.76it/s]


In [6]:
# クエリ画像を用いた検索
query_image = str(random.choice(list(Path("../../sample_data/").glob("**/*.jpg"))))
k = 5
results = search_images(db_manager, extractor, query_image, k=k)

# 検索結果の表示
if results:
    for distance, file_path in results:
        print(f"Distance: {distance:.4f}, File: {file_path}")
else:
    print("検索結果が見つかりませんでした。")
    
# FAISSインデックスの保存
db_manager.save(index_file=index_file)

display_images_grid_thml([query_image] + ["" for _ in range(k-1)] + [r[1] for r in results], cols=k, width=200)

Distance: 0.3055, File: ..\..\sample_data\WIDER_OpenData\4--Dancing\4_Dancing_Dancing_4_97.jpg
Distance: 0.2894, File: ..\..\sample_data\WIDER_OpenData\3--Riot\3_Riot_Riot_3_123.jpg
Distance: 0.2849, File: ..\..\sample_data\WIDER_OpenData\2--Demonstration\2_Demonstration_Demonstration_Or_Protest_2_32.jpg
Distance: 0.2701, File: ..\..\sample_data\WIDER_OpenData\9--Press_Conference\9_Press_Conference_Press_Conference_9_31.jpg
Distance: 0.2522, File: ..\..\sample_data\WIDER_OpenData\9--Press_Conference\9_Press_Conference_Press_Conference_9_41.jpg
