# RAG Evaluation (Notebook) — Taiwan Industry VAlue Chain QA

This notebook mirrors the execution order / flow of `evaluate_rag_models.py`:

1. Environment + imports
2. Load YAML config
3. Build Knowledge Base from individual chain JSON files
4. Initialize embeddings + FAISS vector store
5. Initialize LLM provider
6. Retrieve + generate per sample
7. Parse response → compute metrics
8. Save results + print summary


In [None]:
# 1) Environment + Imports (mirrors script top section)
import os, re, json, time, asyncio, random, warnings
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from datetime import datetime

import yaml
from dotenv import load_dotenv

warnings.filterwarnings("ignore", category=UserWarning)
load_dotenv()

# Make the uploaded libs importable
import sys
LIB_DIR = "../src/utils"
if LIB_DIR not in sys.path:
    sys.path.insert(0, LIB_DIR)

# Uploaded evaluation libs
from config import ModelConfig, RAGConfig as LibRAGConfig
from metrics import calculate_metrics, calculate_average_precision
from providers import get_provider, list_available_providers

print("✓ Imports OK")


In [None]:
# 2) Load evaluation_config.yaml (optional defaults / overrides)

YAML_PATH = Path("../config/evaluation_config.yaml")
yaml_cfg = {}
if YAML_PATH.exists():
    with open(YAML_PATH, "r", encoding="utf-8") as f:
        yaml_cfg = yaml.safe_load(f) or {}
    print(f"✓ Loaded YAML: {YAML_PATH}")
else:
    print(f"(i) YAML not found at {YAML_PATH}, using defaults in notebook")


In [None]:
# 3) Notebook parameters (replace argparse)
# You can edit these directly.

# ---- Dataset ----
DATASET_PATH = yaml_cfg.get("dataset", "../datasets/demo/qa/firm_chains_qa_local.jsonl")

# ---- LLM provider/model ----
PROVIDER = yaml_cfg.get("provider", "ollama")  # openai / anthropic / google / ollama
MODEL_NAME = yaml_cfg.get("model", "gemma3:27b-it-q8_0")

# ---- Generation params ----
TEMPERATURE = float(yaml_cfg.get("temperature", 0.0))
MAX_TOKENS = int(yaml_cfg.get("max_tokens", 500))
TIMEOUT = int(yaml_cfg.get("timeout", 30))

# ---- RAG params ----
EMBEDDING_PROVIDER = yaml_cfg.get("embedding_provider", "ollama")  # openai / huggingface / ollama / google
EMBEDDING_MODEL = yaml_cfg.get("embedding_model", None)  # if None, use provider default
TOP_K = int(yaml_cfg.get("top_k", 5))
SCORE_THRESHOLD = float(yaml_cfg.get("score_threshold", 0.0))
DATA_DIR = yaml_cfg.get("data_dir", "../datasets/demo/individual_chains")

# ---- Sampling ----
MAX_SAMPLES = yaml_cfg.get("max_samples", None)  # e.g. 100
SAMPLE_RATE = float(yaml_cfg.get("sample_rate", 1.0))

# ---- Save results ----
SAVE_RESULTS = bool(yaml_cfg.get("save_results", True))
RESULTS_DIR = Path(yaml_cfg.get("results_dir", "../results"))

# ---- Ollama reasoning switch (not in uploaded LibRAGConfig, but kept to mirror your script) ----
# For Ollama only: if False, we will try to reduce/disable thinking via provider extra_params.
ENABLE_REASONING = bool(yaml_cfg.get("enable_reasoning", False))

print("✓ Parameters set")


In [None]:
# 4) LangChain imports for RAG components (FAISS + embeddings)
try:
    from langchain_core.documents import Document
    from langchain_community.vectorstores import FAISS
    from langchain_openai import OpenAIEmbeddings
    from langchain_huggingface import HuggingFaceEmbeddings
    from langchain_ollama import OllamaEmbeddings
    from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError as e:
    raise ImportError(
        "Missing packages for RAG. Install: "
        "pip install langchain langchain-community faiss-cpu langchain-openai "
        "langchain-huggingface langchain-ollama langchain-google-genai"
    ) from e

print("✓ LangChain RAG imports OK")


In [None]:
# 5) Knowledge Base (mirrors TaiwanValueChainKnowledgeBase in your script)

class TaiwanValueChainKnowledgeBase:
    """Knowledge base for Taiwan value chain data."""
    
    def __init__(self, data_dir: str = "../datasets/demo/individual_chains"):
        self.data_dir = Path(data_dir)
        self.documents: List[Document] = []
        self.company_to_chains: Dict[str, set] = {}
        self.chain_to_companies: Dict[str, set] = {}
        self._load_data()
    
    def _load_data(self):
        print(f"Loading Taiwan value chain data from {self.data_dir}...")
        if not self.data_dir.exists():
            raise FileNotFoundError(f"Data directory not found: {self.data_dir}")
        
        json_files = list(self.data_dir.glob("*.json"))
        if not json_files:
            raise FileNotFoundError(f"No JSON files found in {self.data_dir}")
        
        print(f"Found {len(json_files)} value chain files")
        for json_file in json_files:
            try:
                with open(json_file, "r", encoding="utf-8") as f:
                    chain_data = json.load(f)
                self._process_value_chain(chain_data)
            except Exception as e:
                print(f"Warning: Error loading {json_file}: {e}")
                continue
        
        print(f"Loaded {len(self.documents)} documents")
        print(f"Company-Chain mappings: {len(self.company_to_chains)}")
        print(f"Chain-Company mappings: {len(self.chain_to_companies)}")
        
        sample_companies = list(self.company_to_chains.keys())[:10]
        print(f"Sample companies in KB: {sample_companies}")
        test_companies = ["91APP*-KY", "91APP", "ACpay", "IET-KY", "台積電"]
        found_companies = [c for c in test_companies if c in self.company_to_chains]
        print(f"Test companies found: {found_companies}")
        if found_companies:
            first_company = found_companies[0]
            chains_for_company = self.company_to_chains[first_company]
            print(f"Example: {first_company} belongs to {len(chains_for_company)} chains")
    
    def _process_value_chain(self, chain_data: Dict[str, Any]):
        try:
            chain_name = chain_data.get("title", "")
            introduction = chain_data.get("introduction", "")
            if not chain_name:
                return
            
            # main chain document
            main_doc_content = f"產業鏈名稱: {chain_name}\n"
            if introduction:
                intro_text = introduction[:1000] + "..." if len(introduction) > 1000 else introduction
                main_doc_content += f"介紹: {intro_text}\n"
            self.documents.append(
                Document(
                    page_content=main_doc_content,
                    metadata={"type": "value_chain", "chain_name": chain_name},
                )
            )
            
            if chain_name not in self.chain_to_companies:
                self.chain_to_companies[chain_name] = set()
            
            chains = chain_data.get("chains", [])
            for chain_section in chains:
                self._process_chain_section(chain_section, chain_name)
        except Exception as e:
            print(f"Warning: Error processing chain data: {e}")
    
    def _process_chain_section(self, chain_section: Dict[str, Any], chain_name: str):
        try:
            section_title = chain_section.get("title", "")
            companies = chain_section.get("companies", [])
            section_content = f"產業鏈: {chain_name}\n類別: {section_title}\n"
            companies_in_section = set()
            
            for company_item in companies:
                if isinstance(company_item, dict):
                    if "detailed_data" in company_item:
                        detailed_data = company_item["detailed_data"]
                        sub_companies = detailed_data.get("companies", [])
                        for company_info in sub_companies:
                            if isinstance(company_info, dict):
                                company_name = company_info.get("name", "")
                                is_foreign = company_info.get("is_foreign", False)
                                if company_name:
                                    companies_in_section.add(company_name)
                                    self._add_company_mapping(company_name, chain_name, section_title, is_foreign)
                    elif "name" in company_item:
                        company_name = company_item.get("name", "")
                        is_foreign = company_item.get("is_foreign", False)
                        if company_name:
                            companies_in_section.add(company_name)
                            self._add_company_mapping(company_name, chain_name, section_title, is_foreign)
            
            if companies_in_section:
                section_content += f"包含公司: {', '.join(sorted(companies_in_section))}\n"
            
            self.documents.append(
                Document(
                    page_content=section_content,
                    metadata={
                        "type": "category",
                        "chain_name": chain_name,
                        "category": section_title,
                        "company_count": len(companies_in_section),
                    },
                )
            )
        except Exception as e:
            print(f"Warning: Error processing chain section: {e}")
    
    def _add_company_mapping(self, company_name: str, chain_name: str, category: str, is_foreign: bool):
        try:
            if company_name not in self.company_to_chains:
                self.company_to_chains[company_name] = set()
            self.company_to_chains[company_name].add(chain_name)
            self.chain_to_companies.setdefault(chain_name, set()).add(company_name)
            
            comp_content = (
                f"公司名稱: {company_name}\n"
                f"產業鏈: {chain_name}\n"
                f"類別: {category}\n"
                f"外國公司: {'是' if is_foreign else '否'}\n"
            )
            self.documents.append(
                Document(
                    page_content=comp_content,
                    metadata={
                        "type": "company",
                        "company_name": company_name,
                        "chain_name": chain_name,
                        "category": category,
                        "is_foreign": is_foreign,
                    },
                )
            )
        except Exception as e:
            print(f"Warning: Error adding company mapping: {e}")

print("✓ KnowledgeBase class ready")

In [None]:
# 6) Embeddings initializer (mirrors your script)

def initialize_embeddings(embedding_provider: str, embedding_model: Optional[str] = None, base_url: Optional[str] = None):
    default_embedding_models = {
        "openai": "text-embedding-3-small",
        "huggingface": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        "ollama": "qwen3-embedding:4b",
        "google": "models/gemini-embedding-001",
    }
    model = embedding_model or default_embedding_models.get(embedding_provider, default_embedding_models["huggingface"])

    if embedding_provider == "openai":
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OpenAI API key required for embeddings (OPENAI_API_KEY)")
        return OpenAIEmbeddings(model=model, openai_api_key=api_key), model

    if embedding_provider == "huggingface":
        return HuggingFaceEmbeddings(model_name=model), model

    if embedding_provider == "ollama":
        # OllamaEmbeddings uses local Ollama server
        base = "http://localhost:11434"
        return OllamaEmbeddings(model=model, base_url=base), model

    if embedding_provider == "google":
        api_key = os.getenv("GOOGLE_API_KEY")
        if not api_key:
            raise ValueError("Google API key required for embeddings (GOOGLE_API_KEY)")
        return GoogleGenerativeAIEmbeddings(model=model, google_api_key=api_key), model

    raise ValueError(f"Unsupported embedding provider: {embedding_provider}")

print("✓ Embeddings initializer ready")


In [None]:
# 7) RAG Provider wrapper: uses uploaded providers.py for LLM generation
#    but keeps the same retrieve-and-generate logic as your script.

class RAGModelProvider:
    def __init__(
        self,
        provider: str,
        model_name: str,
        temperature: float,
        max_tokens: int,
        timeout: int,
        base_url: Optional[str],
        enable_reasoning: bool,
        knowledge_base: TaiwanValueChainKnowledgeBase,
        embedding_provider: str,
        embedding_model: Optional[str],
        top_k: int,
        score_threshold: float,
        extra_params: Optional[Dict[str, Any]] = None,
    ):
        self.provider_name = provider
        self.model_name = model_name
        self.knowledge_base = knowledge_base
        self.top_k = top_k
        self.score_threshold = score_threshold
        self.base_url = base_url

        # ---- Initialize LLM using uploaded providers.py ----
        llm_extra = dict(extra_params or {})
        if provider == "ollama":
            # Mirror your script's reasoning switch behavior as best-effort.
            # providers.py currently sets think automatically; we override if requested.
            if not enable_reasoning:
                model_lower = (model_name or "").lower()
                if "gpt" in model_lower:
                    llm_extra["think"] = "low"
                else:
                    llm_extra["think"] = False
        
        model_cfg = ModelConfig(
            provider=provider,
            model_name=model_name,
            temperature=temperature,
            max_tokens=max_tokens,
            timeout=timeout,
            base_url=base_url,
            extra_params=llm_extra,
        )
        self.llm_provider = get_provider(model_cfg)

        # ---- Initialize embeddings + FAISS ----
        self.embeddings, self.embedding_model = initialize_embeddings(
            embedding_provider=embedding_provider,
            embedding_model=embedding_model,
            base_url=base_url,
        )

        print(f"Creating vector store with {len(self.knowledge_base.documents)} documents...")
        print(f"Using {embedding_provider} embeddings ({self.embedding_model})")
        if not self.knowledge_base.documents:
            raise ValueError("No documents in knowledge base")

        self.vector_store = FAISS.from_documents(self.knowledge_base.documents, self.embeddings)
        print("✓ Vector store created successfully")

    async def retrieve_and_generate(
        self,
        question: str,
        dataset_type: str,
        company: Optional[str] = None,
        chains: Optional[List[str]] = None,
    ) -> Dict[str, Any]:
        try:
            print(f"Retrieving documents for: {question[:50]}...")
            search_terms: List[str] = []
            search_term: Optional[str] = None

            if dataset_type == "firm_chains_qa":
                company_match = re.search(r"公司\s+([\w\*\-\+\(\)]+)", question)
                search_term = company_match.group(1) if company_match else question
                search_terms = [search_term]

            elif dataset_type == "chain_firms_qa":
                chain_match = re.search(r"產業鏈\s+([\w\*\-\+\(\)產業鏈]+)", question)
                search_term = chain_match.group(1) if chain_match else question
                if isinstance(search_term, str) and search_term.endswith("產業鏈產業鏈"):
                    search_term = search_term[:-3]
                search_terms = [search_term]

            else:  # competitors_qa
                if company:
                    search_terms.append(company)
                if chains:
                    search_terms.extend(chains)
                if not search_terms:
                    search_terms = [question]

            scored: List[Tuple[Any, float]] = []
            for term in search_terms:
                try:
                    scored.extend(self.vector_store.similarity_search_with_score(term, k=self.top_k))
                except Exception:
                    docs = self.vector_store.similarity_search(term, k=self.top_k)
                    scored.extend([(d, 0.0) for d in docs])

            scored.sort(key=lambda x: x[1])
            seen = set()
            retrieved_docs = []
            for doc, score in scored:
                key = (doc.page_content, tuple(sorted(doc.metadata.items())))
                if key in seen:
                    continue
                seen.add(key)
                retrieved_docs.append(doc)
                if len(retrieved_docs) >= self.top_k:
                    break

            print(f"Retrieved {len(retrieved_docs)} documents")
            if len(retrieved_docs) == 0:
                print(f"No documents retrieved for terms: {search_terms}")

            context = "\n\n".join([f"[文件 {i+1}] {doc.page_content}" for i, doc in enumerate(retrieved_docs)])

            if dataset_type == "firm_chains_qa":
                system_instruction = """你是一位熟悉台灣產業鏈的專家。請根據提供的參考資料回答問題。

參考資料：
{context}

要求：
1. 只列出產業鏈的名稱，每個產業鏈一行
2. 不要包含編號、項目符號或其他格式
3. 產業鏈名稱應該精確，例如「半導體產業鏈」、「電動車產業鏈」等
4. 如果參考資料中沒有相關資訊，請回答「根據提供的資料無法確定」
5. 不要編造或猜測不在參考資料中的產業鏈"""

            elif dataset_type == "chain_firms_qa":
                system_instruction = """你是一位熟悉台灣產業鏈的專家。請根據提供的參考資料回答問題。

參考資料：
{context}

要求：
1. 只列出公司名稱，每個公司一行
2. 不要包含編號、項目符號或其他格式
3. 公司名稱應該精確，包含台灣本地公司和外國公司
4. 如果參考資料中沒有相關資訊，請回答「根據提供的資料無法確定」
5. 不要編造或猜測不在參考資料中的公司名稱"""

            else:
                system_instruction = """你是一位熟悉台灣產業鏈的專家。請根據提供的參考資料回答問題。

參考資料：
{context}

要求：
1. 只列出「競爭對手公司名稱」，每個公司一行
2. 不要包含編號、項目符號或其他格式
3. 公司名稱應該精確，包含台灣本地公司和外國公司
4. 若參考資料不足以判斷，請回答「根據提供的資料無法確定」
5. 不要編造或猜測不在參考資料中的公司名稱"""

            system_text = system_instruction.format(context=context)
            messages = [
                {"role": "system", "content": system_text},
                {"role": "user", "content": question},
            ]

            print(f"Sending prompt to {self.provider_name} {self.model_name}...")
            answer = await self.llm_provider.generate(messages)
            if answer is None:
                raise RuntimeError("LLM returned None")

            answer = answer.strip()
            print(f"Received response: {answer[:100]}...")

            return {
                "question": question,
                "answer": answer,
                "context": context,
                "retrieved_docs_count": len(retrieved_docs),
                "retrieved_docs": [
                    {
                        "content": (doc.page_content[:200] + "...") if len(doc.page_content) > 200 else doc.page_content,
                        "metadata": doc.metadata,
                    }
                    for doc in retrieved_docs
                ],
            }

        except Exception as e:
            import traceback
            print(f"ERROR in retrieve_and_generate: {e}")
            traceback.print_exc()
            return {
                "question": question,
                "answer": None,
                "error": str(e),
                "context": "",
                "retrieved_docs_count": 0,
                "retrieved_docs": [],
            }

print("✓ RAGModelProvider ready")


In [None]:
# 8) Evaluator (mirrors your script, but uses metrics.py calculate_metrics)

class RAGEvaluator:
    def __init__(self, data_dir: str):
        self.knowledge_base = TaiwanValueChainKnowledgeBase(data_dir=data_dir)
        self.dataset_type: Optional[str] = None

    def detect_dataset_type(self, sample: Dict[str, Any]) -> str:
        if "company" in sample and "chains" in sample:
            return "competitors_qa"
        if "company" in sample:
            return "firm_chains_qa"
        if "chain" in sample:
            return "chain_firms_qa"
        raise ValueError("Unknown dataset format. Expected 'company' or 'chain' field in data.")

    def parse_response(self, response: Optional[str], dataset_type: str) -> List[str]:
        if not response:
            return []
        uncertainty_patterns = [
            r"不確定", r"無法確定", r"不知道", r"沒有資料", r"無法回答",
            r"資訊不足", r"不清楚", r"根據提供的資料無法確定",
        ]
        for pat in uncertainty_patterns:
            if re.search(pat, response, re.IGNORECASE):
                return []

        lines = response.split("\n")
        items = []
        for line in lines:
            line = line.strip()
            if not line:
                continue
            line = re.sub(r"^[\d\.\-\*\•\→]+\s*", "", line)
            line = re.sub(r"^\s*[-·]\s*", "", line)
            line = re.sub(r"[，。、；：]$", "", line)
            if len(line) < 2:
                continue

            if dataset_type == "firm_chains_qa":
                if "產業鏈" in line:
                    items.append(line)
            else:
                if re.search(r"[\u4e00-\u9fff]", line) or re.search(r"[A-Za-z]", line):
                    items.append(line)
        return items

    async def evaluate_dataset(
        self,
        dataset_path: str,
        rag_provider: RAGModelProvider,
        max_samples: Optional[int] = None,
        sample_rate: float = 1.0,
        save_results: bool = True,
        results_dir: Path = Path("results"),
    ) -> Dict[str, Any]:

        dataset_path = str(dataset_path)
        with open(dataset_path, "r", encoding="utf-8") as f:
            dataset = [json.loads(line) for line in f]

        self.dataset_type = self.detect_dataset_type(dataset[0])

        print("\n" + "=" * 70)
        print(f"Evaluating RAG with {rag_provider.provider_name.upper()} {rag_provider.model_name}")
        print(f"Dataset Type: {self.dataset_type}")
        print("=" * 70 + "\n")

        total_samples = len(dataset)
        if sample_rate < 1.0:
            random.seed(42)
            dataset = random.sample(dataset, int(total_samples * sample_rate))
        if max_samples:
            dataset = dataset[:max_samples]

        print(f"Dataset: {dataset_path}")
        print(f"Total samples: {total_samples}")
        print(f"Evaluating: {len(dataset)} samples")
        print(f"RAG Config: top_k={rag_provider.top_k}, threshold={rag_provider.score_threshold}")
        print("\nStarting RAG evaluation...\n")

        metrics_list = []
        total_exact = 0
        error_analysis = {"api_errors": 0, "empty_responses": 0, "retrieval_errors": 0}
        detailed_results = []

        start_time = time.time()

        for idx, item in enumerate(dataset, 1):
            question = item["question"]
            actual_answer = item["answer"]
            answer_count = item.get("answer_count", len(actual_answer))

            if self.dataset_type == "firm_chains_qa":
                entity = item["company"]
                print(f"[{idx}/{len(dataset)}] {entity} ({answer_count} chains)...", end=" ")
            elif self.dataset_type == "chain_firms_qa":
                entity = item["chain"]
                print(f"[{idx}/{len(dataset)}] {entity} ({answer_count} companies)...", end=" ")
            else:
                entity = item["company"]
                chains = item.get("chains", [])
                print(f"[{idx}/{len(dataset)}] {entity} ({answer_count} competitors, {len(chains)} chains)...", end=" ")

            if self.dataset_type == "competitors_qa":
                rag_result = await rag_provider.retrieve_and_generate(
                    question,
                    self.dataset_type,
                    company=item.get("company"),
                    chains=item.get("chains", []),
                )
            else:
                rag_result = await rag_provider.retrieve_and_generate(question, self.dataset_type)

            if rag_result.get("error"):
                print("❌ RAG Error")
                error_analysis["api_errors"] += 1
                response = None
                predicted_answer = []
            else:
                response = rag_result.get("answer")
                predicted_answer = self.parse_response(response, self.dataset_type)
                if not predicted_answer:
                    error_analysis["empty_responses"] += 1
                print(f"✓ ({len(predicted_answer)} predicted, {rag_result.get('retrieved_docs_count', 0)} docs)")

            m = calculate_metrics(predicted_answer, actual_answer)
            # mAP/AP (optional) using ordered predicted list
            ap = calculate_average_precision(predicted_answer, actual_answer)
            try:
                m.average_precision = ap
            except Exception:
                pass

            metrics_list.append(m)
            if m.exact_match == 1.0:
                total_exact += 1

            result = {
                "index": idx,
                "entity": entity,
                "question": question,
                "actual_answer": actual_answer,
                "predicted_answer": predicted_answer,
                "response": response,
                "rag_context": rag_result.get("context", ""),
                "retrieved_docs_count": rag_result.get("retrieved_docs_count", 0),
                "retrieved_docs": rag_result.get("retrieved_docs", []),
                "metrics": m.to_dict() if hasattr(m, "to_dict") else {
                    "recall": m.recall,
                    "precision": m.precision,
                    "f1": m.f1,
                    "exact_match": m.exact_match,
                    "average_precision": ap,
                },
                "dataset_type": self.dataset_type,
            }

            # keep dataset-specific fields
            if self.dataset_type == "firm_chains_qa":
                result["company"] = item.get("company")
                result["is_foreign"] = item.get("is_foreign", False)
            elif self.dataset_type == "chain_firms_qa":
                result["chain"] = item.get("chain")
                result["local_count"] = item.get("local_count", 0)
                result["foreign_count"] = item.get("foreign_count", 0)
            else:
                result["company"] = item.get("company")
                result["chains"] = item.get("chains", [])
                result["local_count"] = item.get("local_count", 0)
                result["foreign_count"] = item.get("foreign_count", 0)
                result["is_foreign"] = item.get("is_foreign", False)

            detailed_results.append(result)

            if rag_provider.provider_name != "ollama":
                await asyncio.sleep(0.1)

        elapsed = time.time() - start_time
        n = len(dataset)
        mean_recall = sum(m.recall for m in metrics_list) / n if n else 0.0
        mean_precision = sum(m.precision for m in metrics_list) / n if n else 0.0
        mean_f1 = sum(m.f1 for m in metrics_list) / n if n else 0.0
        mean_ap = sum(getattr(m, "average_precision", 0.0) for m in metrics_list) / n if n else 0.0

        avg_metrics = {
            "recall": mean_recall,
            "precision": mean_precision,
            "f1": mean_f1,
            "map": mean_ap,
            "exact_match_rate": (total_exact / n) if n else 0.0,
            "evaluated_samples": n,
            "total_samples": total_samples,
            "elapsed_time": elapsed,
            "avg_time_per_sample": (elapsed / n) if n else 0.0,
        }

        full_results = {
            "provider": rag_provider.provider_name,
            "model": rag_provider.model_name,
            "embedding_provider": EMBEDDING_PROVIDER,
            "embedding_model": rag_provider.embedding_model,
            "rag_config": {"top_k": rag_provider.top_k, "score_threshold": rag_provider.score_threshold},
            "dataset": str(dataset_path),
            "dataset_type": self.dataset_type,
            "timestamp": datetime.now().isoformat(),
            "average_metrics": avg_metrics,
            "error_analysis": error_analysis,
            "detailed_results": detailed_results,
        }

        if save_results:
            results_dir.mkdir(parents=True, exist_ok=True)
            dataset_name = Path(dataset_path).stem
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_safe = rag_provider.model_name.replace("/", "_").replace(":", "_")
            out = results_dir / f"rag_evaluation_results_{dataset_name}_{rag_provider.provider_name}_{model_safe}_{timestamp}.json"
            with open(out, "w", encoding="utf-8") as f:
                json.dump(full_results, f, ensure_ascii=False, indent=2)
            print(f"\n✓ Detailed results saved to: {out}")

        self.print_summary(full_results)
        return full_results

    def print_summary(self, results: Dict[str, Any]):
        metrics = results["average_metrics"]
        errors = results["error_analysis"]
        rag_cfg = results["rag_config"]

        print("\n" + "=" * 70)
        print("RAG EVALUATION SUMMARY")
        print("=" * 70 + "\n")

        print(f"Provider: {results['provider']}")
        print(f"Model: {results['model']}")
        print(f"Embedding: {results['embedding_provider']} ({results['embedding_model']})")
        print(f"RAG Config: top_k={rag_cfg['top_k']}, threshold={rag_cfg['score_threshold']}")
        print(f"Dataset Type: {results['dataset_type']}")
        print(f"Evaluated: {metrics['evaluated_samples']} samples")
        print(f"Time: {metrics['elapsed_time']:.1f}s ({metrics['avg_time_per_sample']:.2f}s per sample)")

        print("\n" + "=" * 70)
        print("RAG PERFORMANCE METRICS")
        print("=" * 70)
        print(f"  Recall:            {metrics['recall']:.4f} ({metrics['recall']*100:.2f}%)")
        print(f"  Precision:         {metrics['precision']:.4f} ({metrics['precision']*100:.2f}%)")
        print(f"  F1 Score:          {metrics['f1']:.4f} ({metrics['f1']*100:.2f}%)")
        print(f"  mAP:               {metrics['map']:.4f} ({metrics['map']*100:.2f}%)")
        print(f"  Exact Match Rate:  {metrics['exact_match_rate']:.4f} ({metrics['exact_match_rate']*100:.2f}%)")

        print("\n" + "=" * 70)
        print("ERROR ANALYSIS")
        print("=" * 70)
        print(f"  API Errors:        {errors['api_errors']}")
        print(f"  Empty Responses:   {errors['empty_responses']}")
        print(f"  Retrieval Errors:  {errors['retrieval_errors']}")

        detailed = results.get("detailed_results", [])
        perfect = [r for r in detailed if r.get("metrics", {}).get("exact_match", 0.0) == 1.0]
        if perfect:
            print("\n" + "=" * 70)
            print(f"PERFECT PREDICTIONS (showing first 3 of {len(perfect)})")
            print("=" * 70)
            for r in perfect[:3]:
                entity_name = r.get("company", r.get("chain", r.get("entity")))
                print(f"\n  Entity: {entity_name}")
                print(f"  Predicted: {r.get('predicted_answer', [])}")
                print(f"  Retrieved docs: {r.get('retrieved_docs_count', 0)}")

        partial = [r for r in detailed if 0 < r.get("metrics", {}).get("recall", 0.0) < 1.0]
        if partial:
            print("\n" + "=" * 70)
            print(f"PARTIAL MATCHES (showing first 3 of {len(partial)})")
            print("=" * 70)
            for r in partial[:3]:
                entity_name = r.get("company", r.get("chain", r.get("entity")))
                print(f"\n  Entity: {entity_name}")
                print(f"  Actual:    {r.get('actual_answer', [])}")
                print(f"  Predicted: {r.get('predicted_answer', [])}")
                rec = r.get("metrics", {}).get("recall", 0.0)
                prec = r.get("metrics", {}).get("precision", 0.0)
                print(f"  Recall: {rec:.2f}, Precision: {prec:.2f}")
                print(f"  Retrieved docs: {r.get('retrieved_docs_count', 0)}")

        print("\n" + "=" * 70 + "\n")

print("✓ RAGEvaluator ready")


In [None]:
# 9) Run (mirrors __main__ execution)

async def run():
    # Validate dataset path
    if not Path(DATASET_PATH).exists():
        raise FileNotFoundError(f"Dataset file not found: {DATASET_PATH}")
    if not (0.0 < SAMPLE_RATE <= 1.0):
        raise ValueError("Sample rate must be between 0.0 and 1.0")

    evaluator = RAGEvaluator(data_dir=DATA_DIR)

    # Initialize RAG provider
    rag_provider = RAGModelProvider(
        provider=PROVIDER,
        model_name=MODEL_NAME,
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        timeout=TIMEOUT,
        base_url=os.getenv("OLLAMA_BASE_URL", None),
        enable_reasoning=ENABLE_REASONING,
        knowledge_base=evaluator.knowledge_base,
        embedding_provider=EMBEDDING_PROVIDER,
        embedding_model=EMBEDDING_MODEL,
        top_k=TOP_K,
        score_threshold=SCORE_THRESHOLD,
        extra_params={},
    )

    results = await evaluator.evaluate_dataset(
        dataset_path=DATASET_PATH,
        rag_provider=rag_provider,
        max_samples=MAX_SAMPLES,
        sample_rate=SAMPLE_RATE,
        save_results=SAVE_RESULTS,
        results_dir=RESULTS_DIR,
    )

    print("✓ RAG evaluation completed successfully!")
    return results

In [None]:
# Jupyter: run with await
results = await run()
results["average_metrics"]