In [6]:
# In[1]: 環境設定與套件載入
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 關閉 Huggingface parallel 警告

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import ndcg_score, average_precision_score

sns.set(style="whitegrid", font="Arial")

# 根據專案結構調整 import
from datasets.unified_dataset import UnifiedDataset
from models.rag_text_encoder import RAGTextEncoder




In [7]:
# In[2]: 參數設定（可依需要調整）
data_jsonl = "datasets/selected_data.jsonl"
cache_dir  = "ckpts_110"         # Stage1 模型與快取目錄
val_size   = 1000            # 驗證集大小
batch_size = 8
topk       = 5
device     = "cuda"          # 或 "cpu"

# 建立目錄
os.makedirs(cache_dir, exist_ok=True)


In [8]:
# In[3]: 載入文字編碼器模型
dev = torch.device(device)
ckpt = torch.load(f"{cache_dir}/enc1.pth", map_location=dev)

txt_enc = RAGTextEncoder(
    unified_jsonl=data_jsonl,
    top_k=topk,
    device=device,
    cache_dir=cache_dir
).to(dev)
txt_enc.load_state_dict(ckpt["txt"])
txt_enc.eval()

retriever = txt_enc.retriever
fusion    = txt_enc.fusion

print("模型與檢索器載入完成。")


Loaded cached memory_vec (ckpts_110/temp_corpus.jsonl.pt)
模型與檢索器載入完成。


In [9]:
# In[4]: 準備驗證集 DataLoader
full_ds = UnifiedDataset(data_jsonl, num_views=12)
print(len(full_ds))
val_ds  = Subset(full_ds, list(range(val_size)))
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)
print(f"驗證集樣本數：{len(val_ds)}")


7370
驗證集樣本數：1000


In [None]:
# # In[5]: 批次化基線檢索指標計算
# @torch.no_grad()
# def rebuild_memory(retriever, batch_size=64, device="cuda"):
#     """
#     以當前 retriever.kenc 權重重新編碼 corpus caption，
#     並覆寫 retriever.memory_vec in-place。
#     """
#     retriever.kenc.eval().to(device)

#     new_vecs = []
#     for i in range(0, len(retriever.texts), batch_size):
#         chunk = retriever.texts[i:i+batch_size]
#         enc   = retriever.tok(chunk,
#                               return_tensors="pt",
#                               padding=True,
#                               truncation=True).to(device)
#         out   = retriever.kenc(**enc).last_hidden_state[:, 0]  # CLS
#         new_vecs.append(F.normalize(out, 2, -1).cpu())

#     new_memory = torch.cat(new_vecs, 0)             # (N,768)
#     retriever.memory_vec.data.copy_(new_memory)     # 就地覆寫
# rebuild_memory(txt_enc.retriever, batch_size=128, device=device)



@torch.no_grad()
def compute_baseline_batch(retriever, dataset, topk, batch_size, device):
    retriever.qenc.eval()
    retriever.kenc.eval()
    retriever.eval()

    mem_vec = F.normalize(retriever.memory_vec, 2, -1).to(device)  # (N,768)
    print(">> Corpus size:", len(retriever.obj_ids))

    recalls, ndcgs, aps, rrs = [], [], [], []
    loader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

    for caps, _, obj_ids, _ in tqdm(loader, desc="Baseline 指標"):
        # 一併正規化 query
        q_vec = F.normalize(retriever.query_encode(list(caps)), 2, -1).to(device)  # (b,768)
        sims  = q_vec @ mem_vec.T                                                   # (b,N)
        sims_np = sims.cpu().numpy()
        idx_topk = sims.topk(topk, dim=-1).indices.cpu().numpy()                   # (b,topk)

        all_ids = np.array(retriever.obj_ids)
        for i, oid in enumerate(obj_ids):
            rel = (all_ids == oid).astype(int)
            rank_idx = idx_topk[i]

            # 計算指標...
            recalls.append(int(rel[rank_idx].sum()>0))
            ndcgs.append(ndcg_score([rel],[sims_np[i]], k=topk))
            aps.append(average_precision_score(rel, sims_np[i]))
            # MRR
            rr=0.
            for pos, gid in enumerate(rank_idx, start=1):
                if rel[gid]:
                    rr=1./pos; break
            rrs.append(rr)

    return {
        "Recall@K": np.mean(recalls),
        "NDCG@K":   np.mean(ndcgs),
        "mAP":      np.mean(aps),
        "MRR@K":    np.mean(rrs)
    }
# rebuild_memory(txt_enc.retriever, batch_size=128, device=device)
metrics = compute_baseline_batch(
    retriever=txt_enc.retriever,
    dataset=val_ds,
    topk=topk,
    batch_size=batch_size,
    device=device
)
print(metrics)


>> Corpus size: 215498


Baseline 指標: 100%|██████████| 125/125 [00:56<00:00,  2.23it/s]

{'Recall@K': 0.312, 'NDCG@K': 0.11542608382945801, 'mAP': 0.02221536248240349, 'MRR@K': 0.23945}





In [11]:
def collect_attention(retriever, fusion, loader, device, topk):
    records = []
    for caps, _, obj_ids, samp_idxs in tqdm(loader, desc="蒐集注意力"):
        B = len(caps)
        # 1) 檢索 topk context
        q_vec, sims, idx_topk, ctx, _ = retriever(
            list(caps), list(obj_ids), topk=topk
        )
        # 2) 組 tok_seq
        flat = []
        for i in range(B):
            flat.append(caps[i])
            flat.extend(ctx[i])
        enc = retriever.tok(flat, return_tensors="pt",
                            padding=True, truncation=True).to(device)
        out = retriever.qenc(**enc).last_hidden_state[:, 0]
        tok_seq = out.view(B, topk+1, -1)

        # 3) CrossFusion attention (avg over heads)
        with torch.no_grad():
            _, all_attn = fusion(tok_seq)
        # all_attn[l] shape = (B, T, T)

        for b in range(B):
            for l, attn_mat in enumerate(all_attn):
                arr = attn_mat[b].cpu().numpy()  # shape = (T, T)
                for c in range(1, topk+1):
                    records.append({
                        "sample":  samp_idxs[b].item(),
                        "layer":   l,
                        "ctx_pos": c-1,
                        "attn":    arr[0, c],    # CLS idx = 0
                        "is_pos":  int(
                            retriever.obj_ids[idx_topk[b, c-1]] == obj_ids[b]
                        )
                    })
    return pd.DataFrame(records)

# In[6]: 蒐集注意力矩陣
df_attn = collect_attention(
    retriever, fusion, val_loader, dev, topk
)
print("共蒐集", len(df_attn), "筆 attention 紀錄")
df_attn.head()

蒐集注意力: 100%|██████████| 125/125 [00:15<00:00,  8.06it/s]

共蒐集 15000 筆 attention 紀錄





Unnamed: 0,sample,layer,ctx_pos,attn,is_pos
0,0,0,0,0.164109,0
1,0,0,1,0.241446,0
2,0,0,2,0.200201,0
3,0,0,3,0.131537,0
4,0,0,4,0.122967,0


In [12]:
# In[7]: 繪製注意力分布盒鬚圖
def plot_attention_box_avg(df, out_dir):
    # Map is_pos to human-readable labels
    df["label"] = df["is_pos"].map({1: "Positive Sample", 0: "Negative Sample"})
    
    plt.figure(figsize=(8, 6))
    sns.boxplot(
        data=df,
        x="layer",
        y="attn",
        hue="label"
    )
    plt.title("CrossFusion CLS→Contexts Attention Distribution (Average Heads)")
    plt.legend(title="Sample Type", loc="upper right")
    plt.tight_layout()
    
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(os.path.join(out_dir, "attention_boxplot_avg.png"))
    plt.close()

# Execute the plot function
plot_attention_box_avg(df_attn, cache_dir)
print(f"Attention boxplot saved to {cache_dir}")
display(df_attn.head(20))

findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font f

Attention boxplot saved to ckpts_110


Unnamed: 0,sample,layer,ctx_pos,attn,is_pos,label
0,0,0,0,0.164109,0,Negative Sample
1,0,0,1,0.241446,0,Negative Sample
2,0,0,2,0.200201,0,Negative Sample
3,0,0,3,0.131537,0,Negative Sample
4,0,0,4,0.122967,0,Negative Sample
5,0,1,0,0.157402,0,Negative Sample
6,0,1,1,0.269589,0,Negative Sample
7,0,1,2,0.197085,0,Negative Sample
8,0,1,3,0.146392,0,Negative Sample
9,0,1,4,0.114293,0,Negative Sample


In [13]:
def compute_hit_rate_by_layer(df):
    """
    計算每一層中，CLS→context 的最強注意力是否命中正樣本，
    並對每層匯出整體命中率。
    """
    recs = []
    # 按 sample + layer 分組
    for (s, l), g in df.groupby(["sample", "layer"]):
        # 找出該組裡 attn 最大的那一筆索引
        idx_max = g["attn"].idxmax()
        hit     = int(g.loc[idx_max, "is_pos"])
        recs.append({"layer": l, "hit": hit})
    df2 = pd.DataFrame(recs)
    # 算每層的平均命中率
    return df2.groupby("layer")["hit"].mean().reset_index(name="hit_rate")

# 執行
hit_layer = compute_hit_rate_by_layer(df_attn)
hit_layer.to_csv(f"{cache_dir}/hit_rate_by_layer.csv", index=False)
print("各層命中率：\n", hit_layer)

各層命中率：
    layer  hit_rate
0      0     0.114
1      1     0.106
2      2     0.109


In [14]:
# In[9]: 隨機案例注意力熱圖（修正版，平均後注意力）
import os
from torch.utils.data import DataLoader

def plot_example_heatmaps(
    retriever, fusion, dataset,
    cache_dir, topk, device, num=5
):
    os.makedirs(f"{cache_dir}/examples", exist_ok=True)
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    cnt = 0

    for caps, _, obj_ids, samp_idxs in loader:
        q = caps[0]
        with torch.no_grad():
            # 1) 搜 topk context
            _, _, idx_topk, ctx, _ = retriever(
                [q], [obj_ids[0]], topk=topk
            )
            # 2) 組 flat list 丟進 BERT
            flat = [q] + ctx[0]
            enc = retriever.tok(
                flat, return_tensors="pt",
                padding=True, truncation=True
            ).to(device)
            out = retriever.qenc(**enc).last_hidden_state[:,0]
            tok_seq = out.view(1, topk+1, -1).to(device)

            # 3) CrossFusion → all_attn（list of L, each shape=(1,T,T)）
            _, all_attn = fusion(tok_seq)

        # 4) 把所有層做平均
        #    all_attn[l][0] shape=(T,T)
        avg_attn = sum(attn[0].cpu().numpy() for attn in all_attn) / len(all_attn)
        # CLS→contexts 注意力
        cls2ctx = avg_attn[0, 1:]  # 跳過 CLS 位置

        # 5) 畫 barplot
        plt.figure(figsize=(6,4))
        sns.barplot(
            x=list(range(topk)), y=cls2ctx,
            palette="Blues_d"
        )
        plt.xticks(
            range(topk),
            [t[:15]+"…" if len(t)>15 else t for t in ctx[0]],
            rotation=30, ha="right"
        )
        plt.ylabel("Average Attention to Contexts")
        plt.title(f"Sample {samp_idxs.item()} CLS→Contexts")
        plt.tight_layout()
        plt.savefig(f"{cache_dir}/examples/heat_{cnt}.png")
        plt.close()

        cnt += 1
        if cnt >= num:
            break

# 呼叫方式（確保把 topk, device, cache_dir 都帶進去）
plot_example_heatmaps(
    retriever=txt_enc.retriever,
    fusion=txt_enc.fusion,
    dataset=val_ds,
    cache_dir=cache_dir,
    topk=topk,
    device=dev,
    num=5
)
print("案例熱圖已存於", f"{cache_dir}/examples")



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found.
findfont: Font family 'Arial' not found

案例熱圖已存於 ckpts_110/examples
