In [1]:
#!pip install faiss-gpu
#!pip install faiss-cpu

In [2]:
#!pip install numpy==1.26.4

In [3]:
import os
import glob
import numpy as np 
import faiss 
import torch
import torch.nn as nn
import torch.nn.functional as F 
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset 

In [4]:
# 1. 導入自定義模組與變換 (與 dataset.py 保持一致)
# ==============================================================================
try:
    # 這裡假設您已定義了 model.py 中的 EmbeddingNet
    from model import EmbeddingNet
    
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD  = [0.229, 0.224, 0.225]
    IMAGE_SIZE = 224

    INFERENCE_TRANSFORMS = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    ])
except ImportError:
    print("❌ 錯誤: 無法導入 'model' 模組。請確保 'model.py' 在當前目錄中。")
    # 設置一個 Placeholder 類，讓程式碼可以繼續執行，但會產生隨機特徵
    class EmbeddingNet(nn.Module):
        def __init__(self, emb_size): super().__init__(); self.emb_size = emb_size
        def forward(self, x): return torch.randn(x.size(0), self.emb_size)
    INFERENCE_TRANSFORMS = transforms.ToTensor()

In [5]:
# 2. 環境設定與路徑
# ==============================================================================
# 假設您的模型和資料庫圖片位於這些路徑 (請根據實際情況修改!)
MODEL_PATH = 'advanced_model.pth' # 來自 train_metric_advanced.py 訓練的最佳模型
EMBEDDING_SIZE = 512 # 應與 model.py 中的 EmbeddingNet 參數一致
DATABASE_IMAGES_DIR = 'deepfashion/images' # 圖片資料庫目錄 (用於建立索引)
FAISS_INDEX_PATH = 'faiss_index_gallery.bin' # FAISS 索引存儲路徑
EMBEDDINGS_NPY_PATH = 'indexed_advanced.npz' # 特徵向量存儲路徑
QUERY_IMAGE_EXAMPLE = os.path.join(DATABASE_IMAGES_DIR, "1000_031.jpg") # 範例查詢圖片


In [6]:
# 3. 核心函數
# ==============================================================================

def load_feature_extractor(model_path, embedding_size=512, device='cuda'):
    """ 載入訓練好的 EmbeddingNet 特徵提取模型。 """
    device = device if torch.cuda.is_available() else 'cpu'
    print(f"Loading model to device: {device}")

    model = EmbeddingNet(embedding_size=embedding_size).to(device)

    try:
        state = torch.load(model_path, map_location=device)
        
        if isinstance(state, dict) and 'model_state' in state:
            model.load_state_dict(state['model_state'])
            print(f"Model loaded from checkpoint. Epoch: {state.get('epoch', 'N/A')}")
        elif isinstance(state, dict):
            if all(k.startswith('module.') for k in state.keys()):
                state = {k[7:]: v for k, v in state.items()}
            model.load_state_dict(state)
            print("Model loaded successfully from pure state_dict.")
        else:
            model.load_state_dict(state)
            print("Model loaded successfully from simple state.")

        model.eval() 
        return model

    except Exception as e:
        print(f"❌ 錯誤: 無法載入模型 {model_path}。原因: {e}")
        return None


def extract_all_embeddings(model, image_root, out_path, batch_size=64, num_workers=0, device='cuda'):
    """
    從整個圖片資料庫中提取特徵向量並存儲為 .npz 文件。
    註: num_workers 參數用於控制 DataLoader 的並行進程數。
    """
    # 建立一個與 extractor.py 中 SimpleFolderDataset 兼容的 Dataset
    class SimpleFolderDataset(Dataset):
        def __init__(self, root, transform):
            self.root = root
            self.files = [f for f in sorted(os.listdir(root)) if f.lower().endswith(('.jpg','.jpeg','.png'))]
            self.transform = transform

        def __len__(self):
            return len(self.files)

        def __getitem__(self, idx):
            fn = self.files[idx]
            path = os.path.join(self.root, fn)
            # 確保使用 PIL 載入 (它能被 DataLoader 正確處理)
            img = Image.open(path).convert('RGB')
            # 執行轉換 (此處包含 ToTensor，可能導致多進程中的 Numpy 錯誤)
            return self.transform(img), fn
            
    device = device if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()

    if not os.path.exists(image_root):
        print(f"❌ 錯誤: 圖片資料庫路徑不存在: {image_root}")
        return None, None

    ds = SimpleFolderDataset(image_root, transform=INFERENCE_TRANSFORMS)
    # 使用傳入的 num_workers
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=(device!='cpu')) 

    all_emb = []
    all_files = []
    
    print(f"開始提取 {len(ds)} 張圖片的特徵 (num_workers={num_workers})...")
    with torch.no_grad():
        for imgs, fns in tqdm(loader, desc="提取特徵"):
            imgs = imgs.to(device)
            emb = model(imgs)
            emb = F.normalize(emb, p=2, dim=1) 
            all_emb.append(emb.cpu().numpy())
            all_files.extend(fns)
    
    all_emb = np.vstack(all_emb)
    
    os.makedirs(os.path.dirname(out_path) or '.', exist_ok=True)
    np.savez_compressed(out_path, embeddings=all_emb, filenames=np.array(all_files))
    print(f"\n✅ 特徵提取完成。Saved embeddings to {out_path} with shape {all_emb.shape}")
    
    return all_emb, all_files


def build_faiss_index(embeddings, index_path):
    """ 使用提取的特徵向量建立 FAISS 索引並存儲。 """
    D = embeddings.shape[1]
    # L2 距離索引
    index = faiss.IndexFlatL2(D)
    # 將數據轉換為 float32 是 FAISS 的要求
    index.add(embeddings.astype('float32'))
    faiss.write_index(index, index_path)
    print(f"✅ FAISS 索引建立完成。Index type: IndexFlatL2, Size: {index.ntotal}, Dim: {index.d}")
    print(f"Index saved to {index_path}")
    return index

def get_query_embedding(query_image_path, model):
    """ 處理單張查詢圖片，提取其特徵向量。 """
    if not os.path.exists(query_image_path):
        raise FileNotFoundError(f"查詢圖片不存在: {query_image_path}")

    img = Image.open(query_image_path).convert('RGB')
    tensor = INFERENCE_TRANSFORMS(img)
    tensor = tensor.unsqueeze(0) 

    # 獲取模型所在設備 (CPU)
    device = next(model.parameters()).device 
    tensor = tensor.to(device)
    
    with torch.no_grad():
        embedding = model(tensor)
        # 提取特徵後進行 L2 歸一化
        embedding = F.normalize(embedding, p=2, dim=1) 
        
    return embedding.cpu().numpy()


def query_faiss_index(query_image_path, model, faiss_index, indexed_names, k=5):
    """ 執行 FAISS 查詢並返回結果。 """
    query_vector = get_query_embedding(query_image_path, model)
    # 進行相似度搜索
    distances, indices = faiss_index.search(query_vector, k) 

    results = []
    for rank, (index_in_db, distance) in enumerate(zip(indices[0], distances[0])):
        if index_in_db == -1: 
            break
        
        # 使用索引找到對應的文件名
        file_name = indexed_names[index_in_db]
        
        results.append({
            'rank': rank + 1,
            'file_name': file_name,
            'distance_L2': distance,
            'db_index': index_in_db,
        })
        
    return results


In [8]:
# 4. 執行流程
# ==============================================================================

# 1. 載入模型
feature_extractor_model = load_feature_extractor(MODEL_PATH, EMBEDDING_SIZE)
if feature_extractor_model is None:
    print("無法繼續執行，模型載入失敗。")
    faiss_index = None
    indexed_file_names = None
else:
    # 2. 提取所有特徵或載入已存的特徵
    if os.path.exists(EMBEDDINGS_NPY_PATH):
        print(f"載入已存特徵向量: {EMBEDDINGS_NPY_PATH}")
        data = np.load(EMBEDDINGS_NPY_PATH, allow_pickle=True)
        all_embeddings = data['embeddings']
        indexed_file_names = data['filenames']
        print(f"載入完成，Shape: {all_embeddings.shape}")
        
    else:
        if not os.path.exists(DATABASE_IMAGES_DIR):
            print(f"\n⚠️ 警告: 圖片資料庫目錄不存在: {DATABASE_IMAGES_DIR}。跳過特徵提取。")
            all_embeddings = None
            indexed_file_names = None
        else:
            # *** 修正: 設定 num_workers=0 以避免多進程中的 Numpy 依賴錯誤 ***
            print("\n正在將 DataLoader 的 num_workers 設置為 0，以避免多進程環境下的 Numpy 錯誤 (RuntimeError)。")
            all_embeddings, indexed_file_names = extract_all_embeddings(
                feature_extractor_model,
                DATABASE_IMAGES_DIR,
                EMBEDDINGS_NPY_PATH,
                batch_size=64,
                num_workers=0 # 修正為 0
            )

    # 3. 建立或載入 FAISS 索引
    faiss_index = None
    if all_embeddings is not None and len(all_embeddings) > 0:
        if os.path.exists(FAISS_INDEX_PATH):
            print(f"\n載入已存 FAISS 索引: {FAISS_INDEX_PATH}")
            faiss_index = faiss.read_index(FAISS_INDEX_PATH)
            print(f"FAISS Index Loaded. Size: {faiss_index.nt}")
        else:
            faiss_index = build_faiss_index(all_embeddings, FAISS_INDEX_PATH)
    elif feature_extractor_model is not None:
         print("\n⚠️ 警告: 沒有特徵向量可供建立 FAISS 索引。")


Loading model to device: cpu
Model loaded successfully from pure state_dict.
載入已存特徵向量: indexed_advanced.npz
載入完成，Shape: (13752, 512)
✅ FAISS 索引建立完成。Index type: IndexFlatL2, Size: 13752, Dim: 512
Index saved to faiss_index_gallery.bin


In [9]:
# 4. 執行查詢
print("\n" + "="*60)
print("開始執行範例查詢")
print("="*60)

if faiss_index and os.path.exists(QUERY_IMAGE_EXAMPLE):
    K_COUNT = 5 
    
    results = query_faiss_index(
        query_image_path=QUERY_IMAGE_EXAMPLE,
        model=feature_extractor_model,
        faiss_index=faiss_index,
        indexed_names=indexed_file_names,
        k=K_COUNT
    )
    
    print(f"\n========== 查詢結果 (Top {K_COUNT}) ==========")
    print(f"查詢圖片: {os.path.basename(QUERY_IMAGE_EXAMPLE)}")
    for r in results:
        print(f"Rank {r['rank']}: {r['file_name']} (L2 距離: {r['distance_L2']:.4f})")
    print("==========================================")
elif faiss_index:
    print(f"\n❌ 錯誤: 範例查詢圖片路徑不存在: {QUERY_IMAGE_EXAMPLE}")
elif feature_extractor_model:
    print("\n❌ 錯誤: FAISS 索引未成功載入或建立。")


開始執行範例查詢

查詢圖片: 1000_031.jpg
Rank 1: 1000_031.jpg (L2 距離: 0.0000)
Rank 2: 3989_022.jpg (L2 距離: 0.0229)
Rank 3: 2914_032.jpg (L2 距離: 0.0242)
Rank 4: 4229_011.jpg (L2 距離: 0.0257)
Rank 5: 1055_031.jpg (L2 距離: 0.0263)
