In [None]:
# 0) アンインストール
%pip -q uninstall -y \
  langchain langchain-core langchain-community langchain-openai \
  langchain-text-splitters langchain-chroma langchain-classic \
  langchain-graph-retriever chromadb tokenizers numpy || true

# 1) インストール
# LangChain 本体＆プラグイン
# LangChain 0.3 系で統一
%pip -q install \
  "langchain==0.3.12" \
  "langchain-community==0.3.12" \
  "langchain-openai==0.2.8" \
  "langchain-text-splitters==0.3.4" \
  "langchain-chroma==0.1.4"

# Chroma は 0.5 系（GraphRetriever の Chroma アダプタが期待）
%pip -q install "chromadb==0.5.23"

# GraphRetriever 本体（Chroma extras 付き）
%pip -q install "langchain-graph-retriever[chroma]==0.8.0"

# Transformers の要件を満たす tokenizers
%pip -q install "tokenizers==0.23.0"

# その他ユーティリティは現状のままでOK（必要なら）
%pip -q install -U tiktoken pypdf python-docx bs4 chardet numpy
%pip -q install -U rank_bm25 sudachipy sudachidict_full sudachidict_core

# 2) 環境変数にOpenAI APIキー
import os
from google.colab import userdata
os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = userdata.get("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_PROJECT"] = "agent-book"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os, glob, json
import pandas as pd
from typing import List, Dict, Any
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter

In [None]:
DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/rag_week1"  # ←ご指定のパス
DOCS_DIR = f"{DATA_DIR}/docs"
META_CSV = f"{DATA_DIR}/metadata.csv"   # または medata.csv（綴りミス想定も考慮）
QA_JSONL = f"{DATA_DIR}/qa.jsonl"

In [None]:
# =========================
# 1) metadata.csv / qa.jsonl / docs/*.md を読む
# =========================

# 1-1) metadata.csv（綴りゆれ対策）
meta_path = META_CSV if os.path.exists(META_CSV) else f"{DATA_DIR}/medata.csv"
assert os.path.exists(meta_path), f"metadata.csv (or medata.csv) not found under {DATA_DIR}"
meta_df = pd.read_csv(meta_path)

# ファイル名で突合するため末尾ファイル名を追加（CSVの path は別環境の絶対パスでもOK）
meta_df["file_name"] = meta_df["path"].apply(lambda p: os.path.basename(str(p)))

# id/title/category/effective_date/confidentiality/department/product_type を想定
# 列名が大小文字違い等でも最低限拾えるよう軽く正規化
meta_df.columns = [c.strip() for c in meta_df.columns]

# 1-2) qa.jsonl
assert os.path.exists(QA_JSONL), f"qa.jsonl not found under {DATA_DIR}"
qa_items: List[Dict[str, Any]] = []
with open(QA_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        qa_items.append(json.loads(line))
print(f"Loaded QAs: {len(qa_items)}")

# 1-3) docs/*.md
assert os.path.isdir(DOCS_DIR), f"docs dir not found: {DOCS_DIR}"
doc_paths = sorted(glob.glob(os.path.join(DOCS_DIR, "*.md")))
assert len(doc_paths) > 0, f"No .md files found in {DOCS_DIR}"
print(f"Found markdown docs: {len(doc_paths)}")

# 1-4) mdテキスト＋メタデータ組み立て
#     → CSVのファイル名（末尾）と docs のファイル名をキーに突合
meta_index = {row["file_name"]: row for _, row in meta_df.iterrows()}

md_texts: List[str] = []
md_metas: List[Dict[str, Any]] = []

for p in doc_paths:
    fname = os.path.basename(p)
    row = meta_index.get(fname)
    # 基本メタ
    meta = {"source": p}
    if row is not None:
        for col in ["id", "title", "category", "effective_date", "confidentiality", "department", "product_type"]:
            if col in row and pd.notna(row[col]):
                meta[col] = row[col]
    # 本文読み込み
    with open(p, "r", encoding="utf-8") as f:
        text = f.read()
    md_texts.append(text)
    md_metas.append(meta)

print("Docs prepared:", len(md_texts))

In [None]:
# =========================
# 2) 分割（RC / Token / Markdownヘッダ）
# =========================

# 2-1) 汎用：RecursiveCharacter
rc_splitter = RecursiveCharacterTextSplitter(
    chunk_size=800, chunk_overlap=120, add_start_index=True
)
rc_docs: List[Document] = rc_splitter.create_documents(md_texts, metadatas=md_metas)

# 2-2) トークン境界安全（tiktoken由来）
tok_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=300, chunk_overlap=60
)
tok_docs: List[Document] = tok_splitter.create_documents(md_texts, metadatas=md_metas)

# 2-3) Markdownの見出し分割（見出しメタを付与）
md_header_splitter = MarkdownHeaderTextSplitter(
    headers_to_split_on=[("#", "h1"), ("##", "h2"), ("###", "h3")]
)
md_docs: List[Document] = []
for text, base_meta in zip(md_texts, md_metas):
    parts = md_header_splitter.split_text(text)  # List[Document] (見出しメタ付き)
    for d in parts:
        d.metadata = {**base_meta, **(d.metadata or {})}
    md_docs.extend(parts)

print("RC chunks:", len(rc_docs))
print("TOK chunks:", len(tok_docs))
print("MD chunks:", len(md_docs))
if md_docs:
    print("例：MDの1つ目のメタデータ:", md_docs[0].metadata)

In [None]:
# =========================
# 4) ベースラインのベクトル化 & Retriever
# =========================
from langchain_openai import OpenAIEmbeddings # Import the class here
from langchain_chroma import Chroma # Ensure Chroma is also imported if not already

emb = OpenAIEmbeddings(model="text-embedding-3-small")  # コスト重視ならsmall、品質重視なら-large
vs = Chroma(collection_name="wk1_rc_base", embedding_function=emb)
_ = vs.add_documents(rc_docs)
baseline_retriever = vs.as_retriever(search_kwargs={"k": 5})

# 動作確認
probe_q = "サクラ短期国債ファンドの信託報酬"
for i, d in enumerate(baseline_retriever.invoke(probe_q), 1):
    print(f"[{i}] {d.page_content[:60]}...  id={d.metadata.get('id')} title={d.metadata.get('title')}")

In [None]:
# =========================
# 5) QA評価（qa.jsonl に基づく）
#    - Source Hit@k: must_have_source_id が上位kに含まれるか
#    - Answer Hit@k: 期待解答文字列が上位kの本文に現れるか
#    - Noise: 上位kのうち「どちらにも当てはまらない」割合の平均
# =========================
def evaluate_with_qas(retriever, qas: List[Dict[str, Any]], k: int = 5):
    src_hits = 0
    ans_hits = 0
    total = 0
    noise_sum = 0.0

    for qa in qas:
        q = qa["question"]
        must_id = qa.get("must_have_source_id")
        answers = qa.get("answers", [])

        docs = retriever.invoke(q) or []
        docs = list(docs)[:k]

        # Source 判定
        src_hit = any(d.metadata.get("id") == must_id for d in docs)
        src_hits += int(src_hit)

        # Answer 判定（単純な部分一致／必要あれば正規表現や正規化に拡張）
        def doc_has_any_answer(doc) -> bool:
            content = (doc.page_content or "")
            return any(a in content for a in answers)

        ans_hit = any(doc_has_any_answer(d) for d in docs)
        ans_hits += int(ans_hit)

        # Noise = どちらにも当てはまらないドキュメント割合
        noise_flags = []
        for d in docs:
            bad = True
            if must_id and d.metadata.get("id") == must_id:
                bad = False
            if any(a in (d.page_content or "") for a in answers):
                bad = False
            noise_flags.append(1 if bad else 0)

        denom = max(1, len(docs))
        noise_sum += sum(noise_flags) / denom
        total += 1

    return {
        "SourceHit@k": src_hits / max(1, total),
        "AnswerHit@k": ans_hits / max(1, total),
        "Noise": noise_sum / max(1, total),
        "N": total,
    }

print("Baseline (RC) metrics:", evaluate_with_qas(baseline_retriever, qa_items, k=5))

In [None]:
# =========================
# 6) 別分割（TokenSplit / MDHeader）で比較
# =========================
def build_retriever(docs, name):
    v = Chroma(
        collection_name=name,
        embedding_function=OpenAIEmbeddings(model="text-embedding-3-small")
    )
    v.add_documents(docs)
    return v.as_retriever(search_kwargs={"k": 5})

retr_tok = build_retriever(tok_docs, "wk1_tok")
retr_md  = build_retriever(md_docs,  "wk1_md") if md_docs else None

print("TokenSplit metrics:", evaluate_with_qas(retr_tok, qa_items, k=5))
if retr_md:
    print("MDHeader   metrics:", evaluate_with_qas(retr_md, qa_items, k=5))
else:
    print("MDHeader   metrics: (N/A)")

In [None]:
# =========================
# 7) 日本語トークナイズと同義語・ヘッダ連結
# =========================
from sudachipy import dictionary, tokenizer

# Sudachi初期化
_sudachi = dictionary.Dictionary().create()
_mode = tokenizer.Tokenizer.SplitMode.C

# ドメイン用：同義語（必要に応じて追加）
SYNONYMS_JA = {
    "信託報酬": ["運用 管理 費用", "マネジメント フィー"],
    "解約 申込 締切": ["カットオフ", "解約 締切", "申込 締切"],
    "販売 勧誘": ["勧誘", "電話 勧誘", "営業 電話"],
    "信託財産 留保額": ["留保額"],
}

STOP_POS = {"助詞", "助動詞", "記号"}
STOP_TOKENS = {"する", "ある", "なる", "こと", "ため", "および", "により", "について"}

def ja_tokenize_advanced(text: str, add_bigrams: bool = True):
    """Sudachiで分かち書き＋品詞フィルタ＋基本形化＋バイグラム"""
    terms = []
    for m in _sudachi.tokenize(text or "", _mode):
        pos = m.part_of_speech()[0]
        if pos in STOP_POS:
            continue
        w = m.normalized_form()
        if (not w) or (w in STOP_TOKENS):
            continue
        terms.append(w)
    if add_bigrams and len(terms) >= 2:
        terms += [f"{a}_{b}" for a, b in zip(terms, terms[1:])]
    return " ".join(terms)

def expand_with_synonyms(text: str) -> str:
    """粗い同義語展開をテキストに追記（BM25は語一致が命）"""
    buf = [text]
    for base, syns in SYNONYMS_JA.items():
        if all(t not in text for t in [base] + syns):
            continue
        buf.extend(syns)
    return "\n".join(buf)

def attach_header_weight(content: str, meta: dict, title_boost: int = 3) -> str:
    """title/category/product_type などを複製して“実質重み付け”"""
    headers = []
    for key in ("title", "category", "product_type", "department"):
        v = meta.get(key)
        if isinstance(v, str) and v.strip():
            headers.append(v)
    header_text = " / ".join(headers)
    return content + (" " + header_text) * max(0, title_boost)


In [None]:
# =========================
# 8) BM25 Retriverの構築とQuery用ラッパー作成
# =========================
from copy import deepcopy
from langchain_community.retrievers import BM25Retriever

def build_bm25_retriever_ja(docs, k=10, title_boost=3):
    bm25_docs = []
    for d in docs:
        d2 = deepcopy(d)
        enriched = expand_with_synonyms(d2.page_content or "")
        enriched = attach_header_weight(enriched, d2.metadata or {}, title_boost=title_boost)
        d2.page_content = ja_tokenize_advanced(enriched)
        bm25_docs.append(d2)
    retr = BM25Retriever.from_documents(bm25_docs)  # rank_bm25 ベース
    retr.k = k
    return retr

# BM25（改良版）を tok_docs から作成
bm25_retriever = build_bm25_retriever_ja(tok_docs, k=10, title_boost=3)

In [None]:
# =========================
# 8.5) BM25 用クエリ前処理ラッパ（BaseRetriever 継承）+ ID 付与ラッパ
# =========================
from typing import Any, List
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
import hashlib, os
from copy import deepcopy

def _ensure_id(doc):
    if getattr(doc, "metadata", None) is None:
        doc.metadata = {}
    if "id" not in doc.metadata or not doc.metadata["id"]:
        src = doc.metadata.get("source")
        if isinstance(src, str) and src:
            base = os.path.splitext(os.path.basename(src))[0]
            doc.metadata["id"] = base
        else:
            h = hashlib.sha1((doc.page_content or "").encode("utf-8")).hexdigest()[:12]
            doc.metadata["id"] = f"AUTO_{h}"
    return doc

def _retriever_call(ret, query):
    if hasattr(ret, "get_relevant_documents"):
        return ret.get_relevant_documents(query)
    # Runnable 準拠
    return ret.invoke(query)

class BM25QueryWrappedRetriever(BaseRetriever):
    """BM25 に渡す前にクエリを 同義語展開＋分かち書き"""
    bm25: Any = Field(...)  # langchain_community.retrievers.BM25Retriever

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        qx = ja_tokenize_advanced(expand_with_synonyms(query))
        return _retriever_call(self.bm25, qx)

class IdSafeRetriever(BaseRetriever):
    """返却ドキュメントに必ず metadata['id'] を付与"""
    inner: BaseRetriever

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        docs = _retriever_call(self.inner, query) or []
        out = []
        for d in docs:
            d2 = deepcopy(d)
            out.append(_ensure_id(d2))
        return out

bm25_wrapped = BM25QueryWrappedRetriever(bm25=bm25_retriever)
bm25_idsafe  = IdSafeRetriever(inner=bm25_wrapped)
vec_idsafe   = IdSafeRetriever(inner=baseline_retriever)

In [None]:
# =========================
# 9) ハイブリッド（RRF）を構築
# =========================
from langchain_classic.retrievers import EnsembleRetriever
hybrid_retriever = EnsembleRetriever(
    retrievers=[vec_idsafe, bm25_idsafe],
    weights=[0.8, 0.2],
    c=30,
    id_key="id",
)

# デバッグ：peek も idsafe を通す
def peek(ret, q, k=5):
    print("Q:", q)
    docs = ret.invoke(q)[:k]
    for i, d in enumerate(docs, 1):
        print(f"[{i}] id={d.metadata.get('id')} title={d.metadata.get('title')} | {d.page_content[:60]}...")

q1 = "サクラ短期国債ファンドの信託報酬は？"
print("— Vector —"); peek(vec_idsafe, q1)
print("\n— BM25  —"); peek(bm25_idsafe, q1)
print("\n— Hybrid—"); peek(hybrid_retriever, q1)


In [None]:
# =========================
# 10) 改良版での評価
# =========================
print("BM25 only  :", evaluate_with_qas(bm25_idsafe, qa_items, k=5))
print("Vector only:", evaluate_with_qas(vec_idsafe,   qa_items, k=5))
print("Hybrid RRF :", evaluate_with_qas(hybrid_retriever, qa_items, k=5))

In [None]:
# Week3: LLM（クエリ拡張・HyDEで使用）

from langchain_openai import ChatOpenAI
# 温度0で安定的に書き換えを生成（モデルは環境に合わせて調整可）
llm_rewrite = ChatOpenAI(model="gpt-4o-mini", temperature=0)


In [None]:
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
from typing import List, Any
from copy import deepcopy

# 共通: 行分割（箇条書きにも対応）
def _split_lines(text: str) -> List[str]:
    lines = []
    for raw in (text or "").splitlines():
        t = raw.strip()
        if not t:
            continue
        # 「- 」「・」「1. 」などを削る
        t = t.lstrip("-・*0123456789.　 ").strip()
        if t:
            lines.append(t)
    return list(dict.fromkeys(lines))  # 重複除去

# 公式MultiQueryRetriever
class MultiQueryRetriever(BaseRetriever):
    base: BaseRetriever       = Field(...)
    llm:  Any                 = Field(...)
    n_queries: int            = Field(default=4)

    _prompt = (
        "あなたは検索クエリの言い換え生成器です。入力質問に対して、"
        "語彙・言い回し・キーワード展開を変えた検索向けクエリを{n}個、日本語で列挙してください。"
        "出力は各行1クエリのみ。\n\n質問: {q}\n"
    )

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        # 言い換えクエリを生成
        prompt = self._prompt.format(q=query, n=self.n_queries)
        rewrite = self.llm.invoke(prompt).content
        variations = _split_lines(rewrite)[: self.n_queries]
        if not variations:
            variations = [query]

        # 各クエリで検索 → 結果を順序維持でユニーク化
        bucket = []
        seen = set()
        for q in [query] + variations:   # 元の質問も含める
            docs = _retriever_call(self.base, q) or []
            for d in docs:
                key = (d.metadata.get("id") or d.page_content[:80])
                if key in seen:
                    continue
                seen.add(key)
                bucket.append(d)
        return bucket

# MultiQuery retriever を構築
def build_multiquery_retriever(base_retriever, llm, n_queries=5):
    """
    常に MultiQueryCompatRetriever（フォールバック）を返す。
    公式 MultiQueryRetriever の from_llm は PromptTemplate/Runnable を要求し、
    LangChain のバージョン差でエラーになりやすいため、安定の互換実装を採用。
    """
    return MultiQueryRetriever(base=base_retriever, llm=llm, n_queries=n_queries)

# idsafe ラッパでメタデータ id を保証
multiq_retriever = build_multiquery_retriever(baseline_retriever, llm_rewrite, n_queries=5)
multiq_idsafe    = IdSafeRetriever(inner=multiq_retriever)


In [None]:
class HyDERetriever(BaseRetriever):
    base: BaseRetriever = Field(...)
    llm:  Any           = Field(...)

    _prompt = (
        "次の質問に対する簡潔な参考回答（事実ベースの要約）を、検索向けの文章として日本語で作成してください。"
        "列挙や記号は避け、平叙文で150〜300字程度。\n\n質問: {q}\n"
    )

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        hypo = self.llm.invoke(self._prompt.format(q=query)).content
        # 仮想文書と元質問の両方で検索（結合）
        docs_h = _retriever_call(self.base, hypo) or []
        docs_q = _retriever_call(self.base, query) or []
        # 順序維持ユニーク
        out, seen = [], set()
        for d in list(docs_h) + list(docs_q):
            key = (d.metadata.get("id") or d.page_content[:80])
            if key in seen:
                continue
            seen.add(key)
            out.append(d)
        return out

hyde_retriever = HyDERetriever(base=baseline_retriever, llm=llm_rewrite)
hyde_idsafe    = IdSafeRetriever(inner=hyde_retriever)

In [None]:
class RewritingRetriever(BaseRetriever):
    base: BaseRetriever = Field(...)
    llm:  Any           = Field(...)

    _prompt = (
        "あなたは金融ドメインの検索クエリ最適化エージェントです。入力質問を、"
        "検索に強いキーワード列へ変換してください。重要語を保ちつつ、同義語や製品名、"
        "時間表現（例: 14:00/午後2時）などの表記ゆれもカバーしてください。"
        "出力は日本語、箇条書きや記号なし、単一行の検索文字列のみ。\n\n質問: {q}\n"
    )

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        rewritten = self.llm.invoke(self._prompt.format(q=query)).content.strip()
        return _retriever_call(self.base, rewritten) or []

rewriter_retriever = RewritingRetriever(base=baseline_retriever, llm=llm_rewrite)
rewriter_idsafe    = IdSafeRetriever(inner=rewriter_retriever)

In [None]:
def peek(ret, q, k=5):
    print("Q:", q)
    docs = ret.invoke(q)[:k]
    for i, d in enumerate(docs, 1):
        print(f"[{i}] id={d.metadata.get('id')} title={d.metadata.get('title')} | {d.page_content[:60]}...")

q_demo = "TOPIX連動インデックスの信託財産留保額は？"
print("— Baseline —"); peek(baseline_retriever, q_demo)
print("\n— MultiQuery —"); peek(multiq_idsafe, q_demo)
print("\n— HyDE —"); peek(hyde_idsafe, q_demo)
print("\n— Rewriter —"); peek(rewriter_idsafe, q_demo)

In [None]:
print("Baseline   :", evaluate_with_qas(baseline_retriever, qa_items, k=5))
print("MultiQuery :", evaluate_with_qas(multiq_idsafe,        qa_items, k=5))
print("HyDE       :", evaluate_with_qas(hyde_idsafe,          qa_items, k=5))
print("Rewriter   :", evaluate_with_qas(rewriter_idsafe,      qa_items, k=5))

In [None]:
# === 1) MMR retriever: ベクトル検索の上澄みを多様化してノイズ抑制 ===
# Chroma → as_retriever(search_type="mmr", search_kwargs={...})
mmr_retriever = baseline_retriever.vectorstore.as_retriever(
    search_type="mmr",
    search_kwargs={
        "k": 5,        # 最終取得
        "fetch_k": 20, # 一旦広く取ってから多様化
        "lambda_mult": 0.5,  # 0=多様性重視, 1=類似度重視（0.3〜0.7で探索）
    }
)

# Id付与（評価・RRF安定のため）
mmr_idsafe = IdSafeRetriever(inner=mmr_retriever)

print("MMR metrics:", evaluate_with_qas(mmr_idsafe, qa_items, k=5))

In [None]:
import datetime as dt
from typing import Callable, List

def _parse_date(s):
    try:
        return dt.datetime.fromisoformat(str(s)).date()
    except Exception:
        return None

class FilteredRetriever(BaseRetriever):
    """ベースRetrieverの結果（多めに fetch）に対し、メタデータで前処理フィルタ→上位kへ"""
    base: BaseRetriever = Field(...)
    k: int              = Field(default=5)
    fetch_k: int        = Field(default=25)
    filter_fn: Callable = Field(default=lambda d: True)
    sort_key: Callable  = Field(default=None)

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        # まず広めに取る
        docs = _retriever_call(self.base, query) or []
        docs = list(docs)[: self.fetch_k]
        # フィルタ
        docs = [d for d in docs if self.filter_fn(d)]
        # ソート（例：有効日降順）
        if self.sort_key:
            docs.sort(key=self.sort_key, reverse=True)
        return docs[: self.k]

# 例：public優先 + product_type=投資信託 を優先、さらに effective_date が新しい順
def filter_public_fund(d):
    m = d.metadata or {}
    if m.get("product_type") and m.get("product_type") != "投資信託":
        return False
    # internal も拾うなら True を返す運用でもよい（ここでは public 優先）
    conf = (m.get("confidentiality") or "").lower()
    return conf in ("public", "公開", "")

def sort_by_effective_date_desc(d):
    m = d.metadata or {}
    ed = _parse_date(m.get("effective_date"))
    return ed or dt.date(1970,1,1)

filtered_mmr = FilteredRetriever(
    base=mmr_idsafe,
    k=5,
    fetch_k=25,
    filter_fn=filter_public_fund,
    sort_key=sort_by_effective_date_desc,
)
filtered_mmr_idsafe = IdSafeRetriever(inner=filtered_mmr)

print("Filtered+MMR metrics:", evaluate_with_qas(filtered_mmr_idsafe, qa_items, k=5))

In [None]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
import json

llm_rerank = ChatOpenAI(model="gpt-4o-mini", temperature=0)

class LLMRerankRetriever(BaseRetriever):
    base: BaseRetriever = Field(...)
    llm:  Any           = Field(...)
    fetch_k: int        = Field(default=20)
    k: int              = Field(default=5)

    # ★ 波括弧を {{ }} でエスケープ（JSON例・フィールド指定・コードブロック内すべて）
    _prompt = ChatPromptTemplate.from_messages([
        ("system",
         "あなたは検索結果の再ランキングを行うアシスタントです。"
         "与えられたユーザ質問と候補ドキュメント（id, title, snippet）について、"
         "各候補が質問にどれだけ関連するかを 0.0〜1.0 のスコアで評価し、"
         "JSONで返してください。フィールドは {{\"id\": str, \"score\": float}} の配列のみ。"),
        ("human",
         "質問: {question}\n\n候補:\n{candidates}\n\n"
         "出力: JSON配列（例: [{{\"id\":\"DOC001\",\"score\":0.83}}, ...]）のみ。")
    ])

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Any]:
        pool = _retriever_call(self.base, query) or []
        pool = list(pool)[: self.fetch_k]
        if not pool:
            return []

        def short(d):
            return {
                "id": d.metadata.get("id") or "",
                "title": d.metadata.get("title") or "",
                "snippet": (d.page_content or "")[:200]
            }
        ctext = json.dumps([short(d) for d in pool], ensure_ascii=False)

        # ここで {question}, {candidates} を埋め込む
        msg = self._prompt.format_messages(question=query, candidates=ctext)
        out = self.llm.invoke(msg).content

        try:
            scores = json.loads(out)
        except Exception:
            # JSON化失敗時は素朴に上位のまま返す
            return pool[: self.k]

        score_map = {s.get("id"): float(s.get("score", 0.0)) for s in scores if isinstance(s, dict)}
        pool.sort(key=lambda d: score_map.get(d.metadata.get("id"), 0.0), reverse=True)
        return pool[: self.k]

llm_reranked = LLMRerankRetriever(base=filtered_mmr_idsafe, llm=llm_rerank, fetch_k=15, k=5)
llm_reranked_idsafe = IdSafeRetriever(inner=llm_reranked)

print("LLM Rerank metrics:", evaluate_with_qas(llm_reranked_idsafe, qa_items, k=5))

In [None]:
# 互換 import（バージョン差に備えて両方試す）
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_classic.retrievers.document_compressors import LLMChainExtractor

compressor = LLMChainExtractor.from_llm(llm_rerank)  # 同じ LLM でOK（温度0）
compressed_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=llm_reranked_idsafe,  # 直前の再ランク結果をさらに圧縮
)

print("Compressed metrics:", evaluate_with_qas(compressed_retriever, qa_items, k=5))


In [None]:
print("Baseline          :", evaluate_with_qas(baseline_retriever,     qa_items, k=5))
print("MMR               :", evaluate_with_qas(mmr_idsafe,             qa_items, k=5))
print("Filtered+MMR      :", evaluate_with_qas(filtered_mmr_idsafe,    qa_items, k=5))
print("LLM Rerank        :", evaluate_with_qas(llm_reranked_idsafe,   qa_items, k=5))
print("Compressed (final):", evaluate_with_qas(compressed_retriever,   qa_items, k=5))

In [None]:
from langchain_graph_retriever.transformers import ShreddingTransformer
from langchain_graph_retriever import GraphRetriever
from graph_retriever.strategies import Eager  # 公式例の戦略
from typing import Set

# ---- 前提：md_docs / rc_docs のどちらかがある（Week1で生成済み）----
source_docs: List[Document] = md_docs if len(md_docs) > 0 else rc_docs
assert len(source_docs) > 0, "md_docs/rc_docs が見つかりません。先にWeek1の分割を実行してください。"

source_docs = [_ensure_id(deepcopy(d)) for d in source_docs]

# ShreddingTransformer は document を "graph-ready" に変換してくれます
# （内部でエンティティ/エッジ抽出＝トリプル抽出を行い、Docへ分解・整形）
shredder = ShreddingTransformer()

graph_ready_docs: List[Document] = list(shredder.transform_documents(source_docs))
print(f"[Shredding] produced {len(graph_ready_docs)} graph-ready docs")

# ---- 2) VectorStore 構築（Chroma）：グラフ＆元断片を同一ストアに積む公式スタイル ----
emb = OpenAIEmbeddings(model="text-embedding-3-large")

# 公式の例に倣い、Shreddedドキュメントをそのまま Chroma へ
vector_store = Chroma.from_documents(
    documents=graph_ready_docs,
    embedding=emb,
    collection_name="wk5_graph_rag",
)

# ---- 3) エッジ（relation名）の自動検出 → GraphRetriever 構築 ----
# Shredding済みDocの metadata から relation（predicate）候補を拾っておく
def _collect_relations(docs: List[Document]) -> List[tuple]:
    rels: Set[str] = set()
    for d in docs:
        m = d.metadata or {}
        # ライブラリ実装に依存しますが、一般的に relation/predicate 相当のキーが入ります
        for key in ("relation", "predicate", "edge", "p"):
            v = m.get(key)
            if isinstance(v, str) and v.strip():
                rels.add(v.strip())
    # GraphRetriever は (edge_label, relation_key) のタプル列を受け取る仕様
    # 公式デモでは ("habitat", "habitat") のように同名で与えています。
    if not rels:
        # もし抽出できない/ゼロ件でも動くように保険で代表的な金融用語をいくつか置く
        rels = {"信託報酬", "カットオフタイム", "信託財産留保額", "販売勧誘時間"}
    return [(r, r) for r in sorted(rels)]

edges = _collect_relations(graph_ready_docs)
print(f"[Edges] relations detected: {len(edges)} -> {edges[:8]}{' ...' if len(edges)>8 else ''}")

# GraphRetriever（トラバーサル戦略は公式例の Eager）
traversal_retriever = GraphRetriever(
    store=vector_store,
    edges=edges,                          # 関係の種類（predicate名）を与える
    strategy=Eager(k=5, start_k=1, max_depth=2),
)

# ---- 4) 動作確認 ----
def peek(ret, q, k=5):
    print("Q:", q)
    docs = ret.invoke(q)[:k]
    for i, d in enumerate(docs, 1):
        print(f"[{i}] id={d.metadata.get('id')} title={d.metadata.get('title')} | {d.page_content[:80]}...")

q_demo = "サクラ短期国債ファンドの信託報酬は？"
peek(traversal_retriever, q_demo)


In [None]:
# 既存の評価関数を流用（AnswerHit/SourceHit/Noise）
print("GraphRAG (GraphRetriever):", evaluate_with_qas(traversal_retriever, qa_items, k=5))