In [1]:
from modelnet40_datasets import *
from MHSAN import *
import torch
from tqdm import tqdm
import wandb
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
from sklearn.preprocessing import StandardScaler, normalize


def test(model, test_loader, device="cuda", num_classes=40):
    model.to(device)
    model.eval()
    correct, total = 0, 0

    # 建立空的 list 來收集所有特徵 & 標籤
    all_features = []
    all_labels = []

    # 記錄每個類別的計數
    class_correct = np.zeros(num_classes)
    class_total = np.zeros(num_classes)

    loop = tqdm(test_loader, desc="Testing", leave=True)
    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs, final_feature = model(images)
            

            predictions = outputs.argmax(dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
            loop.set_postfix(acc=correct / total)

            # 計算每個類別的準確度
            for i in range(labels.size(0)):
                label = labels[i].item()
                class_total[label] += 1
                if predictions[i] == label:
                    class_correct[label] += 1

            # 收集特徵和標籤
            all_features.append(final_feature.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    test_acc = correct / total

    # 計算每個類別的準確率
    class_accuracy = class_correct / (class_total + 1e-6)  # 避免除以 0
    class_accuracy_dict = {f"Class_{i}": class_accuracy[i] for i in range(num_classes)}

    # **記錄到 WandB**
    wandb.log({"Test Accuracy": test_acc})
    wandb.log({"Class Accuracy": class_accuracy_dict})

    print(f"Test Accuracy: {test_acc:.4f}")

    # **轉換特徵與標籤**
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    return test_acc, all_features, all_labels, class_accuracy


def load_model(model, checkpoint_path="MHSAN_best.pth", device="cuda"):
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()
    print(f" 成功載入模型權重: {checkpoint_path}")



In [2]:
# -----------------------------------------
# 以下為「檢索測量」所需函式
# -----------------------------------------
import faiss
import time
def l2_normalize(vecs):
    norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9
    return vecs / norms

def build_faiss_index(features, metric="euclidean"):
    """
    建立 Faiss index 方便做檢索。
    如果是 euclidean distance,就用 IndexFlatL2。
    如果是 cosine,就做 L2 normalize 後再用 IndexFlatIP。
    回傳:
        index: 建好的 Faiss 索引
        processed_features: 如果是 cosine,回傳已經 normalize 的特徵
    """
    if features.dtype != np.float32:
        features = features.astype(np.float32)

    N, d = features.shape

    if metric == "euclidean":
        index = faiss.IndexFlatL2(d)
        index.add(features)
        return index, features

    elif metric == "cosine":
        feats_norm = l2_normalize(features)
        index = faiss.IndexFlatIP(d)
        index.add(feats_norm.astype(np.float32))
        return index, feats_norm

    else:
        raise ValueError("metric 必須是 'euclidean' 或 'cosine'")

# -- mAP 相關 --
def compute_dcg(relevances):
    dcg = 0.0
    for i, rel in enumerate(relevances, start=1):
        dcg += rel / np.log2(i + 1)
    return dcg

def compute_ndcg(retrieved_labels, query_label, k=None):
    if k is not None:
        retrieved_labels = retrieved_labels[:k]

    relevances = [1 if lbl == query_label else 0 for lbl in retrieved_labels]
    dcg = compute_dcg(relevances)

    total_ones = sum(relevances)
    ideal_relevances = [1]*total_ones + [0]*(len(relevances) - total_ones)
    idcg = compute_dcg(ideal_relevances)

    if idcg == 0:
        return 0.0
    else:
        return dcg / idcg

def compute_bruteforce_retrieval_metrics(features, labels, metric="euclidean", top_k=5):
    """
    使用暴力檢索法 (Brute-Force) 計算:
    - Precision@K
    - Recall@K
    - mAP
    - NDCG@K (這裡示範用二元相關度)
    """
    print("\n[Brute-Force 檢索]")

    start_time = time.time()

    N = features.shape[0]
    if metric == "euclidean":
        dist_matrix = euclidean_distances(features, features)
        indices = np.argsort(dist_matrix, axis=1)  # 由小排到大
    elif metric == "cosine":
        sim_matrix = cosine_similarity(features, features)
        indices = np.argsort(-sim_matrix, axis=1) # 由大排到小
    else:
        raise ValueError("metric 必須是 'euclidean' 或 'cosine'")

    precision_list = []
    recall_list = []
    ndcg_list = []
    average_precisions = []

    for i in range(N):
        query_label = labels[i]
        retrieved_indices = indices[i]

        # 通常第0個是自己,先排除
        if retrieved_indices[0] == i:
            retrieved_indices = retrieved_indices[1:]
        else:
            retrieved_indices = retrieved_indices[:]

        # 只取前 K
        retrieved_indices_topk = retrieved_indices[:top_k]
        retrieved_labels_topk = labels[retrieved_indices_topk]

        # -- Precision@K --
        correct_count = np.sum(retrieved_labels_topk == query_label)
        precision_k = correct_count / top_k
        precision_list.append(precision_k)

        # -- Recall@K --
        total_relevant = np.sum(labels == query_label)
        recall_k = correct_count / total_relevant if total_relevant > 0 else 0.0
        recall_list.append(recall_k)

        # -- NDCG@K --
        ndcg_k = compute_ndcg(labels[retrieved_indices], query_label, k=top_k)
        ndcg_list.append(ndcg_k)

        # -- mAP --
        hit = 0
        p_list = []
        for rank, idx_ in enumerate(retrieved_indices, start=1):
            if labels[idx_] == query_label:
                hit += 1
                p_list.append(hit / rank)
        AP = np.mean(p_list) if len(p_list) > 0 else 0.0
        average_precisions.append(AP)

    mean_precision = np.mean(precision_list)
    mean_recall = np.mean(recall_list)
    mean_ndcg = np.mean(ndcg_list)
    mean_AP = np.mean(average_precisions)

    end_time = time.time()
    elapsed = end_time - start_time

    print(f"暴力檢索花費時間: {elapsed:.4f} 秒")
    print(f"Precision@{top_k}: {mean_precision:.4f}")
    print(f"Recall@{top_k}: {mean_recall:.4f}")
    print(f"mAP: {mean_AP:.4f}")
    print(f"NDCG@{top_k}: {mean_ndcg:.4f}")

    return {
        "Precision@K": mean_precision,
        "Recall@K": mean_recall,
        "mAP": mean_AP,
        "NDCG@K": mean_ndcg,
        "time": elapsed
    }

def compute_faiss_retrieval_metrics(features, labels, metric="euclidean", top_k=5):
    """
    使用 Faiss 索引計算檢索性能:
    - Precision@K
    - Recall@K
    - mAP
    - NDCG@K
    """
    print(f"\n[Faiss 檢索 - {metric}]")

    # 建立索引
    index, processed_features = build_faiss_index(features, metric=metric)

    start_time = time.time()
    N = processed_features.shape[0]

    # Faiss 的 search
    # k+1: 因為第 0 筆通常會檢索到自己
    D, I = index.search(processed_features, top_k+1)  # I.shape = (N, top_k+1)

    precision_list = []
    recall_list = []
    ndcg_list = []
    average_precisions = []

    for i in range(N):
        query_label = labels[i]
        retrieved_indices = I[i]

        # 排除自己 (如果檢索回來的第一筆就是自己)
        if retrieved_indices[0] == i:
            retrieved_indices = retrieved_indices[1:]
        else:
            retrieved_indices = retrieved_indices[:top_k]

        retrieved_labels_topk = labels[retrieved_indices[:top_k]]

        # -- Precision@K --
        correct_count = np.sum(retrieved_labels_topk == query_label)
        precision_k = correct_count / top_k
        precision_list.append(precision_k)

        # -- Recall@K --
        total_relevant = np.sum(labels == query_label)
        recall_k = correct_count / total_relevant if total_relevant > 0 else 0.0
        recall_list.append(recall_k)

        # -- NDCG@K --
        # 這裡示範用檢索到的前 top_k+1 (包含自己),但記得排除自己
        # 所以可以改成對 I[i, 1:] 做 NDCG
        # 為了簡單,這裡直接用 retrieved_indices[:top_k]
        ndcg_k = compute_ndcg(labels[retrieved_indices], query_label, k=top_k)
        ndcg_list.append(ndcg_k)

        # -- mAP --
        # 這裡也可以只算 top_k
        # 如果想跟暴力法比對,建議檢索列表整串都算 (但 Faiss 只給 k+1)
        # 為保持一致,這裡就只算前 (k+1) (排除自己後) 當 ranking
        hit = 0
        p_list = []
        # 先把檢索到的結果當 ranking
        full_rank = I[i]  # shape = (k+1,)
        if full_rank[0] == i:
            full_rank = full_rank[1:]
        for rank, idx_ in enumerate(full_rank, start=1):
            if labels[idx_] == query_label:
                hit += 1
                p_list.append(hit / rank)
        AP = np.mean(p_list) if len(p_list) > 0 else 0.0
        average_precisions.append(AP)

    mean_precision = np.mean(precision_list)
    mean_recall = np.mean(recall_list)
    mean_ndcg = np.mean(ndcg_list)
    mean_AP = np.mean(average_precisions)

    end_time = time.time()
    elapsed = end_time - start_time

    print(f"Faiss 檢索花費時間: {elapsed:.4f} 秒")
    print(f"Precision@{top_k}: {mean_precision:.4f}")
    print(f"Recall@{top_k}: {mean_recall:.4f}")
    print(f"mAP: {mean_AP:.4f}")
    print(f"NDCG@{top_k}: {mean_ndcg:.4f}")

    return {
        "Precision@K": mean_precision,
        "Recall@K": mean_recall,
        "mAP": mean_AP,
        "NDCG@K": mean_ndcg,
        "time": elapsed
    }

In [None]:

# 參數
batch_size = 8
num_views = 12
num_layers = 3
checkpoint_path = "MHSAN12_best.pth"

# 依照原本規劃: 12 views => top_k=6、20 views => top_k=10
top_k = 6 if num_views == 12 else 10 if num_views == 20 else 5

wandb.init(project="MHSAN-ModelNet40", name=f"MHSAN_{num_views}views_test")

# 載入 Test 資料
test_loader = get12_views_dataloader(split="test", batch_size=batch_size)
dataset = test_loader.dataset
class_names = dataset.categories
print("測試資料類別名稱:", class_names)
print("測試集樣本數:", len(test_loader.dataset))

device = "cuda" if torch.cuda.is_available() else "cpu"
model = MHSAN(num_views=num_views, embed_dim=512, num_heads=8, num_layers=num_layers, top_k=top_k).to(device)

# 載入模型權重
load_model(model, checkpoint_path=checkpoint_path, device=device)

# 測試模型 + 收集特徵
test_acc, features, labels, class_accuracy = test(model, test_loader, device=device)


# ----------------------------------
# 加入檢索測量 (Brute-Force)
# ----------------------------------
brute_force_result = compute_bruteforce_retrieval_metrics(
    features, labels,
    metric="cosine",  # 或 "cosine"
    top_k=5
)
wandb.log({
    "BruteForce_Precision@K": brute_force_result["Precision@K"],
    "BruteForce_Recall@K": brute_force_result["Recall@K"],
    "BruteForce_mAP": brute_force_result["mAP"],
    "BruteForce_NDCG@K": brute_force_result["NDCG@K"],
    "BruteForce_time": brute_force_result["time"]
})

# ----------------------------------
# 加入 Faiss 檢索測量
# ----------------------------------
faiss_result = compute_faiss_retrieval_metrics(
    features, labels,
    metric="cosine",  # 或 "cosine"
    top_k=5
)
wandb.log({
    "Faiss_Precision@K": faiss_result["Precision@K"],
    "Faiss_Recall@K": faiss_result["Recall@K"],
    "Faiss_mAP": faiss_result["mAP"],
    "Faiss_NDCG@K": faiss_result["NDCG@K"],
    "Faiss_time": faiss_result["time"]
})

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhj6hki123[0m. Use [1m`wandb login --relogin`[0m to force relogin


載入 test 數據集: 2468 筆資料,類別數: 40
測試資料類別名稱: ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
測試集樣本數: 2468
 成功載入模型權重: MHSAN12_best.pth


Testing: 100%|██████████| 309/309 [01:08<00:00,  4.53it/s, acc=0.846]


Test Accuracy: 0.8460

[Brute-Force 檢索]
暴力檢索花費時間: 0.9766 秒
Precision@5: 0.8337
Recall@5: 0.0620
mAP: 0.6718
NDCG@5: 0.9069

[Faiss 檢索 - euclidean]
Faiss 檢索花費時間: 0.0628 秒
Precision@5: 0.8336
Recall@5: 0.0620
mAP: 0.8853
NDCG@5: 0.9067


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
BruteForce_NDCG@K,▁
BruteForce_Precision@K,▁
BruteForce_Recall@K,▁
BruteForce_mAP,▁
BruteForce_time,▁
Faiss_NDCG@K,▁
Faiss_Precision@K,▁
Faiss_Recall@K,▁
Faiss_mAP,▁
Faiss_time,▁

0,1
BruteForce_NDCG@K,0.90691
BruteForce_Precision@K,0.83371
BruteForce_Recall@K,0.06199
BruteForce_mAP,0.6718
BruteForce_time,0.97658
Faiss_NDCG@K,0.90666
Faiss_Precision@K,0.83363
Faiss_Recall@K,0.06196
Faiss_mAP,0.88533
Faiss_time,0.06283


In [14]:
brute_force_result = compute_bruteforce_retrieval_metrics(
    features, labels,
    metric="euclidean",  # 或 "cosine"
    top_k=20
)

faiss_result = compute_faiss_retrieval_metrics(
    features, labels,
    metric="euclidean",  # 或 "cosine"
    top_k=20
)


[Brute-Force 檢索]
暴力檢索花費時間: 0.9948 秒
Precision@20: 0.7708
Recall@20: 0.2141
mAP: 0.6718
NDCG@20: 0.9092

[Faiss 檢索 - euclidean]
Faiss 檢索花費時間: 0.2161 秒
Precision@20: 0.7708
Recall@20: 0.2141
mAP: 0.8468
NDCG@20: 0.9089
