In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/SemEval_26_Task8_MTRAG
!pwd
!ls


Mounted at /content/drive
/content/drive/MyDrive/SemEval_26_Task8_MTRAG
/content/drive/MyDrive/SemEval_26_Task8_MTRAG
 beir   dataset   indexes  'queries data'   README.md   src


In [None]:
!pip install -U --no-cache-dir faiss-gpu-cu11
!pip install -q sentence-transformers
!pip install -q tqdm numpy sentencepiece

Collecting faiss-gpu-cu11
  Downloading faiss_gpu_cu11-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting nvidia-cuda-runtime-cu11>=11.8.89 (from faiss-gpu-cu11)
  Downloading nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cublas-cu11>=11.11.3.6 (from faiss-gpu-cu11)
  Downloading nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Downloading faiss_gpu_cu11-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (48.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.3/48.3 MB[0m [31m179.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_x86_64.whl (417.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.9/417.9 MB[0m [31m182.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl (875 kB)
[2K   

In [None]:
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

BASE = "/content/drive/MyDrive/SemEval_26_Task8_MTRAG"

CORPUS_PATH = f"{BASE}/dataset/clapnq/corpus.jsonl"
INDEX_DIR   = f"{BASE}/indexes/clapnq-bge-base-faiss"
INDEX_PATH  = f"{INDEX_DIR}/index.faiss"
EMB_PATH    = f"{INDEX_DIR}/emb.npy"

# ---- Load corpus & doc_ids ----
corpus_list = []
with open(CORPUS_PATH, "r") as f:
    for line in f:
        corpus_list.append(json.loads(line))

doc_ids = [d["_id"] for d in corpus_list]

print("Corpus size =", len(corpus_list))

# ---- Load FAISS index ----
index = faiss.read_index(INDEX_PATH)
print("Index size =", index.ntotal)

emb = np.load(EMB_PATH)
print("emb shape =", emb.shape)
assert emb.shape[0] == len(corpus_list)

# ---- sentence-transformers encoder ----
bge_model = SentenceTransformer("BAAI/bge-base-en-v1.5")

def encode_queries(texts):
    return bge_model.encode(
        texts,
        batch_size=32,
        show_progress_bar=False,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

def dense_search_single(query: str, top_k: int = 50):
    emb_q = encode_queries([query])
    scores, idxs = index.search(emb_q, top_k)
    scores, idxs = scores[0], idxs[0]

    results = []
    for i, s in zip(idxs, scores):
        if i == -1:
            continue
        did = str(doc_ids[i])
        results.append((did, float(s)))
    return results


Corpus size = 183408
Index size = 183408
emb shape = (183408, 768)


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [None]:
import glob
from collections import defaultdict

def clean_user_line(line: str) -> str:
    return line.replace("|user|:", "").strip()

def format_history(raw_text: str) -> str:
    lines = raw_text.split("\n")
    cleaned = []
    for l in lines:
        l = l.strip()
        if not l:
            continue
        cleaned.append("User: " + clean_user_line(l))
    return "\n".join(cleaned)

# ---- qrels ----
qrels_dir = f"{BASE}/dataset/clapnq/qrels"
qrels_files = glob.glob(f"{qrels_dir}/*.tsv")
assert qrels_files, "qrels 目录下没有 .tsv 文件"
QRELS_PATH = qrels_files[0]
print("Using qrels file:", QRELS_PATH)

qrels = defaultdict(dict)
with open(QRELS_PATH, "r") as f:
    # 跳过表头：query-id corpus-id score
    header = f.readline()
    for line in f:
        parts = line.strip().split()
        if len(parts) != 3:
            continue
        qid, docid, rel = parts
        qrels[qid][docid] = int(rel)

print("qrels queries:", len(qrels))
print("sample qid from qrels:", next(iter(qrels.keys())))

# ---- 官方 rewrite: dataset/clapnq/queries.jsonl ----
OFFICIAL_PATH = f"{BASE}/dataset/clapnq/queries.jsonl"
official = {}

with open(OFFICIAL_PATH, "r") as f:
    for line in f:
        item = json.loads(line)
        qid = item["_id"]
        text = item["text"]
        official[qid] = {
            "qid": qid,
            "official_rewrite": clean_user_line(text),
        }
print("official rewrites:", len(official))

# ---- last turn: queries data/clapnq_lastturn.jsonl ----
LASTTURN_PATH = f"{BASE}/queries data/clapnq_lastturn.jsonl"
lastturn = {}

with open(LASTTURN_PATH, "r") as f:
    for line in f:
        item = json.loads(line)
        qid = item["_id"]
        text = item["text"]
        # 取最后一行作为当前 turn 原始问题
        non_empty = [l for l in text.split("\n") if l.strip()]
        last_line = non_empty[-1]
        current_query = clean_user_line(last_line)
        lastturn[qid] = {
            "qid": qid,
            "current_query": current_query,
        }
print("lastturn entries:", len(lastturn))

# ---- all questions: queries data/clapnq_questions.jsonl ----
ALLQS_PATH = f"{BASE}/queries data/clapnq_questions.jsonl"
allqs = {}

with open(ALLQS_PATH, "r") as f:
    for line in f:
        item = json.loads(line)
        qid = item["_id"]
        text = item["text"]
        history = format_history(text)
        allqs[qid] = {
            "qid": qid,
            "conversation_history": history,
        }
print("all-questions entries:", len(allqs))

# ---- 合并为统一的 queries dict ----
queries = {}
all_ids = set(official) | set(lastturn) | set(allqs)

for qid in all_ids:
    q = {"qid": qid}
    if qid in official:
        q.update(official[qid])
    if qid in lastturn:
        q.update(lastturn[qid])
    if qid in allqs:
        q.update(allqs[qid])
    queries[qid] = q

print("Final queries size:", len(queries))
sample_qid = next(iter(queries))
print("Sample query:", sample_qid, "->", queries[sample_qid])


Using qrels file: /content/drive/MyDrive/SemEval_26_Task8_MTRAG/dataset/clapnq/qrels/train.tsv
qrels queries: 208
sample qid from qrels: dd6b6ffd177f2b311abe676261279d2f<::>2
official rewrites: 208
lastturn entries: 208
all-questions entries: 208
Final queries size: 208
Sample query: a2698b2973ea7db1ee5adb5e70ec02e4<::>8 -> {'qid': 'a2698b2973ea7db1ee5adb5e70ec02e4<::>8', 'official_rewrite': 'How many people worldwide lack proper sewage treatment?', 'current_query': 'Chaging the subject a little a bit,  are there many people worldwide without  proper sewage treatment?', 'conversation_history': 'User: where does water go after it enters a storm drain\nUser: What is a catchbasin and how they are designed?\nUser: Does the basin really capture the litter  and debris from the water from the streets, roads, roofs?\nUser: Could be the water that enters a storm drain recycled and be used in the house chores like cleaning, gardening , etc?\nUser: And about on the level of the Federal government

In [None]:
!pip install -U transformers accelerate modelscope einops

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "Qwen/Qwen3-4B"

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,     # 你的 T4 / L4 都支持 fp16
    device_map={"": "cpu"},,
    trust_remote_code=True
)

print("Model loaded OK!")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Model loaded OK!


In [None]:
def call_llm(prompt: str,
             max_tokens: int = 256,
             temperature: float = 0.0):

    messages = [
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = tokenizer(text, return_tensors="pt").to("cpu")
    outputs = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=128,
    )

    generated = outputs[0][len(inputs["input_ids"][0]):]
    return tokenizer.decode(generated, skip_special_tokens=True).strip()


In [None]:
def _truncate_history(history: str, max_chars: int = 1500) -> str:
    history = history.strip()
    if len(history) <= max_chars:
        return history
    return history[-max_chars:]


# 1) Retrieval Necessity + Rewrite
RN_TEMPLATE = """In a multi-turn dialogue scenario, your task is to determine whether it is necessary to use a search engine to answer a user’s query and to provide a single clear search query when needed.

Consider the following situations:
1. Non-Informational Replies: Sometimes, users may respond with statements or expressions that do not require information retrieval, such as "thank you" or "okay."
2. Ambiguous or Unclear Queries: A user’s query might be unclear or lack specific details. Your role is to recognize the user’s intent and rewrite the query to make it clearer and more precise, facilitating an effective search engine query.
3. Previously Answered Queries: Check if the current query or a similar one has been previously asked and answered in the conversation history. If relevant information has already been provided, acknowledge this and avoid repeating the search.

Reply in the following STRICT format:
- First line: "yes" or "no" indicating whether retrieval is necessary.
- If you output "yes", the SECOND line MUST be a single, clear search query.
- Do NOT output anything else.

Conversation History:
{history}

Current User’s Query:
{current}

Response For Retrieval Necessity:
"""

def retrieval_necessity_and_rewrite(history: str,
                                    current_query: str,
                                    use_official_rewrite: str | None = None):
    hist = _truncate_history(history or "")
    if use_official_rewrite:
        current = f"{current_query}\n(Standalone rewrite suggestion: {use_official_rewrite})"
    else:
        current = current_query

    prompt = RN_TEMPLATE.format(history=hist, current=current)
    out = call_llm(prompt, max_tokens=128, temperature=0.0)
    lines = [l.strip() for l in out.split("\n") if l.strip()]
    if not lines:
        return True, current_query

    flag = lines[0].lower()
    if flag.startswith("no"):
        return False, ""

    if len(lines) >= 2:
        rew = lines[1]
    else:
        rew = current_query
    return True, rew


# 2) Decomposition
DECOMP_TEMPLATE = """Your task is to effectively decompose complex, multihop questions into simpler, manageable sub-questions.
Break down the question into multiple direct questions that can be answered individually.
Each sub-question should:
- be self-contained,
- follow logically from the previous one,
- and together they should help answer the main question.

Format:
- Output multiple lines, each line is exactly ONE sub-question.
- Do NOT number the lines.
- Do NOT add explanations.

Provided Contexts:
{context}

Multihop Question:
{question}

Decomposed queries:
"""

def decompose_queries(history: str, main_question: str):
    hist = _truncate_history(history or "")
    prompt = DECOMP_TEMPLATE.format(context=hist, question=main_question)
    out = call_llm(prompt, max_tokens=256, temperature=0.0)
    lines = [l.strip() for l in out.split("\n") if l.strip()]

    subqs = []
    for l in lines:
        lower = l.lower()
        if lower.startswith("decomposed") or lower.startswith("answer:") or "sub-question" in lower:
            continue
        subqs.append(l)
    return subqs


# 3) Disambiguation
DISAMBIG_TEMPLATE = """Your task is to identify and resolve ambiguity in a user question, ensuring it is clear and unambiguous.
- Read the question carefully and detect possible ambiguous parts.
- Reformulate the question to eliminate ambiguity.
- You may specify missing details, narrow broad terms, or add minimal context to make the intent clear.

Output:
- Return exactly ONE line: a single, disambiguated search query.
- Do NOT add explanations or extra sentences.

Original Question:
{question}

Disambiguated Query:
"""

def disambiguate_query(question: str) -> str:
    prompt = DISAMBIG_TEMPLATE.format(question=question)
    out = call_llm(prompt, max_tokens=128, temperature=0.0)
    line = out.strip().split("\n")[0].strip()
    return line or question


# 4) simple_rewrite baseline
def simple_rewrite(history: str,
                   current_query: str,
                   official_rewrite: str | None = None) -> str:
    need, rew = retrieval_necessity_and_rewrite(
        history=history,
        current_query=current_query,
        use_official_rewrite=official_rewrite,
    )
    if need and rew:
        return rew
    if official_rewrite:
        return official_rewrite
    return current_query


In [None]:
from tqdm import tqdm

def evaluate_run(run, qrels, ks=(1,3,5,10)):
    # run: {qid: [(doc_id, score), ...]}
    for qid in run:
        run[qid] = sorted(run[qid], key=lambda x: x[1], reverse=True)

    results = {"recall": {}, "ndcg": {}}
    idcgs = {}

    for qid, rel_docs in qrels.items():
        rels_sorted = sorted(rel_docs.values(), reverse=True)
        gains = []
        for i, rel in enumerate(rels_sorted, start=1):
            gains.append((2**rel - 1) / np.log2(i + 1))
        idcgs[qid] = np.cumsum(gains)

    for k in ks:
        recalls = []
        ndcgs = []
        for qid, rel_docs in qrels.items():
            if qid not in run:
                continue
            relevant = {d for d, r in rel_docs.items() if r > 0}
            retrieved = [d for d, _ in run[qid][:k]]

            hit = len(relevant & set(retrieved))
            if relevant:
                recalls.append(hit / len(relevant))

            dcg = 0.0
            for rank, doc_id in enumerate(retrieved, start=1):
                rel = rel_docs.get(doc_id, 0)
                if rel > 0:
                    dcg += (2**rel - 1) / np.log2(rank + 1)

            ideal = idcgs[qid][min(k, len(idcgs[qid])) - 1] if idcgs[qid].size > 0 else 1.0
            ndcgs.append(dcg / ideal if ideal > 0 else 0.0)

        results["recall"][k] = float(np.mean(recalls))
        results["ndcg"][k]   = float(np.mean(ndcgs))

    return results

def print_metrics(name, metrics):
    print(f"\n===== {name} =====")
    for k in [1,3,5,10]:
        r = metrics["recall"].get(k, 0.0)
        n = metrics["ndcg"].get(k, 0.0)
        print(f"k={k:2d} | Recall@{k}: {r:.4f} | NDCG@{k}: {n:.4f}")


In [None]:
run_official = {}

for qid, q in tqdm(queries.items(), desc="Baseline 1 - official rewrite"):
    if qid not in qrels:
        continue
    if "official_rewrite" not in q:
        continue
    query_text = q["official_rewrite"]
    hits = dense_search_single(query_text, top_k=50)
    run_official[qid] = hits

metrics_official = evaluate_run(run_official, qrels)
print_metrics("Baseline 1: Official Rewrite", metrics_official)


Baseline 1 - official rewrite: 100%|██████████| 208/208 [00:17<00:00, 12.05it/s]


===== Baseline 1: Official Rewrite =====
k= 1 | Recall@1: 0.1742 | NDCG@1: 0.4712
k= 3 | Recall@3: 0.3753 | NDCG@3: 0.4124
k= 5 | Recall@5: 0.4619 | NDCG@5: 0.4376
k=10 | Recall@10: 0.6063 | NDCG@10: 0.4982





In [None]:
run_simple = {}

for qid, q in tqdm(queries.items(), desc="Baseline 2 - simple rewrite"):
    if qid not in qrels:
        continue
    if "current_query" not in q:
        continue

    history = q.get("conversation_history", "")
    current = q["current_query"]
    official_rew = q.get("official_rewrite")

    rew = simple_rewrite(history, current, official_rewrite=official_rew)
    hits = dense_search_single(rew, top_k=50)
    run_simple[qid] = hits

metrics_simple = evaluate_run(run_simple, qrels)
print_metrics("Baseline 2: Simple RN+Rewrite", metrics_simple)


Baseline 2 - simple rewrite:   0%|          | 0/208 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Baseline 2 - simple rewrite: 100%|██████████| 208/208 [29:28<00:00,  8.50s/it]


===== Baseline 2: Simple RN+Rewrite =====
k= 1 | Recall@1: 0.0608 | NDCG@1: 0.1731
k= 3 | Recall@3: 0.1554 | NDCG@3: 0.1636
k= 5 | Recall@5: 0.2335 | NDCG@5: 0.1984
k=10 | Recall@10: 0.3461 | NDCG@10: 0.2458





In [None]:
run_rqrag_v1 = {}

for qid, q in tqdm(queries.items(), desc="RQ-RAG V1"):
    if qid not in qrels:
        continue
    if "official_rewrite" not in q:
        continue

    history = q.get("conversation_history", "")
    main_query = q["official_rewrite"]

    sub_queries = decompose_queries(history, main_query)

    all_queries = [main_query] + sub_queries
    all_hits = []
    for qq in all_queries:
        hits = dense_search_single(qq, top_k=20)
        all_hits.extend(hits)

    merged = {}
    for doc_id, score in all_hits:
        merged[doc_id] = max(merged.get(doc_id, -1e9), score)
    merged_list = sorted(merged.items(), key=lambda x: x[1], reverse=True)
    run_rqrag_v1[qid] = merged_list[:50]

metrics_rqrag_v1 = evaluate_run(run_rqrag_v1, qrels)
print_metrics("RQ-RAG V1: history + official rewrite → main+sub", metrics_rqrag_v1)


RQ-RAG V1: 100%|██████████| 208/208 [57:21<00:00, 16.55s/it]


===== RQ-RAG V1: history + official rewrite → main+sub =====
k= 1 | Recall@1: 0.1607 | NDCG@1: 0.4135
k= 3 | Recall@3: 0.3349 | NDCG@3: 0.3678
k= 5 | Recall@5: 0.4377 | NDCG@5: 0.4045
k=10 | Recall@10: 0.5531 | NDCG@10: 0.4522





In [None]:
run_rqrag_v2 = {}

for qid, q in tqdm(queries.items(), desc="RQ-RAG V2"):
    if qid not in qrels:
        continue
    if "current_query" not in q:
        continue

    history = q.get("conversation_history", "")
    current = q["current_query"]

    need, main_q = retrieval_necessity_and_rewrite(
        history=history,
        current_query=current,
        use_official_rewrite=None,
    )
    if not need or not main_q:
        main_q = current

    main_q = disambiguate_query(main_q)
    sub_queries = decompose_queries(history, main_q)

    all_queries = [main_q] + sub_queries
    all_hits = []
    for qq in all_queries:
        hits = dense_search_single(qq, top_k=20)
        all_hits.extend(hits)

    merged = {}
    for doc_id, score in all_hits:
        merged[doc_id] = max(merged.get(doc_id, -1e9), score)
    merged_list = sorted(merged.items(), key=lambda x: x[1], reverse=True)
    run_rqrag_v2[qid] = merged_list[:50]

metrics_rqrag_v2 = evaluate_run(run_rqrag_v2, qrels)
print_metrics("RQ-RAG V2: history only → RN+Disambig+Decomp", metrics_rqrag_v2)


RQ-RAG V2:  47%|████▋     | 98/208 [54:12<1:01:43, 33.67s/it]