In [1]:
import torch, numpy as np, faiss
from datasets.unified_dataset import UnifiedDataset
from models.rag_text_encoder import RAGTextEncoder
from models.gf_mv_encoder import GFMVEncoder
from train_two_stage import build_faiss_with_tok

# 1. 先随便把 ds、txt_enc、vis_enc 都 load 了
ds = UnifiedDataset("datasets/unified_data.jsonl", num_views=12)
ckpt = torch.load("ckpts_0530/enc1.pth", map_location="cpu")
txt_enc = RAGTextEncoder("datasets/unified_data.jsonl", top_k=4).cpu().eval()
vis_enc = GFMVEncoder(num_views=12).cpu().eval()
txt_enc.load_state_dict(ckpt["txt"])
vis_enc.load_state_dict(ckpt["vis"])

# 2. Build / load faiss_index + all_tok
class Args: pass
args = Args()
args.bs    = 16
args.out   = "ckpts_0530"
args.views = 12
index, all_tok = build_faiss_with_tok(ds, txt_enc, vis_enc, args)

# 3. 随机抽 3 个 obj_id，比对一下：
for i in [0, 10, 123]:
    oid = ds.items[i]["obj_id"]
    # 先找它在 obj2idx：
    idx0 = ds.obj2idx[oid]
    # 再直接 load all_tok[idx0]，让它 encode 一遍 imgs 看看和 vis_enc( imgs ) 结果是不是一模一样
    # 这里只用最简单的 L2 差距来瞄一眼
    #   imgs0 = 原始图片 → vis_enc(imgs0) = vis_tok0
    print(f"ds.items[{i}].obj_id = {oid}, idx0 = {idx0}")
    # 顺便看一下 all_tok[idx0] 的 shape
    print(" all_tok[idx0].shape =", all_tok[idx0].shape)

Loaded cached memory_vec (datasets/temp_corpus.jsonl.pt)
✓ Load cached FAISS index & tokens from ckpts_0530
ds.items[0].obj_id = 00026607d2024fb8888e85791310ed52, idx0 = 11
 all_tok[idx0].shape = (14, 512)
ds.items[10].obj_id = 00026607d2024fb8888e85791310ed52, idx0 = 11
 all_tok[idx0].shape = (14, 512)
ds.items[123].obj_id = 002823d4334b42cc9a5bb9be0884c71a, idx0 = 131
 all_tok[idx0].shape = (14, 512)


In [2]:
import os, torch, faiss, numpy as np
from datasets.unified_dataset      import UnifiedDataset
from models.rag_text_encoder       import RAGTextEncoder
from models.gf_mv_encoder          import GFMVEncoder
from train_two_stage               import build_faiss_with_tok

# 1. 准备
ds = UnifiedDataset("datasets/unified_data.jsonl", num_views=12)
ckpt = torch.load("ckpts_0530/enc1.pth", map_location="cpu")
txt_enc = RAGTextEncoder("datasets/unified_data.jsonl", top_k=4).cpu().eval()
vis_enc = GFMVEncoder(num_views=12).cpu().eval()
txt_enc.load_state_dict(ckpt["txt"])
vis_enc.load_state_dict(ckpt["vis"])

args = type("A", (), {})()
args.bs = 8
args.out = "ckpts_0530"
args.views = 12

index, all_tok = build_faiss_with_tok(ds, txt_enc, vis_enc, args)

# 2. 随机拿一个 batch，手动看 FAISS 返回的 top‐1 跟 obj_id[b] 有没有对上
loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=False)
for cap_batch, img_batch, oid_batch in loader:
    # 只做第一个 batch 检查就够了
    q_vec, _, txt_tok = txt_enc(list(cap_batch), list(oid_batch))
    _, vis_tok = vis_enc(img_batch, q_vec)
    sims, idx = index.search(q_vec.detach().cpu().numpy(), 50)
    for b, oid in enumerate(oid_batch):
        faiss_top1_idx = idx[b][0]
        faiss_top1_obj = ds.items[faiss_top1_idx]["obj_id"]
        print(f" b={b}: obj_id[b]={oid}, FAISS top1 obj_id={faiss_top1_obj}")
    break
print("Dataset unique obj_id count:", len(set(item["obj_id"] for item in ds.items)))
print("前10筆 obj_id：", [item["obj_id"] for item in ds.items[:10]])

Loaded cached memory_vec (datasets/temp_corpus.jsonl.pt)
✓ Load cached FAISS index & tokens from ckpts_0530
 b=0: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=1: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=2: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=3: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=4: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=5: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=6: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
 b=7: obj_id[b]=00026607d2024fb8888e85791310ed52, FAISS top1 obj_id=00026607d2024fb8888e85791310ed52
Dataset unique obj_id count: 7369
前10筆 obj_id： ['00026607d2024fb8888e85791310ed52', 

In [5]:
import os, torch, faiss, numpy as np
from datasets.unified_dataset      import UnifiedDataset
from models.rag_text_encoder       import RAGTextEncoder
from models.gf_mv_encoder          import GFMVEncoder
from models.cross_modal_reranker   import CrossModalReranker
from train_two_stage               import build_faiss_with_tok

ds = UnifiedDataset("datasets/unified_data.jsonl", num_views=12)
ckpt = torch.load("ckpts_0530/enc1.pth", map_location="cpu")
txt_enc = RAGTextEncoder("datasets/unified_data.jsonl", top_k=4).cpu().eval()
vis_enc = GFMVEncoder(num_views=12).cpu().eval()
txt_enc.load_state_dict(ckpt["txt"])
vis_enc.load_state_dict(ckpt["vis"])

args = type("A", (), {})()
args.bs = 4
args.out = "ckpts_0530"
args.views = 12
args.L = 50

index, all_tok = build_faiss_with_tok(ds, txt_enc, vis_enc, args)

# 只看第一批
loader = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=False)
for cap_batch, img_batch, oid_batch in loader:
    q_vec, _, txt_tok = txt_enc(list(cap_batch), list(oid_batch))
    _, vis_tok     = vis_enc(img_batch, q_vec)
    vis_tok = torch.nn.functional.normalize(vis_tok, 2, -1)

    sims, idx = index.search(q_vec.detach().cpu().numpy(), args.L)

    # 下面开始做 s_pos / s_neg
    reranker = CrossModalReranker().cpu().eval()
    # 如果你已经有微调好 weights，就在这里 load
    # reranker.load_state_dict(torch.load("ckpts_0530/rerank.pth", map_location="cpu"))

    for b, oid in enumerate(oid_batch):
        faiss_top1_idx = idx[b][0]
        faiss_top1_obj = ds.items[faiss_top1_idx]["obj_id"]
        # 1) 先印出应该是正例的那一行 VAIS idx
        print(f"\n>>> b={b}:  batch 的 obj_id = {oid} ; FAISS top1 obj_id = {faiss_top1_obj}")

        # 2) 现在我们拿 out-of-sample 正例 token、和 N 个负例 token
        #    `vis_tok[b]` 应该是 “batch 中第 b 笔 query” 的视觉 token
        tok_pos = vis_tok[b : b+1]        # shape = (1, T_vis, 512)
        t_pos   = txt_tok[b : b+1]        # shape = (1, 1+top_k, 768)

        # 将 pos token 扩展成跟负例数一样多的行
        # 建一个 (1 + (L-1)) 的组合：第一行是正例，后面都是负例
        # 负例来自 all_tok[idx[b][1:]]  <-- idx[b][0] 是正例
        neg_idx_list = idx[b][1 : 1+3]  # 先拿 3 个负例来 demo
        tok_neg = torch.from_numpy(all_tok[neg_idx_list]).detach().cpu().float()  # shape = (3, T_vis, 512)
        # t_pos_expand: (3, 1+top_k, 768)
        t_pos_rep = t_pos.expand(len(neg_idx_list), -1, -1)

        # 3) 把这 1:3 的 pos vs neg 一起丢给 reranker 看分数
        x_pos = reranker(t_pos_rep, tok_pos.expand(len(neg_idx_list), -1, -1))  # shape = (3,)
        x_neg = reranker(t_pos_rep, tok_neg)  # shape = (3,)

        print("   reranker(pos scores) =", x_pos.detach().numpy())
        print("   reranker(neg scores) =", x_neg.detach().numpy())

        # 如果正例分数没比 3 个负例都高，就说明「正负对齐」或「loss 计算」有问题
        break
    break


Loaded cached memory_vec (datasets/temp_corpus.jsonl.pt)
✓ Load cached FAISS index & tokens from ckpts_0530

>>> b=0:  batch 的 obj_id = 00026607d2024fb8888e85791310ed52 ; FAISS top1 obj_id = 00026607d2024fb8888e85791310ed52
   reranker(pos scores) = [0.1545835 0.1545835 0.1545835]
   reranker(neg scores) = [0.15543418 0.15626474 0.15581958]
