# Configuration

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q --no-deps datasets sentence-transformers faiss-cpu rank_bm25 evaluate

: 

In [None]:
import nest_asyncio, logging
from pathlib import Path
from tqdm import tqdm, trange
import torch, random, faiss
import numpy as np
import pandas as pd

nest_asyncio.apply()
logging.basicConfig(level = logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s ", force=True)
SEED = 42

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


BASE_DIR = Path("/content/drive/MyDrive/hotpot_qa")
BASE_DIR.mkdir(parents=True, exist_ok=True)

TOTAL_SAMPLE = 10000
EVAL_RATIO = 0.2

CORPUS_PATH   = BASE_DIR/"corpus.parquet"
EVAL_PATH     = BASE_DIR/"eval.parquet"
FAISS_TXT     = BASE_DIR/"index_text_flatip.faiss"
FAISS_IMG     = BASE_DIR/"index_img_flatip.faiss"

LORA_CKPT     = BASE_DIR/"lora_sft_ckpt"
DPO_CKPT      = BASE_DIR/"dpo_finetuned_ckpt"

# Load Data

In [None]:
from datasets import load_dataset
import pandas as pd, re, itertools

hp_raw = load_dataset('hotpot_qa', 'fullwiki', split='validation[:200]')

docs = []
for ex in hp_raw:
    for title, sents in ex['context']:
        docs.append({
            'title': title,
            'text': ' '.join(sents).strip()
        })
docs_df = pd.DataFrame(docs).drop_duplicates('title').reset_index(drop=True)
docs_df['row_id'] = docs_df.index
docs_df.to_parquet(CORPUS_PATH, index=False)
print('Corpus saved: ', CORPUS_PATH, '| size=', len(docs_df))
# 把所有 (title, 文本段落) 展平成一个 DataFrame，去重后存为 corpus.parquet。
# 同理把 question/answer/gold_docs/keywords 做成 eval.parquet。
title2id = {
    row.title: int(row.row_id) for row in docs_df.itertuples()
}

def gold_ids(supp):
    return sorted({title2id[t] for t, _ in supp if t in title2id})

def keywords(ans):
    toks = re.findall(r'\w+', ans.lower())
    return list(dict.fromkeys([t for t in toks if len(t)>2]))[:5]

eval_rows = []
for ex in hp_raw:
    eval_rows.append({
        'row_id': len(eval_rows),
        'question': ex['question'],
        'answer': ex['answer'],
        'gold_docs': gold_ids(ex['supporting_facts']),
        'keywords': keywords(ex['question']),

    })

eval_df = pd.DataFrame(eval_rows)
eval_df.to_parquet(EVAL_PATH, index=False)
print("Eval saved:", EVAL_PATH, "| size =", len(eval_df))


: 

In [None]:
import faiss
import numpy as np
from sentence_transformers import  SentenceTransformer

embed_model='sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(embed_model, device=DEVICE)
embeddings = model.encode(docs_df['text'].tolist(), show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True).astype('float32')

index=faiss.IndexFlatIp(embeddings.shape[1])
index.add(embeddings)
faiss.write_index(index, str(FAISS_TXT))
print('FAISS saved: ', FAISS_TXT, '| vectors = ', index.ntotal)

In [None]:
# 可选：加载回来看一下
chk_index = faiss.read_index(str(FAISS_TXT))
qv = model.encode(["Albert Einstein"], normalize_embeddings=True).astype("float32")
sims, ids = chk_index.search(qv.reshape(1,-1), 5)
print("Top titles:", docs_df.loc[ids[0], "title"].tolist())
print("✅ Hotpot data ready — 以上三文件已生成")


# Retriever

In [None]:
from abc import ABC, abstractmethod
from sentence_transformers import SentenceTransformer, CrossEncoder
from tqdm import tqdm
from rank_bm25 import BM25Okapi
import re


class Retriever(ABC):
    @abstractmethod
    def retrieve(self, query, top_k):
        pass

## BM25Retriever


In [None]:
class BM25Retriever(Retriever):
    def __init__(self, corpus: pd.DataFrame, text_col="text"):
        self.corpus    = corpus
        self.texts     = corpus[text_col].tolist()
        # 1) tokenize once
        self.tokenized = [re.findall(r'\w+', txt.lower()) for txt in self.texts]
        # 2) build BM25 model
        self.bm25 = BM25Okapi(self.tokenized)
        # 保存 row_id 列表，用于索引映射
        self.row_ids = corpus.index.to_list()

    def retrieve(self, query, top_k):
        q_tok  = re.findall(r'\w+', query.lower())
        scores = self.bm25.get_scores(q_tok)             # numpy array
        top_n  = np.argsort(scores)[::-1][:top_k]        # 文档内部下标
        ids    = [self.row_ids[i] for i in top_n]        # 转回 row_id
        return {"ids": ids, "scores": scores[top_n].tolist()}

    def get_content(self, ids, field="text"):
        if isinstance(ids, int):
            ids = [ids]
        return [self.corpus.loc[i, field] for i in ids]

# CrossEncoder

In [None]:
class CrossEncoderReranker:
    def __init__(self, model_name, device="cuda"):
        self.model = CrossEncoder(model_name, device=device)

    def rerank(self, query: str, docs: list[str], row_ids: list[int], top_k: int):
        # Debug
        # print(f"docs type: {type(docs)}, first: {docs[0] if docs else None}")
        # print(f"row_ids: {row_ids[:5]}, total={len(row_ids)}")
        pairs = [[query, d] for d in docs]
        scores = np.asarray(self.model.predict(pairs, batch_size=32)).flatten()
        # print(f"pairs: {len(pairs)}, scores: {len(scores)}")
        assert len(scores) == len(row_ids), "Mismatch between scores and row_ids!"
        n = min(top_k, len(row_ids))
        order = np.argsort(scores)[::-1][:n]
        return [(row_ids[int(i)], float(scores[i])) for i in order]

# Query_Rewritter & PromptBuilder

## Query_Rewritter

In [None]:
SYNONYM_MAP = {
    "connect": ["plug", "attach", "link"],
    "setup": ["configure", "install"],
    "how": ["how to", "how do I"],
    "price": ["cost", "charge", "fee"],
}

In [None]:
import re
from typing import List
from transformers import pipeline

class QueryRewriter:
    """
    基于规则 + 小模型自动做同义/长尾扩展。
    1) 规则层：关键词同义替换字典
    2) LLM 层：对长尾 query 用 GPT-3.5 判断是否需要扩写
    """
    def __init__(self, synonym_map= None, stopwords=None, device=0):
        self.synmap = synonym_map or {
            "connect": ["plug", "attach", "link"],
            "setup": ["configure", "install"],
            "price": ["cost", "charge", "fee"],
            "how": ["how to", "how do I"],
        }
        self.expander = pipeline("text2text-generation",
                                 model="google/flan-t5-base", device=device)
        self.stopwords = stopwords or {"setup", "login", "price"}

    def rewrite(self, query: str):
        # 1. 规则扩展（只扩展一次，防止连锁递归）
        rewrites = [query]
        for k, vs in self.synmap.items():
            # 只替换一次
            pattern = rf"\b{k}\b"
            if re.search(pattern, query, flags=re.IGNORECASE):
                for v in vs:
                    q_new = re.sub(pattern, f"{k}/{v}", query, flags=re.IGNORECASE, count=1)
                    if q_new != query:
                        rewrites.append(q_new)
                break  # 只扩一个词，避免多个叠加

        # 2. LLM扩写，如果很短才扩写
        if len(query.split()) < 5:
            prompt = f"Please rewrite or expand this search query for better clarity and detail: {query}"
            out = self.expander(prompt, max_length=64, clean_up_tokenization_spaces=True)[0]["generated_text"]
            rewrites.append(out.strip())

        # 返回多条候选query，可供multi-query检索用
        return list(dict.fromkeys(rewrites))

## PromptBuilder

In [None]:
from typing import List, Dict, Optional

class PromptBuilder:
    """
    三层可插拔 Prompt 架构：
      1) system_prompt：系统角色和风格指令
      2) few_shot：N条示例问答
      3) context：检索到的上下文片段
    """
    def __init__(
        self,
        system_prompt: str = "You are a helpful Amazon QA assistant.",
        few_shot: Optional[List[Dict[str, str]]] = None,
        max_context: int = 3
    ):
        self.system = system_prompt
        self.few_shot = few_shot or []
        self.max_ctx = max_context

    def build(
        self,
        query: str,
        contexts: List[str],
        image_tags: Optional[List[str]] = None
    ) -> str:
        parts = []
        # 1. System Prompt
        parts.append(f"SYSTEM:\n{self.system}\n")
        # 2. Few-Shot 示例
        if self.few_shot:
            for ex in self.few_shot:
                parts.append(f"EXAMPLE:\nQ: {ex['q']}\nA: {ex['a']}\n")
        # 3. 用户问题
        parts.append(f"USER QUERY:\n{query}\n")
        # 4. 检索上下文（文本/图片）
        for i, ctx in enumerate(contexts[:self.max_ctx]):
            if image_tags and i < len(image_tags):
                parts.append(f"CONTEXT {i+1} [Image: {image_tags[i]}]:\n{ctx}\n")
            else:
                parts.append(f"CONTEXT {i+1}:\n{ctx}\n")
        # 5. 指令收尾
        parts.append("PLEASE ANSWER BASED ON ABOVE.")
        return "\n".join(parts)

# Generator

In [None]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline

class Generator(ABC):
    @abstractmethod
    def generate(self, prompt):
        pass

In [None]:
class TextGenerator(Generator):
    def __init__(self, model_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", device="cuda",
                     batch_size=128, torch_dtype=torch.float16, max_new_tokens=128, temperature=0.2, do_sample=False):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            # device_map="auto",
            device_map=None,
            trust_remote_code=True,
            torch_dtype=torch_dtype
        ).to(device)

        self.pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=max_new_tokens,
            # do_sample=do_sample,
            # temperature=temperature,
            num_return_sequences=1,
            device=0 if device.startswith("cuda") else -1,
            batch_size=batch_size,
        )

    def generate(self, prompt):
        outs = self.pipe(prompt)

        def clean(raw):
            txt = raw.get("generated_text", raw.get("text", ""))
            if "Answer:" in txt:
                return txt.split("Answer:")[-1].strip()
            return txt.strip()

        if isinstance(prompt, str):
            return clean(outs[0])

        results = []
        for item in outs:
            raw = item[0] if isinstance(item, list) else item
            results.append(clean(raw))
        return results

# QASystem

In [None]:
class QASystem:
    """
    Three-stage pipeline
      1. retrieve_k  : 初筛候选数（交给 Retriever）        –– default 200
      2. rerank_k    : 精排后保留数（交给 Reranker）       –– default 20
      3. final_k     : 取前 k 条生成或返回（answer_top_k） –– default 5
    """
    def __init__(
        self,
        retriever: Retriever,
        text_gen= None,
        reranker= None,
        vis_gen=None,
        classifier=None,
        query_rewriter=None,
        prompt_builder=None,
        retrieve_k: int = 200,
        rerank_k: int = 20,
        answer_top_k: int = 5,
    ):
        self.retriever   = retriever
        self.reranker    = reranker
        self.text_gen    = text_gen
        self.vis_gen     = vis_gen
        self.classifier  = classifier
        self.query_rewriter = query_rewriter
        self.retrieve_k  = retrieve_k
        self.rerank_k    = rerank_k
        self.prompt_builder = prompt_builder
        self.answer_k    = answer_top_k

        self.last_cand_ids: list[int] = []

    # ---------- public ----------
    def answer(self, query: str | list[str], generate: bool = True):
        if isinstance(query, list):
            return [self._answer_single(q, generate) for q in query]
        return self._answer_single(query, generate)

    # ---------- private ----------
    def _answer_single(self, query: str, generate: bool):
        # ===  Query Rewrite（支持多路扩展） ===
        if self.query_rewriter is not None:
            rewrites = self.query_rewriter.rewrite(query)
            if isinstance(rewrites, str):
                rewrites = [rewrites]
        else:
            rewrites = [query]

        # 1) 检索 retrieve_k
        # === 2) 检索（多路查询合并 top-N） ===
        all_cand_ids = []
        for q in rewrites:
            cand = self.retriever.retrieve(q, self.retrieve_k)
            all_cand_ids.extend(cand["ids"])
        # 去重 & 保持顺序
        seen = set()
        all_cand_ids = [x for x in all_cand_ids if not (x in seen or seen.add(x))]
        cand_ids = all_cand_ids[:self.retrieve_k]
        self.last_cand_ids = cand_ids

        # 2) 可选精排
        if self.reranker:
            docs = self.retriever.get_content(cand_ids, field="text")
            top_pairs = self.reranker.rerank(
                query, docs, cand_ids, self.rerank_k
            )
            row_ids = [rid for rid, _ in top_pairs][: self.answer_k]
        else:
            row_ids = cand_ids[: self.answer_k]

        # 3) 只想评检索就直接返回 ids
        if not generate or self.text_gen is None:
            return row_ids

        # 4) 构造 prompt + 生成
        contexts = self.retriever.get_content(row_ids, "text")
        image_tags = None
        # 如果你支持多模态，可以从 get_content(row_ids, "img") 里取 image_tags（如文件名或 URL）

        if self.prompt_builder is not None:
            prompt = self.prompt_builder.build(query, contexts, image_tags)
        else:
            # 兼容老 prompt
            ctx = contexts[0]
            prompt = f"Question: {query}\nContext: {ctx}\nAnswer:"

        if (
            self.classifier
            and self.classifier.is_visual(query)
            and self.vis_gen
            and (img_path := self.retriever.get_content(row_ids[0], "img")[0])
        ):
            return self.vis_gen.generate(prompt, img_path)
        return self.text_gen.generate(prompt)


# Process

In [None]:
# 只需要这两个
from your_module import TextGenerator, QASystem, BM25Retriever

# 1) 实例化检索器（假设已加载 docs_df）
bm25 = BM25Retriever(corpus=docs_df, text_col="text")

# 2) 实例化生成器——如想跑得更快，可以先换成 t5-small
text_gen = TextGenerator(
    model_id="t5-small",    # 或者 "gpt2" 之类的小模型
    device="cpu"            # 或 "cuda"
)

# 3) 构建 QASystem，不传 classifier、vision_gen、rewriter
qa = QASystem(
    retriever=bm25,
    text_gen=text_gen,
    reranker=None,
    prompt_builder=None,    # 先用内置默认 Prompt
    retrieve_k=10,
    rerank_k=5,
    answer_top_k=3
)

# 4) 测试一次端到端
print( qa.answer("Who wrote Pride and Prejudice?", generate=True) )


# Evaluation

In [None]:
# 放在 Notebook 某个 cell 里，确保 `!pip install evaluate` 已经装好
import pandas as pd
import numpy as np
from tqdm import tqdm
import evaluate

class Evaluator:
    def __init__(self, qa_system, eval_df: pd.DataFrame):
        """
        qa_system: 已实例化的 QASystem
        eval_df: 包含 'question','answer','gold_docs' 列的 DataFrame
        """
        self.qa = qa_system
        self.df = eval_df.reset_index(drop=True)

        # 加载评测器
        self.metric_em    = evaluate.load("exact_match")
        self.metric_f1    = evaluate.load("f1")
        self.metric_bleu  = evaluate.load("bleu")
        self.metric_rouge = evaluate.load("rouge")

    def eval_retrieval(self, ks=(5,10,20)):
        """返回 Recall@k 的 dict"""
        results = {}
        for k in ks:
            hits = 0
            for q, gold in zip(self.df["question"], self.df["gold_docs"]):
                pred = self.qa.answer(q, generate=False)[:k]
                if any(g in pred for g in gold):
                    hits += 1
            results[f"Recall@{k}"] = hits / len(self.df)
        return results

    def eval_generation(self):
        """
        - 批量调用 .answer(generate=True)
        - 最后一次性 compute 各指标
        """
        preds, refs = [], []
        for q, gold in tqdm(zip(self.df["question"], self.df["answer"]),
                            total=len(self.df), desc="Gen Eval"):
            out = self.qa.answer(q, generate=True)
            # 可能返回 list
            if isinstance(out, (list, tuple)):
                out = out[0]
            preds.append(out.strip())
            refs.append(gold.strip())

        # EM / F1
        em_res = self.metric_em.compute(predictions=preds, references=refs)
        f1_res = self.metric_f1.compute(predictions=preds, references=refs)
        # BLEU 要 list(list(tokens))
        bleu_res = self.metric_bleu.compute(
            predictions=[p.split() for p in preds],
            references=[[r.split()] for r in refs]
        )
        rouge_res = self.metric_rouge.compute(predictions=preds, references=refs)

        return {
            "EM":    em_res["exact_match"],
            "F1":    f1_res["f1"],
            "BLEU":  bleu_res["bleu"],
            "ROUGE": rouge_res["rouge1"]  # 你也可以看 rouge2、rougeL
        }

    def grid_search(self, param_grid: dict, metric="F1"):
        """
        简单的超参搜索
        param_grid = {'retrieve_k':[50,100], 'rerank_k':[5,10]}
        """
        from itertools import product
        best = {"score": -1, "params": None}
        for vals in product(*param_grid.values()):
            params = dict(zip(param_grid.keys(), vals))
            # 动态设置
            for k, v in params.items():
                setattr(self.qa, k, v)

            # 只跑 Retrieval 或者 Generation 都行
            gen_res = self.eval_generation()
            score = gen_res.get(metric, 0)
            print(f"Params={params} → {metric}={score:.4f}")

            if score > best["score"]:
                best = {"score": score, "params": params.copy()}

        print("Best:", best)
        return best

    def ablation(self):
        """
        示例：对比无 / 有 rewrite 和自定义 PromptBuilder 的效果
        """
        results = {}
        # baseline
        results["baseline"] = self.eval_generation()
        # with rewrite
        orig_rw = self.qa.query_rewriter
        self.qa.query_rewriter = QueryRewriter()
        results["with_rewrite"] = self.eval_generation()
        self.qa.query_rewriter = orig_rw
        # with custom prompt
        orig_pb = self.qa.prompt_builder
        self.qa.prompt_builder = PromptBuilder(
            system_prompt="You are an expert QA assistant.",
            few_shot=[{"q":"Q1","a":"A1"}, {"q":"Q2","a":"A2"}]
        )
        results["with_prompt"] = self.eval_generation()
        self.qa.prompt_builder = orig_pb

        return pd.DataFrame(results).T

# ===== 使用示例 =====
# eval_df = pd.read_parquet(EVAL_PATH)
# evaluator = Evaluator(qa, eval_df)
# print("Retrieval:", evaluator.eval_retrieval())
# print("Generation:", evaluator.eval_generation())
# best = evaluator.grid_search({'retrieve_k':[50,100], 'rerank_k':[5,10]}, metric="F1")
# print("Ablation:\n", evaluator.ablation())


In [None]:
# rag_env.py
import gym
from gym import spaces
import numpy as np

class RAGEnv(gym.Env):
    """
    动作：离散   0~(N-1)  (e.g. top_k 档位 × temperature 档位 × rerank_on/off)
    状态：连续向量 128 维 (可自行调整)
    奖励：一次问答的 F1 / EM / (-loss) 等
    """
    def __init__(self, qa_system, eval_df, topk_choices, temp_choices, rerank_choices):
        super().__init__()
        self.qa = qa_system
        self.data = eval_df.to_dict("records")          # list[dict]
        self.ptr = 0                                    # 当前样本指针

        # === 动作空间离散化 ===
        self.topk_choices  = topk_choices      # e.g. [10,20,50]
        self.temp_choices  = temp_choices      # e.g. [0.7,1.0,1.3]
        self.rerank_choices= rerank_choices    # e.g. [0,1] (off/on)

        self.actions = [(tk,tp,rr)
                        for tk in topk_choices
                        for tp in temp_choices
                        for rr in rerank_choices]
        self.action_space = spaces.Discrete(len(self.actions))

        # === 状态空间（可先用占位向量） ===
        self.observation_space = spaces.Box(-np.inf, np.inf, (128,), dtype=np.float32)

    def _get_state(self):
        """可用检索相似度均值、历史 reward、query 长度等拼成 128 维；先用零向量占位"""
        return np.zeros(128, dtype=np.float32)

    def reset(self):
        self.ptr = np.random.randint(0, len(self.data))  # 每个 episode 随机抽样
        return self._get_state()

    def step(self, action_id):
        query, answer = self.data[self.ptr]["question"], self.data[self.ptr]["answer"]
        # 1) 解码动作
        top_k, temp, rerank_on = self.actions[action_id]

        # 2) 设置到 QA system
        self.qa.retrieve_k = top_k
        self.qa.text_gen.pipe.kwargs["temperature"] = temp   # 修改生成温度
        self.qa.reranker   = self.qa.reranker if rerank_on else None

        # 3) 执行一次 QA
        pred = self.qa.answer(query, generate=True)
        if isinstance(pred, list):
            pred = pred[0]

        # 4) 计算 reward（这里用 EM；可换成 F1、综合分）
        reward = 1.0 if pred.strip().lower() == answer.strip().lower() else 0.0

        # 5) 环境终止：一次问答即结束
        done  = True
        info  = {"pred": pred, "gold": answer}
        next_state = self._get_state()
        return next_state, reward, done, info


In [None]:
# train_rl.py
from stable_baselines3 import PPO
from rag_env import RAGEnv

env = RAGEnv(
    qa_system = qa,          # 你前面实例化好的 QASystem
    eval_df   = eval_df,     # 评估集 DataFrame
    topk_choices  = [10,20,50],
    temp_choices  = [0.7,1.0,1.3],
    rerank_choices= [0,1]
)

model = PPO(
    "MlpPolicy",
    env,
    learning_rate = 3e-4,
    n_steps       = 1024,
    batch_size    = 128,
    gamma         = 0.95,
    verbose       = 1
)

model.learn(total_timesteps = 10_000)
model.save("ppo_rag")


In [None]:
env = RAGEnv(...)

model = PPO.load("ppo_rag", env=env)
obs = env.reset()
rewards = []
for _ in range(len(eval_df)):
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)
    rewards.append(reward)
    if done:
        obs = env.reset()

print("平均 EM：", np.mean(rewards))
