In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install -U pip
!pip install "pydantic>=2,<3" \
            "langchain>=0.3,<0.4" \
            "langchain-core>=0.3,<0.4" \
            "langchain-community>=0.3,<0.4" \
            "langchain-openai>=0.3,<0.4" \
            "langchain-huggingface>=0.3,<0.4" \
            faiss-cpu flask flask-cors pandas pyyaml rank_bm25

Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.3
Collecting langchain-community<0.4,>=0.3
  Downloading langchain_community-0.3.31-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-openai<0.4,>=0.3
  Downloading langchain_openai-0.3.35-py3-none-any.whl.metadata (2.4 kB)
Collecting langchain-huggingface<0.4,>=0.3
  Downloading langchain_huggingface-0.3.1-py3-none-any.whl.metadata (996 bytes)
Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Collecting flask-cors
  Downloading flask_cors-6.0.1-py3-none-any.wh

In [3]:
!pip install rank_bm25
!pip install langchain_google_genai

Collecting langchain_google_genai
  Downloading langchain_google_genai-3.0.3-py3-none-any.whl.metadata (2.7 kB)
Collecting filetype<2.0.0,>=1.2.0 (from langchain_google_genai)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting google-ai-generativelanguage<1.0.0,>=0.7.0 (from langchain_google_genai)
  Downloading google_ai_generativelanguage-0.9.0-py3-none-any.whl.metadata (10 kB)
Collecting langchain-core<2.0.0,>=1.0.0 (from langchain_google_genai)
  Downloading langchain_core-1.0.5-py3-none-any.whl.metadata (3.6 kB)
Downloading langchain_google_genai-3.0.3-py3-none-any.whl (56 kB)
Downloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Downloading google_ai_generativelanguage-0.9.0-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m13.0 MB/s[0m  [33m0:00:00[0m
[?25hDownloading langchain_core-1.0.5-py3-none-any.whl (471 kB)
Installing collected packages: filetype, langchain-core, google-ai-generati

In [6]:
!pip install ragas

Collecting ragas
  Downloading ragas-0.3.9-py3-none-any.whl.metadata (22 kB)
Collecting appdirs (from ragas)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting diskcache>=5.6.3 (from ragas)
  Downloading diskcache-5.6.3-py3-none-any.whl.metadata (20 kB)
Collecting instructor (from ragas)
  Downloading instructor-1.13.0-py3-none-any.whl.metadata (11 kB)
Collecting scikit-network (from ragas)
  Downloading scikit_network-0.33.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.5 kB)
Collecting jiter<1,>=0.4.0 (from openai>=1.0.0->ragas)
  Downloading jiter-0.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.2 kB)
Collecting openai>=1.0.0 (from ragas)
  Downloading openai-2.8.0-py3-none-any.whl.metadata (29 kB)
Collecting pre-commit>=4.3.0 (from instructor->ragas)
  Downloading pre_commit-4.4.0-py2.py3-none-any.whl.metadata (1.2 kB)
Collecting ty>=0.0.1a23 (from instructor->ragas)
  Downloading ty-0.0.1a26-py3-none-

In [4]:
# ============================================================
# 0) IMPORTS & BASIC CONFIG
# ============================================================
from pathlib import Path
import os, pickle, re, json
from typing import List, Dict, Any, Optional
from collections import defaultdict

import numpy as np
import pandas as pd

# LangChain / docs / retrievers
try:
    from langchain.schema import Document
except Exception:
    from langchain_core.documents import Document

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.text_splitter import RecursiveCharacterTextSplitter


# Paths & constants
PHASES = ["I", "II", "III"]
SPLITS = ["train", "valid", "test"]

ROOT_PATH = Path("/content/drive/MyDrive/MyModel(hybrid)")
CSV_BASE = Path(os.getenv("RAG_CSV_BASE", str(ROOT_PATH / "rag")))
FAISS_PATH = ROOT_PATH / "faiss_clinical_trials_store_en_v1_allMiniL6"
BM25_CACHE = ROOT_PATH / "bm25_cache.pkl"

REQ_COLS = [
    "NCT Number","Enrollment","criteria","label",
    "smiles","icdcodes","fused_pred","diseases","drugs"
]

class RetrievalCfg:
    K_INITIAL = 12       # initial candidates from each retriever
    K_FINAL   = 8        # final number of chunks in context
    SIM_THRESHOLD = 0.05 # slightly relaxed similarity filter
    USE_RERANKER  = False

K_INITIAL     = RetrievalCfg.K_INITIAL
K_FINAL       = RetrievalCfg.K_FINAL
SIM_THRESHOLD = RetrievalCfg.SIM_THRESHOLD


# ============================================================
# 1) LOAD DATA
# ============================================================
def _clean_text(s):
    if isinstance(s, str):
        return " ".join(s.replace("\r"," ").replace("\n"," ").split())
    return s

def load_all_data():
    frames = []
    for p in PHASES:
        for sp in SPLITS:
            path = CSV_BASE / f"filtered_phase_{p}_{sp}.csv"
            try:
                df = pd.read_csv(path)
            except FileNotFoundError:
                print(f"[WARN] Missing {path}")
                continue

            for c in REQ_COLS:
                if c not in df.columns:
                    df[c] = "N/A"

            df["criteria"]   = df["criteria"].apply(_clean_text)
            df["phase_meta"] = p
            frames.append(df)

    if not frames:
        raise RuntimeError("No CSVs loaded. Check CSV_BASE and filenames.")

    df = pd.concat(frames, ignore_index=True)
    df["enrollment_num"] = pd.to_numeric(df["Enrollment"], errors="coerce")
    df["fused_pred_num"] = pd.to_numeric(df["fused_pred"], errors="coerce")
    df["is_lung"] = df["diseases"].astype(str).str.lower().str.contains("lung")
    return df

print("[INFO] Loading CSVs...")
df = load_all_data()
print("Rows loaded:", len(df))


# ============================================================
# 2) DOCUMENT CREATION (with FUSED_PRED visible in TEXT)
# ============================================================
splitter = RecursiveCharacterTextSplitter(
    chunk_size=1200,
    chunk_overlap=150
)

def df_to_documents(df: pd.DataFrame) -> List[Document]:
    docs = []
    for _, row in df.iterrows():
        fused_num = row.get("fused_pred_num")
        if pd.isna(fused_num):
            fused_str = "N/A"
            fused_meta = None
        else:
            fused_str = f"{float(fused_num):.3f}"
            fused_meta = float(fused_num)

        label = row.get("label", "N/A")

        header = (
            f"NCT: {row['NCT Number']} | "
            f"PHASE: {row['phase_meta']} | "
            f"ENROLLMENT: {row['Enrollment']} | "
            f"DISEASES: {row['diseases']} | "
            f"DRUGS: {row['drugs']} | "
            f"FUSED_PRED: {fused_str} | "
            f"HISTORICAL_OUTCOME_LABEL: {label}"
        )

        body = (
            f"ELIGIBILITY CRITERIA: {row['criteria']}\n"
            f"ICD: {row['icdcodes']}\n"
            f"SMILES: {row['smiles']}"
        )

        full = f"{header}\n{body}"

        meta = {
            "nct_id": row["NCT Number"],
            "phase": row["phase_meta"],
            "enrollment": row["enrollment_num"],
            "diseases": row["diseases"],
            "drugs": row["drugs"],
            "fused_pred": fused_meta,
            "is_lung": bool(row["is_lung"]),
            "label": row["label"],
        }

        for chunk in splitter.split_text(full):
            docs.append(Document(page_content=chunk, metadata=meta))

    return docs

print("[INFO] Building documents...")
all_docs = df_to_documents(df)
print("Document chunks:", len(all_docs))


# ============================================================
# 3) EMBEDDINGS + FAISS
# ============================================================
emb = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    encode_kwargs={"batch_size": 64}
)

print("[INFO] Building FAISS index...")
faiss_store = FAISS.from_documents(all_docs, emb)
FAISS_PATH.mkdir(parents=True, exist_ok=True)
faiss_store.save_local(str(FAISS_PATH))
print("FAISS vectors:", faiss_store.index.ntotal)


# ============================================================
# 4) BM25 RETRIEVER
# ============================================================
def build_bm25(docs):
    retr = BM25Retriever.from_documents(docs)
    retr.k = K_INITIAL
    return retr

bm25_ret = build_bm25(all_docs)
print("BM25 ready.")


# ============================================================
# 5) NCT HELPERS + LOOKUP
# ============================================================
NCT_REGEX = re.compile(r"\bNCT\d{8}\b", re.I)

def extract_ncts(text: str) -> List[str]:
    return list({m.upper() for m in NCT_REGEX.findall(text or "")})

def normalize_nct(n: str) -> str:
    n = (n or "").upper().strip()
    if n.startswith("NCT") and len(n) == 11 and n[3:].isdigit():
        return n
    digits = "".join(ch for ch in n if ch.isdigit())
    return f"NCT{digits:0>8}" if digits else ""

def _doc_key(d: Document) -> str:
    md = d.metadata or {}
    nid = (md.get("nct_id") or "").strip()
    nid = normalize_nct(nid) if nid else ""
    if nid:
        return f"id::{nid}"
    return "hash::" + str(hash((d.page_content[:512], tuple(sorted((md or {}).items())))))

def build_nct_lookup(vs) -> Dict[str, List[Document]]:
    lookup: Dict[str, List[Document]] = {}
    store = vs.docstore._dict
    for _, doc in store.items():
        md = doc.metadata or {}
        nid = md.get("nct_id")
        if nid:
            nid = normalize_nct(nid)
            lookup.setdefault(nid, []).append(doc)
    print("[OK] NCT lookup with", len(lookup), "distinct trials.")
    return lookup

nct_lookup = build_nct_lookup(faiss_store)


# ============================================================
# 6) COSINE + COMPRESSION + METADATA FILTER
# ============================================================
def _l2norm(x, axis=-1, eps=1e-12):
    n = np.linalg.norm(x, axis=axis, keepdims=True)
    return x / np.maximum(n, eps)

def metadata_filter(docs: List[Document], *, lung_only: Optional[bool]=None) -> List[Document]:
    out = []
    for d in docs:
        md = d.metadata or {}
        if lung_only and not md.get("is_lung", False):
            continue
        out.append(d)
    return out

def compress_with_embeddings(
    query: str,
    docs: List[Document],
    embeddings_model,
    sim_threshold: float = SIM_THRESHOLD,
    topn: int = K_FINAL,
    *,
    pinned: Optional[set] = None,
    key_fn=None
) -> List[Document]:
    if not docs:
        return []
    if key_fn is None:
        key_fn = _doc_key
    pinned = pinned or set()

    qv = embeddings_model.embed_query(query)
    qv = np.array(qv, dtype=float) if isinstance(qv, list) else qv

    doc_vecs, keys = [], []
    for d in docs:
        dv = embeddings_model.embed_documents([d.page_content])[0]
        dv = np.array(dv, dtype=float) if isinstance(dv, list) else dv
        doc_vecs.append(dv)
        keys.append(key_fn(d))

    D = _l2norm(np.vstack(doc_vecs))
    q = _l2norm(qv.reshape(1, -1))
    sims = (D @ q.T).ravel()

    keep = []
    for d, s, k in zip(docs, sims, keys):
        if (s >= sim_threshold) or (k in pinned):
            keep.append((d, s, k))
    if not keep:
        keep = list(zip(docs, sims, keys))
    keep.sort(key=lambda x: x[1], reverse=True)

    pinned_docs = [d for d, _, k in keep if k in pinned]
    non_pinned  = [d for d, _, k in keep if k not in pinned]
    result = pinned_docs + [d for d in non_pinned if d not in pinned_docs]
    return result[:topn]

USE_RERANKER_FLAG = False
reranker = None

def rerank(query: str, docs: List[Document], topn: int = K_FINAL) -> List[Document]:
    if not USE_RERANKER_FLAG or not docs:
        return docs[:topn]
    return docs[:topn]


# ============================================================
# 7) HYBRID RETRIEVER (FAISS + BM25) + retrieve()
# ============================================================
def _call_retriever(ret, q):
    if hasattr(ret, "get_relevant_documents"):
        return ret.get_relevant_documents(q)
    if hasattr(ret, "invoke"):
        out = ret.invoke(q)
        if isinstance(out, list):
            return out
        if isinstance(out, dict) and "documents" in out:
            return out["documents"]
    return []

class MinimalEnsembleRetriever:
    def __init__(self, retrievers, weights=[0.6,0.4], k=K_INITIAL):
        self.retrievers = retrievers
        self.weights = weights
        self.k = k

    def get_relevant_documents(self, query: str) -> List[Document]:
        scored = defaultdict(float)
        picks  = {}
        for w, ret in zip(self.weights, self.retrievers):
            docs = _call_retriever(ret, query)
            for rank, d in enumerate(docs[:self.k], start=1):
                key = _doc_key(d)
                picks[key] = d
                scored[key] += w * (1.0 / rank)
        ranked = sorted(scored.items(), key=lambda x: x[1], reverse=True)
        return [picks[k] for k,_ in ranked[:self.k]]

faiss_ret = faiss_store.as_retriever(search_kwargs={"k": K_INITIAL})
bm25_ret.k = K_INITIAL
hybrid = MinimalEnsembleRetriever([faiss_ret, bm25_ret], [0.6,0.4], K_INITIAL)
print("[OK] Hybrid retriever ready.")

def retrieve(question: str, *, lung_only: Optional[bool]=None) -> List[Document]:
    # 1) Hybrid retrieval
    docs = hybrid.get_relevant_documents(question)

    # 2) Force-add any NCT mentioned in the question
    forced = [normalize_nct(n) for n in extract_ncts(question)]
    forced = [n for n in forced if n in nct_lookup]

    extras = []
    for n in forced:
        extras.extend(nct_lookup[n])

    docs = extras + docs  # extras get pinned in compression

    # 3) Metadata filter
    docs = metadata_filter(docs, lung_only=lung_only)

    # 4) Compress + rerank
    pinned_keys = {_doc_key(d) for d in extras}
    docs = compress_with_embeddings(
        question,
        docs,
        emb,
        SIM_THRESHOLD,
        K_FINAL,
        pinned=pinned_keys,
        key_fn=_doc_key
    )
    docs = rerank(question, docs, topn=K_FINAL)
    return docs

print("[READY] Retrieval stack initialised.")


# ============================================================
# 8) LLMs (Gemini Mini) + caching
# ============================================================
os.environ.pop("GOOGLE_APPLICATION_CREDENTIALS", None)

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage



os.environ["GOOGLE_API_KEY"] = "AIzaSyCIskdnnHmnfiiyp8-SH7D5GmHyNa8NdT4"  # better: export in your shell

llm_judge = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=2048,
    google_api_key=os.getenv("GOOGLE_API_KEY"),
    convert_system_message=True,
)
# --- Semantic Embeddings for RAGAS ---
embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

llm_qa = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=2048,
    google_api_key=os.environ.get("GOOGLE_API_KEY"),
    convert_system_message_to_human=True,
)

llm_judge = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash",
    temperature=0,
    max_output_tokens=2048,
    google_api_key=os.environ.get("GOOGLE_API_KEY"),
    convert_system_message_to_human=True,
)

print("[OK] LLMs for QA and judge ready.")

def _serialize_messages(messages: List[Any]) -> str:
    return json.dumps(
        [
            {
                "type": getattr(m, "type", "unknown"),
                "content": getattr(m, "content", str(m)),
            }
            for m in messages
        ],
        sort_keys=True,
        ensure_ascii=False,
    )

qa_cache: Dict[str, Any] = {}

def cached_llm_invoke(llm, messages: List[Any]):
    key = _serialize_messages(messages)
    if key in qa_cache:
        return qa_cache[key]
    resp = llm.invoke(messages)
    qa_cache[key] = resp
    return resp


# ============================================================
# 9) SYSTEM PROMPT + PROMPT BUILDER
# ============================================================
SYSTEM_PROMPT_QA = """
You are a clinical-trial analysis assistant.

You MUST answer ONLY using the information explicitly provided in the CONTEXT.
You are not allowed to use outside medical knowledge or to guess.

RULES:
1. If the answer is completely missing from the CONTEXT, say:
   "Not specified in the provided context."
2. If the CONTEXT contains only partial information, clearly separate:
   - what is explicitly specified, and
   - what is unknown.
   Do not invent or guess unknown parts.
3. Do NOT hallucinate trial details, drugs, phases, or endpoints.
4. Prefer structured, concise answers (tables, bullet lists, short paragraphs).
5. When referring to a specific trial, always mention its NCT number.
6. When asked for design improvements or recommendations, base them ONLY on
   the patterns and details visible in the CONTEXT.
"""

def build_user_prompt(question: str, contexts: List[str]) -> str:
    if not contexts:
        context_text = "NO CONTEXT FOUND."
    else:
        blocks = [f"[CONTEXT {i+1}]\n{c}" for i, c in enumerate(contexts)]
        context_text = "\n\n".join(blocks)

    return f"""
You will be given a RESEARCH QUERY and a CONTEXT.

CONSTRAINTS:
- Answer ONLY using the CONTEXT.
- If the answer is completely missing, say "Not specified in the provided context."
- If the answer is partial, clearly state what is known vs unknown.
- Do not use external knowledge.

RESEARCH QUERY:
{question}

CONTEXT:
{context_text}

Now provide your answer:
""".strip()


# ============================================================
# 10) QUERY SETS (NEW STRATEGY vs BASELINE)
# ============================================================
NEW_QUERIES: List[Dict[str, Any]] = [
    {
        "id": "Q1_lung_low_pred",
        "question": (
            "You are assisting in the design of a new lung cancer trial. "
            "Using only the trials in the CONTEXT, identify all lung-cancer-related "
            "trials with a predicted success score (fused_pred) of 0.4 or below. "
            "For each such trial, report: NCT number, phase, enrollment, "
            "disease indication(s), and main intervention(s). "
            "If the CONTEXT contains no lung trials with fused_pred <= 0.4, "
            "state this explicitly and instead list the lung cancer trials that "
            "have the lowest available fused_pred values in the CONTEXT. "
            "Then, in a short 'Recommendations' section, propose 2–3 concrete "
            "design changes (based only on CONTEXT) that might improve the "
            "predicted success or operational robustness of these low- or "
            "lower-scoring lung trials."
        ),
        "filters": {"lung_only": True},
    },
    {
        "id": "Q2_disease_comparison",
        "question": (
            "Compare all lung cancer trials in the CONTEXT. For each trial, provide: "
            "NCT number, phase, enrollment, and the drug(s) tested. Then compare their "
            "design choices along at least three axes (for example: sample size, phase, "
            "strictness of inclusion/exclusion criteria, single-arm vs multi-arm, or "
            "disease subtype). Based only on the CONTEXT, identify which lung trial "
            "appears most at risk of operational failure (e.g., poor recruitment or "
            "feasibility issues) and explain why, explicitly referencing NCT numbers. "
            "Finish with a brief 'Recommendations' section suggesting design adjustments "
            "that could reduce this operational risk."
        ),
        "filters": {"lung_only": True},
    },
    {
        "id": "Q3_failure_drivers",
        "question": (
            "From the trials in the CONTEXT, list any explicit design characteristics "
            "that could increase the risk of trial failure. Classify each into one of "
            "the following categories: efficacy-related, safety-related, "
            "recruitment-related, or endpoint-related. For each category, provide "
            "specific examples citing NCT numbers. Then, in a 'Mitigation Strategies' "
            "section, propose 1–2 possible design adjustments per category that could "
            "reduce the risk of failure, staying strictly within what is implied by "
            "the CONTEXT."
        ),
        "filters": {"lung_only": None},
    },
    {
        "id": "Q4_trial_table",
        "question": (
            "Reconstruct a table of all distinct trials present in the CONTEXT with "
            "columns: NCT number, phase, enrollment, diseases, drugs, and any available "
            "predicted success score (fused_pred). Sort the table by fused_pred in "
            "ascending order (from lowest predicted success to highest), placing trials "
            "with missing fused_pred at the bottom. After the table, briefly comment on "
            "any observed patterns in phase, enrollment, and disease areas across the "
            "range of fused_pred values."
        ),
        "filters": {"lung_only": None},
    },
    {
        "id": "Q5_high_vs_low_pred",
        "question": (
            "Within the CONTEXT, contrast trials with high predicted success scores "
            "(fused_pred > 0.7) against those with low scores (fused_pred < 0.3). "
            "For each group, summarise the typical phase, enrollment ranges, and "
            "disease areas that appear in the CONTEXT. Then describe any design "
            "patterns that seem to distinguish high-scoring from low-scoring trials "
            "(for example: phase, sample size, disease category, or other features "
            "visible in the CONTEXT). Conclude with a short 'Design Heuristics' "
            "section outlining 2–3 principles that might help move a trial from the "
            "low-scoring group toward the high-scoring group, based only on what you "
            "see in the CONTEXT."
        ),
        "filters": {"lung_only": None},
    },
]

BASELINE_QUERIES: List[Dict[str, Any]] = [
    {
        "id": "B1_lung_low_pred",
        "question": (
            "List all lung cancer trials with fused_pred <= 0.4 from the CONTEXT. "
            "Give their NCT numbers, phases, and enrollment. If no such trials can be "
            "identified from the CONTEXT, state that explicitly."
        ),
        "filters": {"lung_only": True},
    },
    {
        "id": "B2_all_lung",
        "question": (
            "Using the CONTEXT, describe all lung cancer trials: their NCT numbers, "
            "phase, and enrollment."
        ),
        "filters": {"lung_only": True},
    },
    {
        "id": "B3_general_design",
        "question": (
            "Summarise the design of the trials in the CONTEXT, including diseases, "
            "phases, enrollment, and interventions."
        ),
        "filters": {"lung_only": None},
    },
]


# ============================================================
# 11) RAG ANSWER GENERATION
# ============================================================
def run_query_batch(
    queries: List[Dict[str, Any]],
    llm,
    *,
    name_prefix: str = "run",
) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []

    for q in queries:
        qid = f"{name_prefix}_{q['id']}"
        question = q["question"]
        filters = q.get("filters", {}) or {}

        docs = retrieve(
            question,
            lung_only=filters.get("lung_only", None),
        )
        contexts = [d.page_content for d in docs]

        user_prompt = build_user_prompt(question, contexts)
        messages = [
            SystemMessage(content=SYSTEM_PROMPT_QA),
            HumanMessage(content=user_prompt),
        ]

        try:
            resp = cached_llm_invoke(llm, messages)
            answer = (resp.content or "").strip()
        except Exception as e:
            answer = f"[ERROR during LLM call: {e}]"

        results.append(
            {
                "id": qid,
                "question": question,
                "answer": answer,
                "contexts": contexts,
            }
        )
        print(f"[OK] Generated answer for {qid} (context chunks: {len(contexts)})")

    return results

print("\n=== Generating NEW STRATEGY answers ===")
new_results = run_query_batch(NEW_QUERIES, llm_qa, name_prefix="strategy")

print("\n=== Generating BASELINE answers ===")
baseline_results = run_query_batch(BASELINE_QUERIES, llm_qa, name_prefix="baseline")


# ============================================================
# 12) RAGAS EVALUATION (faithfulness + answer_relevancy)
# ============================================================
!pip install -q ragas datasets

from datasets import Dataset
from ragas import evaluate
from ragas.metrics import faithfulness, answer_relevancy

def results_to_dataset_no_ref(results: List[Dict[str, Any]]) -> Dataset:
    return Dataset.from_dict(
        {
            "question": [r["question"] for r in results],
            "answer":   [r["answer"]   for r in results],
            "contexts": [r["contexts"] for r in results],
        }
    )

metrics_no_ref = [faithfulness, answer_relevancy]

# NEW STRATEGY
ds_new = results_to_dataset_no_ref(new_results)
scores_new = evaluate(
    ds_new,
    metrics=metrics_no_ref,
    llm=llm_judge,
    embeddings=emb,
)

print("\n=== RAGAS – NEW STRATEGY ===")
print(scores_new)

# BASELINE
ds_base = results_to_dataset_no_ref(baseline_results)
scores_base = evaluate(
    ds_base,
    metrics=metrics_no_ref,
    llm=llm_judge,
    embeddings=emb,
)

print("\n=== RAGAS – BASELINE ===")
print(scores_base)


# ============================================================
# 13) SUMMARY + DATAFRAME (context, answer, metrics)
# ============================================================
def summarize_scores(result, label: str):
    print(f"\n--- {label} ---")
    df_m = result.to_pandas()
    means = df_m.mean(numeric_only=True)
    for metric_name, value in means.items():
        print(f"{metric_name}: {value:.4f}")

summarize_scores(scores_new,  "NEW STRATEGY")
summarize_scores(scores_base, "BASELINE")

def ragas_to_df(result):
    df_m = result.to_pandas()
    if "faithfulness" in df_m.columns:
        df_m["faithfulness"] = df_m["faithfulness"].astype(float)
    if "answer_relevancy" in df_m.columns:
        df_m["answer_relevancy"] = df_m["answer_relevancy"].astype(float)
    return df_m

def build_results_df(results: List[Dict[str, Any]], ragas_result, label: str) -> pd.DataFrame:
    df_rag = pd.DataFrame(results)
    df_rag["context_text"] = df_rag["contexts"].apply(lambda xs: "\n---\n".join(xs))
    df_metrics = ragas_to_df(ragas_result)
    df_full = pd.concat([df_rag, df_metrics], axis=1)
    df_full["model"] = label
    return df_full

df_new  = build_results_df(new_results, scores_new,  "NEW_STRATEGY")
df_base = build_results_df(baseline_results, scores_base, "BASELINE")

df_all = pd.concat([df_new, df_base], ignore_index=True)

print("\n=== Combined RAG + RAGAS DataFrame ===")
df_all

# Optionally save:
df_all.to_csv("ragas_results_full.csv", index=False)


[INFO] Loading CSVs...
Rows loaded: 3666
[INFO] Building documents...
Document chunks: 16587


  emb = HuggingFaceEmbeddings(


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[INFO] Building FAISS index...
FAISS vectors: 16587
BM25 ready.
[OK] NCT lookup with 3574 distinct trials.
[OK] Hybrid retriever ready.
[READY] Retrieval stack initialised.


                convert_system_message was transferred to model_kwargs.
                Please confirm that convert_system_message is what you intended.
  exec(code_obj, self.user_global_ns, self.user_ns)


[OK] LLMs for QA and judge ready.

=== Generating NEW STRATEGY answers ===
[OK] Generated answer for strategy_Q1_lung_low_pred (context chunks: 8)
[OK] Generated answer for strategy_Q2_disease_comparison (context chunks: 7)
[OK] Generated answer for strategy_Q3_failure_drivers (context chunks: 8)
[OK] Generated answer for strategy_Q4_trial_table (context chunks: 8)
[OK] Generated answer for strategy_Q5_high_vs_low_pred (context chunks: 8)

=== Generating BASELINE answers ===
[OK] Generated answer for baseline_B1_lung_low_pred (context chunks: 8)
[OK] Generated answer for baseline_B2_all_lung (context chunks: 8)
[OK] Generated answer for baseline_B3_general_design (context chunks: 8)
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
langchain-google-genai 3.0.3 requires langchain-core<2.0.0,>=1.0.0, but you have langchain-core 0.3.79 which is incompatible.[0

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 15, model: gemini-2.0-flash
Please retry in 48.831685436s. [links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerMinutePerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-2.0-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 15
}
, retry_delay {
  seconds: 48
}
].
* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 15, model: gemini-2.0-flash
Please retry in 48.831428058s. [links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, violations {
  quota_metric: "generativelanguage.googleapis.


=== RAGAS – NEW STRATEGY ===
{'faithfulness': 0.9444, 'answer_relevancy': 0.7093}


Evaluating:   0%|          | 0/6 [00:00<?, ?it/s]

* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 15, model: gemini-2.0-flash
Please retry in 45.987008743s. [links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, violations {
  quota_metric: "generativelanguage.googleapis.com/generate_content_free_tier_requests"
  quota_id: "GenerateRequestsPerMinutePerProjectPerModel-FreeTier"
  quota_dimensions {
    key: "model"
    value: "gemini-2.0-flash"
  }
  quota_dimensions {
    key: "location"
    value: "global"
  }
  quota_value: 15
}
, retry_delay {
  seconds: 45
}
].
* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 15, model: gemini-2.0-flash
Please retry in 43.913146218s. [links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, violations {
  quota_metric: "generativelanguage.googleapis.


=== RAGAS – BASELINE ===
{'faithfulness': 1.0000, 'answer_relevancy': 0.8251}

--- NEW STRATEGY ---
faithfulness: 0.9444
answer_relevancy: 0.7093

--- BASELINE ---
faithfulness: 1.0000
answer_relevancy: 0.8251

=== Combined RAG + RAGAS DataFrame ===


In [5]:
df_all

Unnamed: 0,id,question,answer,contexts,context_text,user_input,retrieved_contexts,response,faithfulness,answer_relevancy,model
0,strategy_Q1_lung_low_pred,You are assisting in the design of a new lung ...,"Based on the provided context, the lung cancer...",[NCT: NCT00162318 | PHASE: I | ENROLLMENT: 30....,NCT: NCT00162318 | PHASE: I | ENROLLMENT: 30.0...,You are assisting in the design of a new lung ...,[NCT: NCT00162318 | PHASE: I | ENROLLMENT: 30....,"Based on the provided context, the lung cancer...",1.0,0.801968,NEW_STRATEGY
1,strategy_Q2_disease_comparison,Compare all lung cancer trials in the CONTEXT....,Here's a comparison of the lung cancer trials ...,[NCT: NCT00600821 | PHASE: II | ENROLLMENT: 11...,NCT: NCT00600821 | PHASE: II | ENROLLMENT: 118...,Compare all lung cancer trials in the CONTEXT....,[NCT: NCT00600821 | PHASE: II | ENROLLMENT: 11...,Here's a comparison of the lung cancer trials ...,0.833333,0.718272,NEW_STRATEGY
2,strategy_Q3_failure_drivers,"From the trials in the CONTEXT, list any expli...","Based on the provided context, here are potent...",[hampering compliance with the study protocol ...,hampering compliance with the study protocol r...,"From the trials in the CONTEXT, list any expli...",[hampering compliance with the study protocol ...,"Based on the provided context, here are potent...",,0.796983,NEW_STRATEGY
3,strategy_Q4_trial_table,Reconstruct a table of all distinct trials pre...,| NCT Number | Phase | Enrollment | Diseases |...,[NCT: NCT00396864 | PHASE: I | ENROLLMENT: 51....,NCT: NCT00396864 | PHASE: I | ENROLLMENT: 51.0...,Reconstruct a table of all distinct trials pre...,[NCT: NCT00396864 | PHASE: I | ENROLLMENT: 51....,| NCT Number | Phase | Enrollment | Diseases |...,1.0,0.502983,NEW_STRATEGY
4,strategy_Q5_high_vs_low_pred,"Within the CONTEXT, contrast trials with high ...",**Trials with High Predicted Success (fused_pr...,[NCT: NCT02871440 | PHASE: III | ENROLLMENT: 4...,NCT: NCT02871440 | PHASE: III | ENROLLMENT: 40...,"Within the CONTEXT, contrast trials with high ...",[NCT: NCT02871440 | PHASE: III | ENROLLMENT: 4...,**Trials with High Predicted Success (fused_pr...,,0.726461,NEW_STRATEGY
5,baseline_B1_lung_low_pred,List all lung cancer trials with fused_pred <=...,There are no lung cancer trials with fused_pre...,[NCT: NCT00864721 | PHASE: II | ENROLLMENT: 63...,NCT: NCT00864721 | PHASE: II | ENROLLMENT: 63....,List all lung cancer trials with fused_pred <=...,[NCT: NCT00864721 | PHASE: II | ENROLLMENT: 63...,There are no lung cancer trials with fused_pre...,1.0,0.756768,BASELINE
6,baseline_B2_all_lung,"Using the CONTEXT, describe all lung cancer tr...","Here is a summary of the lung cancer trials, b...",[NCT: NCT01441752 | PHASE: III | ENROLLMENT: 3...,NCT: NCT01441752 | PHASE: III | ENROLLMENT: 34...,"Using the CONTEXT, describe all lung cancer tr...",[NCT: NCT01441752 | PHASE: III | ENROLLMENT: 3...,"Here is a summary of the lung cancer trials, b...",1.0,0.864497,BASELINE
7,baseline_B3_general_design,Summarise the design of the trials in the CONT...,"Based on the provided context, here's a summar...","[the trial. Psychological, familial, sociologi...","the trial. Psychological, familial, sociologic...",Summarise the design of the trials in the CONT...,"[the trial. Psychological, familial, sociologi...","Based on the provided context, here's a summar...",,0.853964,BASELINE


In [6]:
df_all["answer"][1]

"Based on the provided context, the lung cancer trial with a fused_pred of 0.4 or below is:\n\n*   **NCT00404924**\n    *   Phase: III\n    *   Enrollment: 1140.0\n    *   Diseases: \\['Non-Small-Cell Lung Carcinoma']\n    *   Drugs: \\['zd6474 (vandetanib)']\n    *   Fused_pred: 0.296\n\n**Recommendations**\n\nBased on the context, here are a few potential design changes that might improve the predicted success or operational robustness of the lower-scoring lung cancer trial (NCT00404924):\n\n1.  **Eligibility Criteria Review:** The eligibility criteria for NCT00404924 exclude patients who have had standard cancer treatments within 4 weeks before the study or three or more prior chemotherapy regimens. Consider refining these criteria to potentially broaden the eligible patient population while maintaining safety and scientific rigor. For example, the criteria could be revised to allow enrollment of patients who have received a specific number of prior lines of therapy or to specify ac

In [9]:
df_all["question"][1]

"Compare all lung cancer trials in the CONTEXT. For each trial, provide: NCT number, phase, enrollment, and the drug(s) tested. Then compare their design choices along at least three axes (for example: sample size, phase, strictness of inclusion/exclusion criteria, single-arm vs multi-arm, or disease subtype). Based only on the CONTEXT, identify which lung trial appears most at risk of operational failure (e.g., poor recruitment or feasibility issues) and explain why, explicitly referencing NCT numbers. Finish with a brief 'Recommendations' section suggesting design adjustments that could reduce this operational risk."

In [8]:
df_all.to_csv("ragas_results_full.csv", index=False)