## retrieval test

In [None]:
from langchain_community.vectorstores import Chroma
from langchain.retrievers import (
    BM25Retriever,
    EnsembleRetriever,
    ContextualCompressionRetriever,
)
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.evaluation import QAEvalChain
import pandas as pd
import numpy as np
from typing import List, Dict
import json
import time

class RetrievalExperiment:
    def __init__(self, db: Chroma, llm, embedding_model, test_questions: List[Dict]):
        """
        Args:
            db: ChromaDB 인스턴스
            llm: LLM 모델
            embedding_model: 임베딩 모델
            test_questions: [{"question": "질문", "answer": "실제 답변"}, ...] 형식의 테스트 데이터
        """
        self.db = db
        self.llm = llm
        self.embedding_model = embedding_model
        self.test_questions = test_questions
        self.results = []

    def setup_retrievers(self):
        """다양한 retriever 설정"""
        retrievers = {
            "Basic": self.db.as_retriever(
                search_type="similarity",
                search_kwargs={"k": 3}
            ),
            "MMR": self.db.as_retriever(
                search_type="mmr",
                search_kwargs={"k": 3, "fetch_k": 5}
            ),
            "BM25": BM25Retriever.from_documents(
                self.db.get(),
                k=3
            ),
            "Ensemble": self._create_ensemble_retriever(),
            "Contextual": self._create_contextual_retriever(),
        }
        return retrievers

    def _create_ensemble_retriever(self):
        """BM25와 벡터 검색을 결합한 앙상블 retriever 생성"""
        bm25_retriever = BM25Retriever.from_documents(
            self.db.get(),
            k=3
        )
        vector_retriever = self.db.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 3}
        )
        
        return EnsembleRetriever(
            retrievers=[bm25_retriever, vector_retriever],
            weights=[0.5, 0.5]
        )

    def _create_contextual_retriever(self):
        """컨텍스트 기반 압축 retriever 생성"""
        base_retriever = self.db.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 5}
        )
        
        compressor = LLMChainExtractor.from_llm(self.llm)
        
        return ContextualCompressionRetriever(
            base_compressor=compressor,
            base_retriever=base_retriever
        )

    def evaluate_retriever(self, retriever, retriever_name: str):
        """각 retriever의 성능 평가"""
        start_time = time.time()
        metrics = {
            "retriever_name": retriever_name,
            "relevant_docs": 0,
            "response_time": 0,
            "precision": 0,
            "recall": 0
        }
        
        total_questions = len(self.test_questions)
        
        for test_case in self.test_questions:
            # 문서 검색
            docs = retriever.get_relevant_documents(test_case["question"])
            
            # 관련 문서 수 계산
            relevant_docs = sum(1 for doc in docs if any(
                keyword in doc.page_content.lower() 
                for keyword in test_case["answer"].lower().split()
            ))
            
            metrics["relevant_docs"] += relevant_docs
            metrics["precision"] += relevant_docs / len(docs)
            metrics["recall"] += relevant_docs / len(test_case["answer"].split())

        # 평균 계산
        metrics["relevant_docs"] /= total_questions
        metrics["precision"] /= total_questions
        metrics["recall"] /= total_questions
        metrics["response_time"] = time.time() - start_time
        
        self.results.append(metrics)
        return metrics

    def run_experiments(self):
        """모든 retriever에 대한 실험 실행"""
        retrievers = self.setup_retrievers()
        
        for name, retriever in retrievers.items():
            print(f"\nTesting {name} retriever...")
            metrics = self.evaluate_retriever(retriever, name)
            print(f"Results for {name}:")
            print(f"Average relevant documents: {metrics['relevant_docs']:.2f}")
            print(f"Precision: {metrics['precision']:.2f}")
            print(f"Recall: {metrics['recall']:.2f}")
            print(f"Response time: {metrics['response_time']:.2f} seconds")

    def get_best_retriever(self) -> str:
        """최적의 retriever 선정"""
        df = pd.DataFrame(self.results)
        
        # 각 메트릭에 가중치 부여
        weights = {
            'precision': 0.4,
            'recall': 0.3,
            'relevant_docs': 0.2,
            'response_time': 0.1
        }
        
        # response_time은 역수로 변환 (작을수록 좋음)
        df['response_time_inv'] = 1 / df['response_time']
        
        # 정규화
        for col in ['precision', 'recall', 'relevant_docs', 'response_time_inv']:
            df[f'{col}_norm'] = (df[col] - df[col].min()) / (df[col].max() - df[col].min())
        
        # 총점 계산
        df['score'] = (
            weights['precision'] * df['precision_norm'] +
            weights['recall'] * df['recall_norm'] +
            weights['relevant_docs'] * df['relevant_docs_norm'] +
            weights['response_time'] * df['response_time_inv_norm']
        )
        
        best_retriever = df.loc[df['score'].idxmax(), 'retriever_name']
        return best_retriever

    def save_results(self, filename: str = "retrieval_experiment_results.json"):
        """실험 결과 저장"""
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(self.results, f, ensure_ascii=False, indent=2)

# 사용 예시
if __name__ == "__main__":
    # 테스트 데이터 예시
    test_questions = [
        {
            "question": "시각장애인 보조견 동반 출입 거부 시 처벌 규정이 있나요?",
            "answer": "장애인복지법에 따라 과태료 처분을 받을 수 있습니다."
        },
        # 더 많은 테스트 케이스 추가
    ]
    
    # 실험 실행
    experiment = RetrievalExperiment(db, llm, embedding_model, test_questions)
    experiment.run_experiments()
    
    # 최적의 retriever 확인
    best_retriever = experiment.get_best_retriever()
    print(f"\nBest performing retriever: {best_retriever}")
    
    # 결과 저장
    experiment.save_results()