## RAG Fusion 核心代码实现

In [None]:
import numpy as np
from typing import List, Dict
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModel
import torch

class RAGFusion:
    def __init__(self, retriever, llm_model, num_queries=5):
      self.retriever = retriever
      self.llm_model = llm_model
      self.num_queries = num_queries
      self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
      self.rerank_model = AutoModel.from_pretrained("bert-base-uncased")
    
    def generate_queries(self, original_query: str) -> List[str]:
      """生成多个相关查询"""
      prompt = f"""
      基于以下问题，生成{self.num_queries}个相关的搜索查询。
      确保查询从不同角度探讨原始问题。
      
      原始问题: {original_query}
      
      生成的查询:
      1. """
      
      response = self.llm_model.generate(prompt)
      queries = [original_query]  # 总是包含原始查询
      # 解析生成的查询并添加到列表
      # ... 解析逻辑
      return queries
    
    def reciprocal_rank_fusion(self, all_results: List[List[Dict]], k=60) -> List[Dict]:
      """Reciprocal Rank Fusion 算法"""
      fused_scores = {}
      
      for results in all_results:
        for rank, doc in enumerate(results):
          doc_id = doc['id']
          if doc_id not in fused_scores:
            fused_scores[doc_id] = 0
          fused_scores[doc_id] += 1 / (rank + k + 1)
      
      # 按分数排序
      sorted_docs = sorted(
        [{'id': doc_id, 'score': score} for doc_id, score in fused_scores.items()],
        key=lambda x: x['score'],
        reverse=True
      )
      return sorted_docs
    
    def rerank_documents(self, query: str, documents: List[Dict]) -> List[Dict]:
      """基于语义相似度重排序"""
      # 使用交叉编码器或BERT进行精细重排
      query_embedding = self._get_embedding(query)
      doc_embeddings = [self._get_embedding(doc['content']) for doc in documents]
      
      similarities = [
        self._cosine_similarity(query_embedding, doc_embed)
        for doc_embed in doc_embeddings
      ]
      
      # 按相似度重新排序文档
      reranked_docs = [
        {**doc, 'similarity': sim}
        for doc, sim in zip(documents, similarities)
      ]
      reranked_docs.sort(key=lambda x: x['similarity'], reverse=True)
      
      return reranked_docs
    
    def retrieve(self, query: str) -> List[Dict]:
      """完整的 RAG Fusion 检索流程"""
      # 1. 生成多个查询
      queries = self.generate_queries(query)
      
      # 2. 并行检索
      all_results = []
      for q in queries:
        results = self.retriever.search(q, top_k=10)
        all_results.append(results)
      
      # 3. RRF 融合
      fused_results = self.reciprocal_rank_fusion(all_results)
      
      # 4. 获取完整文档内容
      final_docs = []
      for result in fused_results[:20]:  # 取前20个
        doc = self.retriever.get_document(result['id'])
        final_docs.append(doc)
      
      # 5. 重排序
      reranked_docs = self.rerank_documents(query, final_docs)
      
      return reranked_docs[:10]  # 返回最终前10个文档