# Multi-Model Evaluation (Notebook) — Taiwan Industry VAlue Chain QA

This notebook mirrors the **execution order / flow** of the provided `evaluate_langchain_models.py` script, but runs interactively in Jupyter.

It uses the uploaded evaluation libraries:
- `src\utils\config.py`
- `src\utils\providers.py`
- `src\utils\metrics.py`

In [1]:
# 1) Environment + Imports (same as script top section)

import os, re, json, time, asyncio, random, warnings
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from collections import defaultdict
from datetime import datetime

import yaml
from dotenv import load_dotenv

warnings.filterwarnings("ignore", category=UserWarning)
load_dotenv()

# Make the uploaded libs importable as a package
import sys
sys.path.insert(0, "../src/utils")

from config import ModelConfig
from metrics import calculate_metrics, calculate_average_precision
from providers import get_provider

print("✓ Imports OK")


  from .autonotebook import tqdm as notebook_tqdm


✓ Imports OK


## 2) Load Evaluation Config (YAML) + Dataset Metadata (.meta.json)

The original script:
- loads `evaluation_config.yaml`
- loads `dataset.meta.json` (if present) to get chain names, prompts, uncertainty patterns


In [2]:
@dataclass
class EvaluationConfig:
    """Configuration loaded from YAML file (kept notebook-local)."""
    chain_names_file: str = "datasets/chain_names.json"
    prompts: Dict[str, Dict[str, str]] = field(default_factory=dict)
    uncertainty_patterns: List[str] = field(default_factory=list)
    providers: Dict[str, Dict] = field(default_factory=dict)
    defaults: Dict[str, Any] = field(default_factory=dict)
    rate_limiting: Dict[str, float] = field(default_factory=dict)
    output: Dict[str, str] = field(default_factory=dict)

    @classmethod
    def load(cls, config_path: Path) -> "EvaluationConfig":
        if not config_path.exists():
            print(f"⚠ Config not found: {config_path} (using defaults)")
            return cls()
        with open(config_path, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f) or {}
        return cls(
            chain_names_file=data.get("chain_names_file", cls.chain_names_file),
            prompts=data.get("prompts", {}),
            uncertainty_patterns=data.get("uncertainty_patterns", []),
            providers=data.get("providers", {}),
            defaults=data.get("defaults", {}),
            rate_limiting=data.get("rate_limiting", {}),
            output=data.get("output", {}),
        )


@dataclass
class DatasetMetadata:
    """Metadata loaded from dataset .meta.json file."""
    dataset_type: str = ""
    dataset_name: str = ""
    description: str = ""
    task: str = ""
    generated_at: str = ""
    source: str = ""
    chain_names: List[str] = field(default_factory=list)
    default_prompt: Dict[str, str] = field(default_factory=dict)
    uncertainty_patterns: List[str] = field(default_factory=list)
    statistics: Dict[str, Any] = field(default_factory=dict)

    @classmethod
    def load(cls, dataset_path: Path) -> Optional["DatasetMetadata"]:
        meta_path = dataset_path.with_suffix(".meta.json")
        if not meta_path.exists():
            print(f"⚠ Dataset metadata not found: {meta_path}")
            return None
        with open(meta_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        return cls(
            dataset_type=data.get("dataset_type", ""),
            dataset_name=data.get("dataset_name", ""),
            description=data.get("description", ""),
            task=data.get("task", ""),
            generated_at=data.get("generated_at", ""),
            source=data.get("source", ""),
            chain_names=data.get("chain_names", []),
            default_prompt=data.get("default_prompt", {}),
            uncertainty_patterns=data.get("uncertainty_patterns", []),
            statistics=data.get("statistics", {}),
        )


# Load YAML config
EVAL_CONFIG_PATH = Path("../config/evaluation_config.yaml")
eval_config = EvaluationConfig.load(EVAL_CONFIG_PATH)

print("✓ Loaded eval config from:", EVAL_CONFIG_PATH)
print("  providers:", list(eval_config.providers.keys()) or "(empty)")
print("  defaults:", eval_config.defaults or "(empty)")


✓ Loaded eval config from: ..\config\evaluation_config.yaml
  providers: ['openai', 'anthropic', 'google', 'ollama']
  defaults: {'temperature': 0.0, 'max_tokens': 500, 'timeout': 120, 'max_retries': 3}


## 3) Parameters

Fill these variables like CLI flags:
- dataset path
- provider / model
- sampling / max tokens / temperature / timeout


In [3]:
# === USER PARAMETERS (edit these) ===

DATASET_PATH = "../datasets/demo/qa/firm_chains_qa_local.jsonl"  # <- change me
PROVIDER = "ollama"                                     # openai / anthropic / google / ollama
MODEL_NAME = "qwen3:14b"

MAX_SAMPLES = None          # e.g. 50
SAMPLE_RATE = 1.0           # 0.1 = 10% random sample
SAVE_RESULTS = True         # set False if you don't want files
NO_REASONING = True        # only meaningful for Ollama setups

# Generation params (defaults fallback to YAML 'defaults' like the script)
temperature = eval_config.defaults.get("temperature", 0.0)
max_tokens   = eval_config.defaults.get("max_tokens", 500)
timeout      = eval_config.defaults.get("timeout", 120)

print("✓ Parameters set")


✓ Parameters set


## 4) Dataset Type Detection + Prompt / Chain Loading (same order as script)

- load metadata to get chain list + prompts + uncertainty patterns
- fallback: load chain names from `chain_names_file` if meta is absent


In [4]:
def detect_dataset_type(sample: Dict[str, Any]) -> str:
    # matches your script logic
    if "company" in sample and "chains" in sample:
        return "competitors_qa"
    elif "company" in sample and "is_foreign" in sample:
        return "firm_chains_qa"
    elif "chain" in sample:
        return "chain_firms_qa"
    raise ValueError("Unknown dataset format.")


def load_chain_names_fallback(chain_names_file: str) -> List[str]:
    # tries workspace_root relative resolution; in notebooks, we do a best-effort search
    candidates = [
        Path(chain_names_file),
        Path.cwd() / chain_names_file,
        Path.cwd().parent / chain_names_file,
    ]
    for p in candidates:
        if p.exists():
            with open(p, "r", encoding="utf-8") as f:
                data = json.load(f)
            return data.get("chain_names", [])
    print("⚠ chain_names_file not found in common locations:", candidates)
    return []


def load_prompt_and_assets(dataset_path: Path, dataset_type: str, eval_config: EvaluationConfig):
    """Return (chain_names, system_prompt_template, uncertainty_patterns, dataset_metadata)."""
    metadata = DatasetMetadata.load(dataset_path)

    chain_names: List[str] = []
    system_prompt_template: Optional[str] = None
    uncertainty_patterns: List[str] = []

    if metadata:
        chain_names = metadata.chain_names or []
        # prompt override order: YAML > metadata
        config_prompt = (eval_config.prompts.get(dataset_type, {}) or {}).get("system")
        if config_prompt:
            system_prompt_template = config_prompt
        elif metadata.default_prompt.get("system"):
            system_prompt_template = metadata.default_prompt["system"]

        # uncertainty override order: YAML > metadata
        if eval_config.uncertainty_patterns:
            uncertainty_patterns = eval_config.uncertainty_patterns
        elif metadata.uncertainty_patterns:
            uncertainty_patterns = metadata.uncertainty_patterns
    else:
        chain_names = load_chain_names_fallback(eval_config.chain_names_file)
        system_prompt_template = (eval_config.prompts.get(dataset_type, {}) or {}).get("system")
        uncertainty_patterns = eval_config.uncertainty_patterns or [
            "不確定", "無法確定", "不知道", "沒有資料",
            "無法回答", "資訊不足", "不清楚"
        ]

    if not system_prompt_template:
        raise ValueError(f"No system prompt configured for dataset type: {dataset_type}")

    return chain_names, system_prompt_template, uncertainty_patterns, metadata


# Load dataset
dataset_path = Path(DATASET_PATH)
if not dataset_path.exists():
    raise FileNotFoundError(f"Dataset not found: {dataset_path}")

with open(dataset_path, "r", encoding="utf-8") as f:
    dataset = [json.loads(line) for line in f]

dataset_type = detect_dataset_type(dataset[0])
chain_names, system_prompt_template, uncertainty_patterns, dataset_metadata = load_prompt_and_assets(
    dataset_path, dataset_type, eval_config
)

print("✓ Dataset type:", dataset_type)
print("✓ Chain names:", len(chain_names))
print("✓ Have system prompt:", bool(system_prompt_template))
print("✓ Uncertainty patterns:", len(uncertainty_patterns))


✓ Dataset type: firm_chains_qa
✓ Chain names: 47
✓ Have system prompt: True
✓ Uncertainty patterns: 7


## 5) Message Construction + Response Parsing 

- `create_messages(question, dataset_type)`
- `parse_response(response, dataset_type)` with JSON-first for chain_firms/competitors


In [5]:
def create_messages(question: str, dataset_type: str) -> List[Dict[str, str]]:
    if dataset_type == "firm_chains_qa":
        chain_list = "\n".join(f"- {c}" for c in chain_names)
        system_prompt = system_prompt_template.format(chain_list=chain_list)
    else:
        system_prompt = system_prompt_template
    return [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]


def parse_response(response: Optional[str], dataset_type: str) -> List[str]:
    if not response:
        return []

    # Uncertainty -> empty
    for pat in uncertainty_patterns or []:
        if re.search(pat, response, re.IGNORECASE):
            return []

    # JSON-first for chain_firms/competitors
    if dataset_type in ("chain_firms_qa", "competitors_qa"):
        raw = response.strip()
        raw = re.sub(r"^```(?:json)?\s*", "", raw, flags=re.IGNORECASE)
        raw = re.sub(r"\s*```$", "", raw)

        start_idx = None
        for ch in ["{", "["]:
            i = raw.find(ch)
            if i != -1:
                start_idx = i if start_idx is None else min(start_idx, i)

        if start_idx is not None:
            end_idx = max(raw.rfind("}"), raw.rfind("]"))
            candidate = raw[start_idx:end_idx+1] if end_idx != -1 and end_idx > start_idx else raw
        else:
            candidate = raw

        try:
            obj = json.loads(candidate)
            extracted: List[str] = []
            if isinstance(obj, dict):
                for k in ["companies", "company", "competitors", "answer", "predicted", "predicted_answer"]:
                    if k in obj and isinstance(obj[k], list):
                        extracted = obj[k]
                        break
                if not extracted:
                    for v in obj.values():
                        if isinstance(v, list):
                            extracted = v
                            break
            elif isinstance(obj, list):
                extracted = obj

            if extracted:
                cleaned = []
                for x in extracted:
                    if not isinstance(x, str):
                        continue
                    x = x.strip().strip(' \t\r\n"\'')
                    x = re.sub(r"[，。、；：]$", "", x)
                    if len(x) >= 2:
                        cleaned.append(x)
                return cleaned
        except Exception:
            pass

    # line-based fallback
    items: List[str] = []
    for line in response.split("\n"):
        line = line.strip()
        if not line:
            continue
        line = re.sub(r"^[\d\.\-\*\•\→]+\s*", "", line)
        line = re.sub(r"^\s*[-·]\s*", "", line)
        line = re.sub(r"[，。、；：]$", "", line)
        if len(line) < 2:
            continue

        if dataset_type == "firm_chains_qa":
            if "產業鏈" in line:
                items.append(line)
        else:
            if re.search(r"[\u4e00-\u9fff]", line) or re.search(r"[A-Za-z]", line):
                items.append(line)

    return items


print("✓ Helpers ready")


✓ Helpers ready


## 6) Provider Init + Query (async) with Retry + Rate Limiting

This follows the script's approach:
- initialize provider from (provider, model)
- per item: create messages -> provider.generate -> parse_response
- API delay for non-ollama providers


In [6]:
# Build ModelConfig (library version) + init provider
extra_params = {}

# Optional: mimic your script's 'no reasoning' behavior for Ollama.
# The uploaded providers.py always forces think low/False; we keep a toggle here.
if PROVIDER == "ollama" and NO_REASONING:
    extra_params["think"] = False

model_cfg = ModelConfig(
    provider=PROVIDER,
    model_name=MODEL_NAME,
    temperature=temperature,
    max_tokens=max_tokens,
    timeout=timeout,
    api_key=None,       # set here if you don't use env vars
    base_url="http://192.168.1.35:11434/",      # set here if needed for Ollama
    extra_params=extra_params,
)

provider = get_provider(model_cfg)
print("✓ Provider initialized:", type(provider).__name__)


async def query_model(question: str, dataset_type: str, max_retries: int = 3) -> Optional[str]:
    messages = create_messages(question, dataset_type)
    last_err = None
    for attempt in range(max_retries):
        try:
            out = await provider.generate(messages)
            if out:
                return out
        except Exception as e:
            last_err = e
            await asyncio.sleep(2 ** attempt)
    if last_err:
        print("  ⚠ query_model failed:", last_err)
    return None


api_delay = float(eval_config.rate_limiting.get("api_delay", 0.1) or 0.1)
print("✓ api_delay:", api_delay)


✓ Provider initialized: OllamaProvider
✓ api_delay: 0.1


## 7) Evaluation Loop + Save Results + Summary Print


In [7]:
def save_results_json(full_results: Dict[str, Any], dataset_path: Path):
    out_cfg = eval_config.output or {}
    results_dir = out_cfg.get("results_dir", "results")
    timestamp_format = out_cfg.get("timestamp_format", "%Y%m%d_%H%M%S")

    ts = datetime.now().strftime(timestamp_format)
    dataset_name = dataset_path.stem
    model_safe = MODEL_NAME.replace("/", "_").replace(":", "_")

    out_dir = Path.cwd() / results_dir
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"evaluation_results_{dataset_name}_{PROVIDER}_{model_safe}_{ts}.json"
    _print_out_file = f"result\evaluation_results_{dataset_name}_{PROVIDER}_{model_safe}_{ts}.json"
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(full_results, f, ensure_ascii=False, indent=2)

    print("✓ Saved:", _print_out_file)
    return out_file


def print_summary(full_results: Dict[str, Any]):
    m = full_results["average_metrics"]
    by_cat = full_results["by_category"]
    errors = full_results["error_analysis"]
    dtype = full_results["dataset_type"]

    print("="*70)
    print("EVALUATION SUMMARY")
    print("="*70)
    print("Provider:", full_results["provider"])
    print("Model:", full_results["model"])
    print("Dataset type:", dtype)
    print(f"Evaluated: {m['evaluated_samples']} samples")
    print(f"Time: {m['elapsed_time']:.1f}s ({m['avg_time_per_sample']:.2f}s/sample)")
    print("-"*70)
    print(f"Recall: {m['recall']:.4f}  Precision: {m['precision']:.4f}  F1: {m['f1']:.4f}")
    print(f"mAP:   {m['mAP']:.4f}  Exact:     {m['exact_match_rate']:.4f}")
    print("-"*70)

    if dtype == "firm_chains_qa":
        print("BY COMPANY TYPE")
        print("  Local  (is_foreign=False) n=", by_cat[False]["count"], "F1=", f"{by_cat[False]['f1']:.4f}")
        print("  Foreign(is_foreign=True)  n=", by_cat[True]["count"],  "F1=", f"{by_cat[True]['f1']:.4f}")
    else:
        print("BY COMPOSITION")
        for k in ["high_local", "mixed", "high_foreign"]:
            print(f"  {k:12s} n={by_cat[k]['count']}  F1={by_cat[k]['f1']:.4f}")

    print("-"*70)
    print("ERRORS:", errors)
    print("="*70)


async def evaluate_dataset_notebook(dataset: List[Dict[str, Any]]):
    total_samples = len(dataset)

    # sampling (same order as script)
    sampled = dataset
    if SAMPLE_RATE < 1.0:
        random.seed(42)
        sampled = random.sample(sampled, int(total_samples * SAMPLE_RATE))
    if MAX_SAMPLES:
        sampled = sampled[:MAX_SAMPLES]
    
    # ==============================
    # PRINT EVALUATION HEADER
    # ==============================

    print("\n" + "=" * 70)

    if dataset_type == "firm_chains_qa":
        dataset_label = "Firm→Chains Dataset"
    elif dataset_type == "chain_firms_qa":
        dataset_label = "Chain→Firms Dataset"
    else:
        dataset_label = "Competitors Dataset"

    print(f"Evaluating {PROVIDER.upper()} {MODEL_NAME} on {dataset_label}")
    print("=" * 70 + "\n")

    print(f"Dataset: {dataset_path}")
    print(f"Dataset Type: {dataset_type}")
    print(f"Provider: {PROVIDER}")
    print(f"Model: {MODEL_NAME}")
    print(f"Temperature: {temperature}")
    print(f"Total samples: {total_samples}")
    print(f"Evaluating: {len(sampled)} samples")
    print("\nStarting evaluation...\n")

    # ==============================

    results = []
    error_analysis = {"api_errors": 0, "empty_responses": 0, "parse_errors": 0}

    total_recall = total_precision = total_f1 = total_ap = 0.0
    total_exact_match = 0

    by_answer_count = defaultdict(lambda: {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0})

    if dataset_type == "firm_chains_qa":
        by_category = {
            True: {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0},
            False: {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0},
        }
    else:
        by_category = {
            "high_local": {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0},
            "mixed": {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0},
            "high_foreign": {"count": 0, "recall": 0.0, "precision": 0.0, "f1": 0.0},
        }

    start = time.time()

    for idx, item in enumerate(sampled, 1):
        question = item["question"]
        actual = item["answer"]
        answer_count = item.get("answer_count", len(actual))

        # category key logic (same as script)
        if dataset_type == "firm_chains_qa":
            entity = item["company"]
            category_key = item["is_foreign"]
            print(f"[{idx}/{len(sampled)}] {entity} ({answer_count} chains)...", end=" ")
        elif dataset_type == "chain_firms_qa":
            entity = item["chain"]
            local_count = item.get("local_count", 0)
            foreign_count = item.get("foreign_count", 0)
            total = local_count + foreign_count
            ratio = (local_count / total) if total else 0.0
            category_key = "high_local" if ratio > 0.7 else ("high_foreign" if ratio < 0.3 else "mixed")
            print(f"[{idx}/{len(sampled)}] {entity} ({answer_count} companies, {local_count}L/{foreign_count}F)...", end=" ")
        else:
            entity = item.get("company", "UNKNOWN")
            local_count = item.get("local_count", 0)
            foreign_count = item.get("foreign_count", 0)
            total = local_count + foreign_count
            ratio = (local_count / total) if total else 0.0
            category_key = "high_local" if ratio > 0.7 else ("high_foreign" if ratio < 0.3 else "mixed")
            print(f"[{idx}/{len(sampled)}] {entity} ({answer_count} competitors, {local_count}L/{foreign_count}F)...", end=" ")

        # query
        resp = await query_model(question, dataset_type)
        if resp is None:
            print("❌ API Error")
            error_analysis["api_errors"] += 1
            predicted = []
        else:
            predicted = parse_response(resp, dataset_type)
            if not predicted:
                error_analysis["empty_responses"] += 1
            print(f"✓ ({len(predicted)} predicted)")

        # metrics (library functions)
        m_obj = calculate_metrics(predicted, actual)
        ap = calculate_average_precision(predicted, actual)
        m_obj.average_precision = ap
        m = m_obj.to_dict()

        # store result
        result = {
            "index": idx,
            "entity": entity,
            "question": question,
            "actual_answer": actual,
            "predicted_answer": predicted,
            "response": resp,
            "metrics": m,
            "average_precision": ap,
            "dataset_type": dataset_type,
        }

        # add dataset-specific fields
        if dataset_type == "firm_chains_qa":
            result.update({"company": entity, "is_foreign": item["is_foreign"],
                           "actual_chains": actual, "predicted_chains": predicted})
        elif dataset_type == "chain_firms_qa":
            result.update({"chain": entity, "local_count": local_count, "foreign_count": foreign_count,
                           "actual_companies": actual, "predicted_companies": predicted})
        else:
            result.update({"company": entity, "local_count": local_count, "foreign_count": foreign_count,
                           "actual_competitors": actual, "predicted_competitors": predicted})

        results.append(result)

        total_recall += m["recall"]
        total_precision += m["precision"]
        total_f1 += m["f1"]
        total_ap += ap
        total_exact_match += 1 if m["exact_match"] == 1.0 else 0

        by_answer_count[answer_count]["count"] += 1
        by_answer_count[answer_count]["recall"] += m["recall"]
        by_answer_count[answer_count]["precision"] += m["precision"]
        by_answer_count[answer_count]["f1"] += m["f1"]

        by_category[category_key]["count"] += 1
        by_category[category_key]["recall"] += m["recall"]
        by_category[category_key]["precision"] += m["precision"]
        by_category[category_key]["f1"] += m["f1"]

        if PROVIDER != "ollama":
            await asyncio.sleep(api_delay)

    elapsed = time.time() - start
    n = len(sampled) if sampled else 1

    avg_metrics = {
        "recall": total_recall / n,
        "precision": total_precision / n,
        "f1": total_f1 / n,
        "mAP": total_ap / n,
        "exact_match_rate": total_exact_match / n,
        "evaluated_samples": n,
        "total_samples": total_samples,
        "elapsed_time": elapsed,
        "avg_time_per_sample": elapsed / n,
    }

    # average category stats
    for k, cat in by_answer_count.items():
        if cat["count"]:
            cat["recall"] /= cat["count"]
            cat["precision"] /= cat["count"]
            cat["f1"] /= cat["count"]

    for k, cat in by_category.items():
        if cat["count"]:
            cat["recall"] /= cat["count"]
            cat["precision"] /= cat["count"]
            cat["f1"] /= cat["count"]

    full_results = {
        "provider": PROVIDER,
        "model": MODEL_NAME,
        "temperature": temperature,
        "dataset": str(dataset_path),
        "dataset_type": dataset_type,
        "timestamp": datetime.now().isoformat(),
        "average_metrics": avg_metrics,
        "by_answer_count": dict(by_answer_count),
        "by_category": by_category,
        "error_analysis": error_analysis,
        "detailed_results": results,
    }

    return full_results


print("✓ Evaluation loop ready")


✓ Evaluation loop ready


## 8) Run Evaluation (async) and optionally save results


In [None]:
# Run the evaluation
full_results = await evaluate_dataset_notebook(dataset)

print_summary(full_results)

if SAVE_RESULTS:
    save_results_json(full_results, Path(DATASET_PATH))



Evaluating OLLAMA qwen3:14b on Firm→Chains Dataset

Dataset: ..\datasets\demo\qa\firm_chains_qa_local.jsonl
Dataset Type: firm_chains_qa
Provider: ollama
Model: qwen3:14b
Temperature: 0.0
Total samples: 2363
Evaluating: 2363 samples

Starting evaluation...

[1/2363] 91APP*-KY (5 chains)... ✓ (0 predicted)
[2/2363] ACpay (1 chains)... ✓ (1 predicted)
[3/2363] IET-KY (1 chains)... ✓ (1 predicted)
[4/2363] LINEPAY (2 chains)... ✓ (4 predicted)
[5/2363] M31 (1 chains)... ✓ (0 predicted)
[6/2363] Q Burger (1 chains)... ✓ (0 predicted)
[7/2363] TPK-KY (1 chains)... ✓ (1 predicted)
[8/2363] jpp-KY (2 chains)... ✓ (2 predicted)
[9/2363] 一等一科技 (1 chains)... ✓ (0 predicted)
[10/2363] 一詮 (2 chains)... ✓ (0 predicted)
[11/2363] 一零四 (2 chains)... ✓ (0 predicted)
[12/2363] 三一東林 (2 chains)... ✓ (0 predicted)
[13/2363] 三商 (1 chains)... ✓ (2 predicted)
[14/2363] 三商壽 (2 chains)... ✓ (2 predicted)
[15/2363] 三商家購 (1 chains)... ✓ (0 predicted)
[16/2363] 三商電 (2 chains)... ✓ (0 predicted)
[17/2363] 三商餐飲 (1 