In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import random
import logging
from pathlib import Path
from tqdm import tqdm

from imageSearch.DatabaseManager import FAISSDatabaseManager
from imageSearch.ImageFeatureExtractor import Resnet50ImageFeatureExtractor, ONNXImageFeatureExtractor, DreamSimImageFeatureExtractor
from imageSearch.utils.display_notebook import display_images_grid_html
from imageSearch.utils.logger_util import configure_logger

logger = configure_logger()
# logger.setLevel(logging.INFO)
logger.setLevel(logging.ERROR)


def register_images(image_dir, db_manager, extractor):
    image_dir = Path(image_dir)
    image_paths = [
        str(p) for p in image_dir.glob("**/*")
        if p.suffix.lower() in [".jpg", ".jpeg", ".png"]
    ]
    rows = []
    for img_path in tqdm(image_paths, desc="Registering images"):
        try:
            feature = extractor.extract_feature(img_path)
            rows.append({"embedding": feature, "file_path": img_path})
        except Exception as e:
            logging.error(f"Error processing {img_path}: {e}")
    db_manager.insert_embeddings(rows)

def search_images(db_manager, extractor, query_image_path, k=5):
    # 特徴抽出
    query_feature = extractor.extract_feature(query_image_path)
    # FAISSによる検索
    distances, indices = db_manager.search(query_feature, k)
    results = []
    for d, idx in zip(distances[0], indices[0]):
        # インデックスが-1の場合は該当なしとする
        if idx != -1:
            file_path = db_manager.file_paths[idx]
            results.append((d, file_path))
    return results


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
# FAISSのデータベースを読み込み
db_manager = FAISSDatabaseManager(index_file=Path("./localDB/FAISS/sampleDB/sampleDB.index"))


In [6]:
# 画像特徴抽出器のインスタンス生成
# extractor = Resnet50ImageFeatureExtractor()
# extractor = DreamSimImageFeatureExtractor(cache_dir="./model/DreamSim/")
extractor = ONNXImageFeatureExtractor(onnx_path="./model/ONNX/mobilenet_v2.onnx")

# クエリ画像を用いた検索
query_image = str(random.choice(list(Path("../../sample_data/").glob("**/*.jpg"))))
k = 5
results = search_images(db_manager, extractor, query_image, k=k)

# 検索結果の表示
if results:
    for distance, file_path in results:
        print(f"Distance: {distance:.4f}, File: {file_path}")
else:
    print("検索結果が見つかりませんでした。")
    

images = [query_image] + ["" for _ in range(k-1)] + [r[1] for r in results]
labels = [f"{Path(query_image).name}"] + ["" for _ in range(k-1)] + [f"{Path(r[1]).name}\n(Distance: {r[0]:.05})" for r in results]
display_images_grid_html(images, labels=labels, cols=k)

Distance: 1.0000, File: ..\..\sample_data\coco_sample_datasets\sample_coco_train2017\000000544261.jpg
Distance: 0.4118, File: ..\..\sample_data\coco_sample_datasets\sample_coco_train2017\000000544250.jpg
Distance: 0.1999, File: ..\..\sample_data\coco_sample_datasets\sample_coco_train2017\000000188087.jpg
Distance: 0.1425, File: ..\..\sample_data\coco_sample_datasets\sample_coco_train2017\000000188120.jpg
Distance: 0.1347, File: ..\..\sample_data\coco_sample_datasets\sample_coco_train2017\000000188130.jpg


0,1,2,3,4
000000544261.jpg,,,,
000000544261.jpg (Distance: 1.0),000000544250.jpg (Distance: 0.41182),000000188087.jpg (Distance: 0.19992),000000188120.jpg (Distance: 0.14253),000000188130.jpg (Distance: 0.13473)
