In [1]:
# 基础库
import os
import time
import numpy as np
from scipy import stats
from typing import List, Dict, Any
from tqdm.notebook import tqdm
import pandas as pd

# Google Cloud 认证
from google.auth import load_credentials_from_file

# Langchain 相关
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import ParentDocumentRetriever
from langchain_core.documents import Document
from langchain_google_vertexai import VertexAI, VertexAIEmbeddings
from langchain_chroma import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.storage import InMemoryStore
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

# Ragas 相关
from ragas.metrics import LLMContextRecall, Faithfulness, FactualCorrectness, SemanticSimilarity
from ragas import evaluate
from ragas.llms import LangchainLLMWrapper
from ragas.embeddings import LangchainEmbeddingsWrapper
from ragas.run_config import RunConfig
from ragas import EvaluationDataset # type: ignore

# 类型提示
from typing import cast as t
from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_core.messages import BaseMessage

# PDF处理
import fitz  # PyMuPDF

In [2]:
# 加载认证信息（换成自己的）
credentials, project_id = load_credentials_from_file(
    "./unified-sensor-437622-t3-1c0bfcf1fd30.json"
)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "./unified-sensor-437622-t3-1c0bfcf1fd30.json"

In [3]:
import os 
os.environ["OPENAI_API_KEY"] = "" # 换成自己的

In [4]:
class CustomHybridRetriever:
    def __init__(self, vector_retriever, bm25_retriever):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
    
    def invoke(self, query: str) -> List[Document]:
        """
        Retrieve documents using both retrievers and combine results
        
        Args:
            query: Search query string
            
        Returns:
            List of unique documents, limited to top 4 results
        """
        if not query or pd.isna(query):
            return []
            
        try:
            # 使用新的 invoke 方法
            vector_docs = self.vector_retriever.invoke(query)
            bm25_docs = self.bm25_retriever.invoke(query)
            
            # 合并结果并去重
            seen = set()
            unique_docs = []
            for doc in vector_docs + bm25_docs:
                if doc.page_content not in seen:
                    seen.add(doc.page_content)
                    unique_docs.append(doc)
            
            return unique_docs[:4]  # 返回前4个文档
        except Exception as e:
            print(f"Error in hybrid retrieval: {str(e)}")
            return []

In [5]:
def extract_text_from_pdf(pdf_path: str) -> List[Document]:
    """从PDF文件中提取文本"""
    """
    Extract text content from PDF files
    
    Args:
        pdf_path: Path to the PDF file
        
    Returns:
        List of Document objects containing page content and metadata
    """
    documents = []
    try:
        pdf_document = fitz.open(pdf_path)
        for page_num in range(len(pdf_document)):
            page = pdf_document[page_num]
            text = page.get_text("text")
            if text.strip():
                metadata = {
                    "source": pdf_path,
                    "page": page_num + 1,
                    "total_pages": len(pdf_document)
                }
                documents.append(Document(page_content=text, metadata=metadata))
        pdf_document.close()
    except Exception as e:
        print(f"Error processing {pdf_path}: {str(e)}")
    return documents

In [6]:
def load_pdfs_from_directory(directory_path: str) -> List[Document]:
    """从目录加载PDF文档"""
    """
    Load all PDF documents from a specified directory
    
    Args:
        directory_path: Path to directory containing PDF files
        
    Returns:
        List of Document objects from all PDFs in directory
    """
    documents = []
    with tqdm(os.listdir(directory_path), desc="Loading PDFs") as pbar:
        for filename in pbar:
            if filename.lower().endswith('.pdf'):
                file_path = os.path.join(directory_path, filename)
                pbar.set_postfix(file=filename)
                documents.extend(extract_text_from_pdf(file_path))
    return documents

In [7]:
def custom_is_finished_parser(response: LLMResult):
    """自定义完成状态解析器"""
    is_finished_list = []
    for g in response.flatten():
        resp = g.generations[0][0]
        if resp.generation_info is not None:
            if resp.generation_info.get("finish_reason") is not None:
                is_finished_list.append(
                    resp.generation_info.get("finish_reason") == "STOP"
                )
        elif (
            isinstance(resp, ChatGeneration)
            and t.cast(ChatGeneration, resp).message is not None
        ):
            resp_message: BaseMessage = t.cast(ChatGeneration, resp).message
            if resp_message.response_metadata.get("finish_reason") is not None:
                is_finished_list.append(
                    resp_message.response_metadata.get("finish_reason") == "STOP"
                )
        else:
            is_finished_list.append(True)
    return all(is_finished_list)

In [8]:
class RetrievalEvaluator:
    def __init__(self, data_dir: str, persist_dir: str):
        self.data_dir = data_dir
        self.persist_dir = persist_dir
        
        # 初始化 embeddings
        self.embeddings = VertexAIEmbeddings(
            model_name="textembedding-gecko",
            project=project_id,
            credentials=credentials,
            location="us-central1"
        )
        
        # 初始化LLMs
        self.eval_llm = ChatOpenAI(
            model="gpt-3.5-turbo",
            temperature=0.1
        )
        
        self.exp_llm = VertexAI(
            model_name="gemini-1.0-pro",
            temperature=0.1,
            max_output_tokens=8192,
            project=project_id,
            credentials=credentials
        )
        
        print("Loading documents...")
        self.documents = load_pdfs_from_directory(data_dir)
        print(f"Loaded {len(self.documents)} documents")
        
        print("Initializing retrievers...")
        self.retrievers = {
            "vector": self._init_vector_retriever(),
            "bm25": self._init_bm25_retriever(),
            "hybrid": self._init_hybrid_retriever(),
            "parent_doc": self._init_parent_doc_retriever()
        }
        print("Initialization complete")

    def _init_vector_retriever(self):
        """初始化向量检索器"""
        print("Initializing vector retriever...")
        vectorstore = Chroma.from_documents(
            documents=self.documents,
            embedding=self.embeddings,
            persist_directory=f"{self.persist_dir}/vector"
        )
        return vectorstore.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 4}
        )

    def _init_bm25_retriever(self):
        """初始化BM25检索器"""
        print("Initializing BM25 retriever...")
        return BM25Retriever.from_documents(
            self.documents,
            k=4
        )

    def _init_hybrid_retriever(self):
        """初始化混合检索器"""
        print("Initializing hybrid retriever...")
        vector_retriever = self._init_vector_retriever()
        bm25_retriever = self._init_bm25_retriever()
        return CustomHybridRetriever(vector_retriever, bm25_retriever)

    def _init_parent_doc_retriever(self):
        """初始化父文档检索器"""
        print("Initializing parent document retriever...")
        
        child_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        parent_splitter = RecursiveCharacterTextSplitter(
            chunk_size=2000,
            chunk_overlap=400
        )
        
        vectorstore = Chroma.from_documents(
            documents=self.documents,
            embedding=self.embeddings,
            persist_directory=f"{self.persist_dir}/parent_doc"
        )
        
        store = InMemoryStore()
        
        retriever = ParentDocumentRetriever(
            vectorstore=vectorstore,
            docstore=store,
            child_splitter=child_splitter,
            parent_splitter=parent_splitter,
            search_kwargs={"k": 4}
        )
        
        retriever.add_documents(self.documents)
        return retriever

    def _generate_response(self, query: str, retrieved_docs: List[Document]) -> str:
        """生成回答"""
        context = "\n\n".join([doc.page_content for doc in retrieved_docs])
        prompt = f"Based on the following context, please answer the question:\n\nContext: {context}\n\nQuestion: {query}"
        
        try:
            response = self.exp_llm.invoke(prompt)
            return response.content if hasattr(response, 'content') else str(response)
        except Exception as e:
            print(f"Error generating response: {str(e)}")
            return ""

    def _single_evaluation_run(self, retriever, test_dataset: pd.DataFrame):
        """单次评估运行"""
        try:
            # 创建评估数据集
            valid_results = pd.DataFrame()
            valid_results['user_input'] = test_dataset['question']
            valid_results['reference'] = test_dataset['ground_truth']
            valid_results['response'] = ''
            valid_results['retrieved_contexts'] = [[]] * len(test_dataset)

            print(f"Processing {len(valid_results)} valid queries")
            
            # 处理每个查询
            for idx, row in valid_results.iterrows():
                try:
                    query = row['user_input']
                    if pd.isna(query):
                        continue
                        
                    # 检索文档
                    retrieved_docs = retriever.invoke(query)
                    if not retrieved_docs:
                        print(f"Warning: No documents retrieved for query: {query[:50]}...")
                        continue
                        
                    # 生成回答
                    response = self._generate_response(query, retrieved_docs)
                    
                    # 更新结果
                    valid_results.at[idx, 'response'] = response
                    valid_results.at[idx, 'retrieved_contexts'] = [doc.page_content for doc in retrieved_docs]
                    
                    time.sleep(1)  # 避免速率限制
                    
                except Exception as e:
                    print(f"Error processing query: {query[:50]}...")
                    print(f"Error: {str(e)}")
                    continue

            print(f"Successfully processed {len(valid_results)} queries")

            # 创建评估数据集
            eval_dataset = EvaluationDataset.from_pandas(valid_results)

            # 设置评估器
            evaluator_llm = LangchainLLMWrapper(self.eval_llm)
            evaluator_embeddings = LangchainEmbeddingsWrapper(OpenAIEmbeddings())

            # 定义评估指标
            metrics = [
                LLMContextRecall(llm=evaluator_llm),
                FactualCorrectness(llm=evaluator_llm),
                Faithfulness(llm=evaluator_llm),
                SemanticSimilarity(embeddings=evaluator_embeddings)
            ]

            # 运行评估
            results = evaluate(
                dataset=eval_dataset,
                metrics=metrics
            )

            # 直接返回评估结果
            return {
                "scores": results.scores,
                "results": results,
                "metadata": results.metadata if hasattr(results, 'metadata') else {}
            }

        except Exception as e:
            print(f"Error during evaluation: {str(e)}")
            import traceback
            traceback.print_exc()
            return None

    def evaluate_retriever(self, retriever_name: str, test_dataset: pd.DataFrame, n_runs: int = 1):
        """评估特定检索器的性能"""
        results = []
        retriever = self.retrievers[retriever_name]
        
        try:
            for run_idx in tqdm(range(n_runs), desc=f"Evaluating {retriever_name}", leave=True):
                run_result = self._single_evaluation_run(retriever, test_dataset.copy())
                if run_result:
                    results.append(run_result)
            
            if not results:
                return {
                    "mean_scores": None,
                    "confidence_intervals": None,
                    "raw_results": []
                }
                
            # 直接返回第一次运行的结果
            return results[0]
                
        except Exception as e:
            print(f"Error evaluating {retriever_name}: {str(e)}")
            return {
                "mean_scores": None,
                "confidence_intervals": None,
                "raw_results": [],
                "error": str(e)
            }

In [9]:
def run_evaluation(data_dir: str = "./data", persist_dir: str = "./chroma_db", n_samples: int = 4):
    """运行评估流程"""
    print("Starting evaluation process...")
    
    # 初始化评估器
    evaluator = RetrievalEvaluator(data_dir, persist_dir)
    
    # 加载测试数据集
    print("Loading test dataset...")
    test_dataset = pd.read_json('./data/eval_dataset_1.json')
    if n_samples:
        test_dataset = test_dataset.head(n_samples)
    
    print("\nDataset Info:")
    print(test_dataset.info())
    print(test_dataset)
    print(f"\nLoaded {len(test_dataset)} valid test cases")
    
    # 评估所有检索器
    results = {}
    retrievers = ["vector", "bm25", "hybrid", "parent_doc"]
    
    for retriever in tqdm(retrievers, desc="Evaluating retrievers"):
        print(f"\nEvaluating {retriever}...")
        result = evaluator.evaluate_retriever(retriever, test_dataset)
        results[retriever] = result
    
    # 打印结果
    print("\nEvaluation Results:")
    print("=" * 50)
    
    for retriever in retrievers:
        print(f"\nResults for {retriever}:")
        if results[retriever] and results[retriever].get("scores"):
            print(f"Mean scores: {results[retriever]['scores']}")
        else:
            print("Mean scores: None")
        print("-" * 30)
    
    return results

In [10]:
# 运行评估
results = run_evaluation(n_samples=1)

# 查看结果
for retriever_name, result in results.items():
    if result and result.get("results"):
        print(f"\nResults for {retriever_name}:")
        df = result["results"].to_pandas()
        print(df)

Starting evaluation process...
Loading documents...


Loading PDFs:   0%|          | 0/11 [00:00<?, ?it/s]

Loaded 290 documents
Initializing retrievers...
Initializing vector retriever...
Initializing BM25 retriever...
Initializing hybrid retriever...
Initializing vector retriever...
Initializing BM25 retriever...
Initializing parent document retriever...
Initialization complete
Loading test dataset...

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
Index: 1 entries, 0 to 0
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   question      1 non-null      object
 1   ground_truth  1 non-null      object
 2   answer        1 non-null      object
 3   contexts      1 non-null      object
dtypes: object(4)
memory usage: 40.0+ bytes
None
                                            question  \
0  How does instruction tuning affect the zero-sh...   

                                        ground_truth  \
0  For larger models on the order of 100B paramet...   

                                              answer  \
0  For l

Evaluating retrievers:   0%|          | 0/4 [00:00<?, ?it/s]


Evaluating vector...


Evaluating vector:   0%|          | 0/1 [00:00<?, ?it/s]

Processing 1 valid queries
Successfully processed 1 queries


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


Evaluating bm25...


Evaluating bm25:   0%|          | 0/1 [00:00<?, ?it/s]

Processing 1 valid queries
Successfully processed 1 queries


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


Evaluating hybrid...


Evaluating hybrid:   0%|          | 0/1 [00:00<?, ?it/s]

Processing 1 valid queries
Successfully processed 1 queries


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


Evaluating parent_doc...


Evaluating parent_doc:   0%|          | 0/1 [00:00<?, ?it/s]

Processing 1 valid queries
Successfully processed 1 queries


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]


Evaluation Results:

Results for vector:
Mean scores: [{'context_recall': 0.0, 'factual_correctness': 0.0, 'faithfulness': 0.5384615384615384, 'semantic_similarity': 0.8600790061438462}]
------------------------------

Results for bm25:
Mean scores: [{'context_recall': 1.0, 'factual_correctness': 0.22, 'faithfulness': 0.8947368421052632, 'semantic_similarity': 0.8423422425167596}]
------------------------------

Results for hybrid:
Mean scores: [{'context_recall': 1.0, 'factual_correctness': 0.0, 'faithfulness': 0.8666666666666667, 'semantic_similarity': 0.8634133691045071}]
------------------------------

Results for parent_doc:
Mean scores: [{'context_recall': 0.5, 'factual_correctness': 0.17, 'faithfulness': 0.5, 'semantic_similarity': 0.8379020210931905}]
------------------------------

Results for vector:
                                          user_input  \
0  How does instruction tuning affect the zero-sh...   

                                  retrieved_contexts  \
0  [Publ