In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import faiss
import numpy as np
import os

# -----------------------------
# 1️⃣ Load pretrained model (ResNet50)
# -----------------------------
model = models.resnet50(pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))  # bỏ fc layer
model.eval()

# -----------------------------
# 2️⃣ Image transform
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet stats
        std=[0.229, 0.224, 0.225]
    ),
])

def extract_feature(img_path):
    """Trích xuất vector đặc trưng từ ảnh"""
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0)
    with torch.no_grad():
        feat = model(x).squeeze().numpy()
    return feat / np.linalg.norm(feat)  # chuẩn hóa vector

# -----------------------------
# 3️⃣ Build feature database
# -----------------------------
image_dir = "dataset"
features, paths = [], []

for fname in os.listdir(image_dir):
    path = os.path.join(image_dir, fname)
    feat = extract_feature(path)
    features.append(feat)
    paths.append(path)

features = np.vstack(features).astype('float32')

# -----------------------------
# 4️⃣ Create FAISS index
# -----------------------------
d = features.shape[1]
index = faiss.IndexFlatL2(d)
index.add(features)
print(f"Indexed {len(paths)} images.")

# -----------------------------
# 5️⃣ Query search
# -----------------------------
query_img = "dataset/cat1.jpg"  # ảnh truy vấn
query_feat = extract_feature(query_img).reshape(1, -1)

k = 3  # top-k ảnh tương tự
distances, indices = index.search(query_feat, k)

print("\n🔍 Top similar images:")
for rank, idx in enumerate(indices[0]):
    print(f"{rank+1}. {paths[idx]} (distance={distances[0][rank]:.4f})")
