## 1. ENV

In [None]:
print("Installing dependencies...")
%pip install -U huggingface_hub chromadb langchain_google_genai sentence_transformers torch_geometric numpy tqdm datasets protobuf==3.20.3

In [None]:
import torch
TORCH = print(torch.__version__)          
CUDA = print(torch.version.cuda)

In [None]:
import os
from dotenv import load_dotenv
from huggingface_hub import login

SEED = 42
load_dotenv()
login(token=os.getenv('HF_TOKEN'))
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
    raise ValueError("GOOGLE_API_KEY not found in environment variables.")

import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


## 2. Dataset loading

data loading

In [None]:
# load the Qasper dataset
from datasets import load_dataset, Dataset
dataset = load_dataset("allenai/qasper", cache_dir='./data/Qasper/qasper_cache')

# shuffle
trainset = dataset['train'] #.shuffle(seed=SEED)
# validset = dataset['validation'] #.shuffle(seed=SEED)
# testset  = dataset['test'] #.shuffle(seed=SEED)

data preprocessing: ensure data integrity

In [None]:
import re
from copy import deepcopy
from typing import Dict, List

FLOAT_TAG = "FLOAT SELECTED: "  # evidence에서 제거할 태그
REF_TAG = re.compile(r" BIBREF\d+")
# FIG_TAG = re.compile(r" FIGREF\d+")

def _clear_tag(ev_list: List[str]) -> List[str]:
    """evidence 리스트에서 'FLOAT SELECTED: ' 부분 문자열을 삭제하고
       내용이 남은 evidence만 반환
    """
    cleaned = []
    for ev in ev_list:
        new_ev = ev.replace(FLOAT_TAG, "")
        new_ev = REF_TAG.sub("", new_ev)
        new_ev = new_ev.strip()
        if new_ev:                       # 치환 후 내용이 남아 있을 때만 보존
            cleaned.append(new_ev)
    return cleaned

def filter_sample(sample: Dict) -> Dict:
    """Qasper 샘플에서 'FLOAT SELECTED: ' 태그를 제거하고
       내용이 남지 않는 answer-detail / question 을 드롭
    """
    new_qas = {k: [] for k in sample["qas"]}

    for idx in range(len(sample["qas"]["question"])):
        blk = deepcopy(sample["qas"]["answers"][idx])

        kept_det, kept_ann, kept_wid = [], [], []
        for det, ann, wid in zip(blk["answer"],
                                 blk.get("annotation_id", []),
                                 blk.get("worker_id", [])):
            cleaned_ev = _clear_tag(det.get("evidence", []))
            if cleaned_ev:                               # evidence가 하나라도 남을 때만 keep
                det = deepcopy(det)
                det["evidence"] = cleaned_ev
                kept_det.append(det)
                kept_ann.append(ann)
                kept_wid.append(wid)

        if kept_det:                                    # 질문 유지 여부
            blk.update(
                answer        = kept_det,
                annotation_id = kept_ann,
                worker_id     = kept_wid,
            )
            new_qas["answers"].append(blk)
            for k in sample["qas"]:
                if k != "answers":
                    new_qas[k].append(sample["qas"][k][idx])

    # Update the sample full_text paragraph with cleaned QAs
    for j, paragraph in enumerate(sample['full_text']['paragraphs']):
        for sub_paragraph in paragraph:
            # Clean the sub-paragraph text
            cleaned_text = sub_paragraph.replace(FLOAT_TAG, "")
            cleaned_text = REF_TAG.sub("", cleaned_text).strip()
            # cleaned_text = FIG_TAG.sub("", cleaned_text).strip()

            sub_paragraph = cleaned_text

    sample["qas"] = new_qas
    return sample


In [None]:
# Apply the filtering function to every sample in the trainset.
trainset = trainset.map(filter_sample)
# validset = validset.map(filter_sample)
# testset = testset.map(filter_sample)

validate data pre processing

In [None]:
def validate_evidence_clean(sample: Dict) -> bool:
    """
    1. evidence가 비어 있지 않은지
    2. evidence 안에 placeholder 문자열이 남아 있지 않은지
    모두 만족하면 True
    """
    ok = True
    for q_idx, blk in enumerate(sample["qas"]["answers"]):
        for a_idx, det in enumerate(blk["answer"]):
            ev = det.get("evidence", [])
            if not ev:
                print(f"[오류] Q{q_idx}-A{a_idx}: evidence가 비어 있습니다.")
                ok = False
                continue
    return ok

In [None]:
def validate_no_float_selected(sample: Dict) -> bool:
    """모든 evidence에 'FLOAT SELECTED: ' 문자열이 남아 있지 않은지 확인"""
    ok = True
    for q_idx, blk in enumerate(sample["qas"]["answers"]):
        for a_idx, det in enumerate(blk["answer"]):
            for ev in det.get("evidence", []):
                if FLOAT_TAG in ev:
                    print(f"[오류] Q{q_idx}-A{a_idx}: 미제거 태그 발견 → “{ev}”")
                    ok = False
    return ok


In [None]:
is_not_clean = True
def validate_all_samples(dataset):
    """모든 샘플에 대해 evidence가 비어 있거나 미제거 태그가 있는지 확인"""
    for i, sample in enumerate(dataset):
        if not validate_evidence_clean(sample):
            print(f"[오류] 샘플 #{i}에 빈 evidence가 존재합니다.")
            return False
        if not validate_no_float_selected(sample):
            print(f"[오류] 샘플 #{i}에 미제거 태그가 존재합니다.")
            return False
    return True

if not validate_all_samples(trainset):
    raise ValueError("Trainset validation failed.")
# if not validate_all_samples(validset):
#     raise ValueError("Validset validation failed.")
# if not validate_all_samples(testset):
#     raise ValueError("Testset validation failed.")

***
## New Approach

In [None]:
import os, re, hashlib, itertools, shutil, json, time
from pathlib import Path
from typing import List, Tuple, Dict, Iterable

HF_CACHE_FOLDER     = "./data/.cache/scibert-nli"
CHROMA_DIR          = "./demo_chroma_db"
# CHROMA_DIR          = "./chroma_qasper_10"
COLLECTION_SENT_NAME     = "paper_sentences"
COLLECTION_RELA_NAME     = "paper_relations"


# Clean previous run
# if Path(CHROMA_DIR).exists():
#     shutil.rmtree(CHROMA_DIR)

In [None]:
from sentence_transformers import SentenceTransformer
print("▸ Loading SentenceTransformer …")
# os.environ['HF_HOME'] = './data/.cache'  # Set cache directory for HuggingFace models
ENC_MODEL = SentenceTransformer('gsarti/scibert-nli', cache_folder=HF_CACHE_FOLDER, device=DEVICE)
if not ENC_MODEL:
    raise RuntimeError("Failed to load the SentenceTransformer model.")

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

print("▸ Loading Gemini 1.5 Flash model …")
GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
LLM = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=GOOGLE_API_KEY,
    temperature=0.0
)
if not LLM:
    raise RuntimeError("Failed to load the Gemini 1.5 Flash model.")
else: print("▸ Gemini 1.5 Flash loaded (LangChain wrapper).")

LABELS = ["Claim", "Evidence", "Background", "Method", "Result",
          "Interpretation", "Contrast", "Cause-Effect",
          "Temporal", "Condition", "Other"]
LABEL2IDX = {lab: i for i, lab in enumerate(LABELS)}

# Not used in this script, but use for further study
PROMPT = """\ 
[MUST FOLLOW Task]:
1. You are an academic discourse analyst.
2. Classify how Sentence_B is related to Sentence_A using EXACTLY ONE "Allowed labels".
3. If none apply, output "Other".

[Allowed labels]:
Claim, Evidence, Background, Method, Result, Interpretation,
Contrast, Cause-Effect, Temporal, Condition, Other

[Sentence pairs]:
{PAIRS}

Respond only with a list of labels, in order:
1: <label>
2: <label>
...
{N}: <label>
"""

def one_hot_encode_labels(labels: list[str]) -> list[list[float]]:
    index = {l: i for i, l in enumerate(LABELS)}
    vecs = []
    for l in labels:
        vec = [0.0] * len(LABELS)
        vec[index.get(l, -1)] = 1.0 if l in index else 0.0
        vecs.append(vec)
    return vecs

def format_sentence_pairs(pairs: list[tuple[str, str]]) -> str:
    return "\n".join([
        f"{i+1}.\nSentence_A: {a}\nSentence_B: {b}"
        for i, (a, b) in enumerate(pairs)
    ])

def parse_llm_labels(response: str, expected: int) -> List[str]:
    lines = response.strip().splitlines()
    labels = []
    
    for line in lines:
        if ":" not in line:
            continue                   # 형식 일치하지 않으면 스킵
        _, raw = line.split(":", 1)
        lbl = raw.strip()
        if not lbl or lbl not in LABELS:
            lbl = "Other"  # 유효하지 않은 라벨은 "Other"로 대체
        labels.append(lbl)

    # 부족한 개수만큼 패딩
    if len(labels) < expected:
        labels.extend(["Other"] * (expected - len(labels)))

    # 초과하면 앞 expected개만 사용
    return labels[:expected]

import re
SPLIT_PUNCS = re.compile(r"[.!?;:]+")

def split_sentences(paragraph: str) -> list[str]:
    """문장 구분 기호로 문단을 분할"""
    sentences = SPLIT_PUNCS.split(paragraph.strip())
    return [s.strip() for s in sentences if s.strip()]  # 빈 문자열 제거


 Initialising ChromaDB

In [None]:
import chromadb

print("▸ Initialising ChromaDB …")
CHROMA_DIR = './chroma_test_db'
print(CHROMA_DIR)
client = chromadb.PersistentClient(path=CHROMA_DIR)
col_sent  = client.get_or_create_collection(name=COLLECTION_SENT_NAME)
col_rel   = client.get_or_create_collection(name=COLLECTION_RELA_NAME)
col_sent.count(), col_rel.count()


In [None]:
import numpy as np, hashlib
from tqdm import tqdm
from datasets import DatasetDict
from typing import Dict, Iterable, Tuple, Optional

def generate_sent_nodes(dataset: DatasetDict) -> Iterable[Dict]:
    for sample in tqdm(dataset, total=len(dataset), desc="Generating nodes"):
        paper_id = sample["id"]
    
        for sec_idx, (sec_name, sec_content) in enumerate(
                zip(sample["full_text"]["section_name"],
                    sample["full_text"]["paragraphs"])):
    
            for para_idx, para in enumerate(sec_content):
                sentences = split_sentences(para.strip())
                prev: Optional[Tuple[str, str]] = None

                for sent_idx, sent in enumerate(filter(None, sentences)):
                    para_id = f"{paper_id}/sec{sec_idx}/para{para_idx}"
                    path = f"{para_id}/sent{sent_idx}"
                    sid  = hashlib.sha1(path.encode()).hexdigest()

                    yield {
                        "sid":    sid,
                        "sent":   sent.strip(),
                        "prev":   prev,
                        "meta":   dict(paper_id=paper_id, sec_idx=sec_idx,
                                        para_idx=para_idx, sent_idx=sent_idx)
                    }
                    prev = dict(sid=sid, sent=sent.strip())  # skips first sentence in paragraph

In [None]:
import logging, sys
from typing import List, Tuple

def flush_sentence(ids, sents, metas):
    if not ids:
        return

    embs = ENC_MODEL.encode(
        sents,
        convert_to_numpy=True,
        normalize_embeddings=True,
        batch_size=len(sents),
        show_progress_bar=False,
        device=DEVICE
    )

    col_sent.upsert(
        ids=ids,
        embeddings=embs,
        documents=sents,
        metadatas=metas
    )


# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# if not logger.handlers:
#     handler = logging.StreamHandler(sys.stdout)
#     formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
#     handler.setFormatter(formatter)
#     logger.addHandler(handler)
# async def flush_pairs_async(pairs, batch_size: int = 50) -> None:
#     if not pairs:
#         logging.debug("flush_pairs_async: Received empty pair buffer. Skipping.")
#         return

#     ab_pairs = [(pair['prev_sen'], pair['curr_sen']) for pair in pairs]
#     logging.info(f"↪ [flush_pairs_async] Processing {len(pairs)}")

#     # TODO : LLM labeling

#     for pair in pairs:
#         prev_sid = pair['prev_sid']
#         curr_sid = pair['curr_sid']
#         rel_id = f"{prev_sid}|{curr_sid}"
#         try:
#             col_rel.upsert(
#                 ids        = [rel_id],
#                 documents  = [f"{prev_sid} <REL> {curr_sid}"],
#                 embeddings = [list(np.ndarray(arange(11)))], # [vec],
#                 metadatas  = [{
#                     # "relation_label": lab,
#                     "paper_id": pair['meta']['paper_id'],
#                     "sec_idx": pair['meta']['sec_idx'],
#                     "para_idx": pair['meta']['para_idx'],
#                     "sid_src": prev_sid,
#                     "sid_dst": curr_sid,
#                 }],
#             )
#         except Exception as e:
#             logging.warning(f"[Upsert Failed] ID: {rel_id} — {e}")


In [None]:
import asyncio

async def process_dataset(nodes, batch_sent=32, batch_pairs=50):
    ids, sents, metas = [], [], []
    pair_buf          = []
    tasks             = []

    for node in nodes:
        ids.append(node["sid"])
        sents.append(node["sent"])
        metas.append(node["meta"])

        if node["prev"]:
            # print(f"Processing pair: {node['prev']['sid']} -> {node['sid']}")
            pair_buf.append(dict(meta=node['meta'],
                             prev_sid=node["prev"]["sid"],  curr_sid=node["sid"],
                             prev_sen=node["prev"]["sent"], curr_sen=node["sent"]))

        if len(ids) >= batch_sent:
            flush_sentence(ids, sents, metas)
            ids, sents, metas = [], [], []

        # if len(pair_buf) >= batch_pairs:
        #     # 병렬 실행 예약
        #     tasks.append(asyncio.create_task(flush_pairs_async(pair_buf)))
        #     pair_buf = []

    if ids:
        flush_sentence(ids, sents, metas)
    # if pair_buf:
        # tasks.append(asyncio.create_task(flush_pairs_async(pair_buf)))

    # 모든 병렬 작업 종료 대기
    if tasks:
        await asyncio.gather(*tasks)


In [None]:
print("▸ Embedding sentences + relations …")
subset = trainset.select(range(2))  # for testing
await process_dataset(generate_sent_nodes(subset))
print("▸ Done embedding sentences + relations.")
print(f"▸ Stored sentences: {col_sent.count():,}")
# print(f"▸ Stored vectors   : {col_rel.count():,}")

In [None]:
# for tmporary testing
def semantic_query(query: str, paper_id: str, k: int = 5):
    """
    paper_id 에 해당하는 문서 내부에서만 최근접 문장 k개 검색
    """
    q_vec = ENC_MODEL.encode(query, normalize_embeddings=True)
    res   = col_sent.query(
        query_embeddings=[q_vec],
        n_results=k,
        include=["documents", "metadatas", "distances"],
        where={"paper_id": paper_id}
    )
    
    print(f"\n[Doc = {paper_id}]  {query!r}")
    for doc, dist, sid, meta in zip(res["documents"][0], res["distances"][0], res["ids"][0], res["metadatas"][0]):
        print(f" • doc = {doc}")# dist={dist:.4f}  sid={sid[:8]} text≈{meta.get('sent_idx')}  rel={meta.get('relation_label')}")


# sample = trainset[4]
# print(len(sample['qas']['question']))
# query = sample['qas']['question'][0]
# id = sample['id']
# topk= 5
# semantic_query(query, id, k=topk)

## Training

In [None]:
# load the collections 
from chromadb import PersistentClient
print(CHROMA_DIR)
client = PersistentClient(path='./demo_chroma_db')
col_sent =client.get_or_create_collection(name=COLLECTION_SENT_NAME)
col_rel = client.get_or_create_collection(name=COLLECTION_RELA_NAME)
col_rel.count(), col_sent.count()

Training

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import negative_sampling
from collections import defaultdict

# 1) ChromaDB 에서 문장·관계 로드

# 1.1) 문장 노드 피처 및 메타
sent_data   = col_sent.get(include=["embeddings","metadatas"])
sent_metas  = sent_data["metadatas"]                                    # 각 dict에 "paper_id","sid","para_id","sent_idx" 포함
node_ids    = sent_data["ids"]                                          # List[str]
feats       = torch.tensor(sent_data["embeddings"], dtype=torch.float)  # [N,768]

tmp_tree = defaultdict( # paper
    lambda: defaultdict( # section
        lambda: defaultdict(list) # paragraph[sentences]
    )
)

for idx, meta in enumerate(sent_metas):
    paper_idx = meta.get("paper_id", None)  # 논문 ID
    sec_idx = meta.get("sec_idx", None)  # not section title... in qasper
    para_idx = meta.get("para_idx", None)
    # sent_idx = meta.get("sent_idx", None)  # 문장 인덱스


    sid = node_ids[idx]

    tmp_tree[paper_idx][sec_idx][para_idx].append(sid)

# 기존 paragraph 연결 외에, section·paper 단계까지 확장
HIER_REL = ["paragraph", "section", "paper"]
hier2idx = {name: i for i, name in enumerate(HIER_REL)}

# sid → global index
id2idx = {sid:i for i,sid in enumerate(node_ids)}
data_list = []  # 최종 서브그래프 모음

for pid, secs in tmp_tree.items():
    nodes, edges, types = set(), [], []

    # 2.1) paragraph-level 엣지
    for sec_idx, paras in secs.items():
        for para_idx, sid_list in paras.items():
            # 각 문단 내 문장 인덱스(글로벌) 변환 & 정렬
            idxs = [ id2idx[sid] for sid in sid_list ]
            idxs.sort(key=lambda i: sent_metas[i]["sent_idx"])
            # 인접 문장쌍 연결 (양방향)
            for u, v in zip(idxs, idxs[1:]):
                nodes.update((u, v))
                edges.append((u, v)); types.append(hier2idx["paragraph"])
                edges.append((v, u)); types.append(hier2idx["paragraph"])

    # 2.2) section-level 엣지
    for sec_idx, paras in secs.items():
        # 각 문단의 첫 문장만 추출해 adjacent 연결
        heads = []
        for para_idx in sorted(paras.keys()):
            if paras[para_idx]:
                heads.append( paras[para_idx][0] )
        idxs = [ id2idx[sid] for sid in heads ]
        for u, v in zip(idxs, idxs[1:]):
            nodes.update((u, v))
            edges.append((u, v)); types.append(hier2idx["section"])
            edges.append((v, u)); types.append(hier2idx["section"])

    # 2.3) paper-level 엣지
    # 각 섹션의 첫 문단 첫 문장을 연결
    section_heads = []
    for sec_idx in sorted(secs.keys()):
        paras = secs[sec_idx]
        if not paras: continue
        first_para = sorted(paras.keys())[0]
        if paras[first_para]:
            section_heads.append( paras[first_para][0] )
    idxs = [ id2idx[sid] for sid in section_heads ]
    for u, v in zip(idxs, idxs[1:]):
        nodes.update((u, v))
        edges.append((u, v)); types.append(hier2idx["paper"])
        edges.append((v, u)); types.append(hier2idx["paper"])

    # 2.4) Data 객체 생성
    if not edges:
        continue
    uniq_nodes = sorted(nodes)
    g2l = {g: i for i, g in enumerate(uniq_nodes)}
    x_sub = feats[uniq_nodes]                              # [n_sub, 768]
    edge_index = torch.tensor([
        [ g2l[u] for u, _ in edges ],
        [ g2l[v] for _, v in edges ]
    ], dtype=torch.long)
    edge_type = torch.tensor(types, dtype=torch.long)      # [E]

    data_list.append(Data(x=x_sub,
                         edge_index=edge_index,
                         edge_type=edge_type))

import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.data import DataLoader
from torch_geometric.utils import negative_sampling

# 1) GraphSAGE 정의 (relation weight 제거)
class GraphSAGE(MessagePassing):
    def __init__(self, in_dim, hid_dim):
        super().__init__(aggr="mean")
        self.lin_self  = torch.nn.Linear(in_dim, hid_dim)
        self.lin_neigh = torch.nn.Linear(in_dim, hid_dim)

    def forward(self, x, edge_index):
        # 자기 표현
        h_self  = self.lin_self(x)
        # 이웃 메시지 집계
        h_neigh = self.propagate(edge_index, x=x)
        return F.relu(h_self + h_neigh)

    def message(self, x_j):
        # 단순히 선형 변환된 이웃 임베딩 반환
        return self.lin_neigh(x_j)


# 2) 학습 세팅
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = GraphSAGE(in_dim=768, hid_dim=256).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs    = 1000
loader    = DataLoader(data_list, batch_size=1, shuffle=True)

# 3) Contrastive-style 학습 루프
for ep in range(1, epochs+1):
    model.train()
    total_loss = 0.0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # 3.1) 노드 임베딩 계산
        h = model(batch.x, batch.edge_index)            # [n_sub, 256]

        # 3.2) Positive edge similarity
        src, dst = batch.edge_index                     # [2, E]
        pos_sim  = (h[src] * h[dst]).sum(dim=1)          # [E]

        # 3.3) Negative sampling
        neg_idx  = negative_sampling(
            edge_index     = batch.edge_index,
            num_nodes      = h.size(0),
            num_neg_samples= src.size(0)
        )
        ns, nd   = neg_idx
        neg_sim  = (h[ns] * h[nd]).sum(dim=1)            # [E]

        # 3.4) Loss 계산
        loss_pos = - F.logsigmoid(pos_sim).mean()
        loss_neg = - F.logsigmoid(-neg_sim).mean()
        loss     = loss_pos + loss_neg

        # 3.5) 역전파 및 업데이트
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"[Epoch {ep:02d}] Avg Loss: {avg_loss:.4f}")


In [None]:
import pandas as pd

rows = []
for pid, secs in tmp_tree.items():
    for sec_idx, paras in secs.items():
        for para_idx, sids in paras.items():
            rows.append({
                "paper_id": pid,
                "section": sec_idx,
                "paragraph": para_idx,
                "num_sentences": len(sids)
            })

df = pd.DataFrame(rows)
print(df.head(10))

## 4. LangGraph State and Node config

In [None]:
col_sent = client.get_or_create_collection(COLLECTION_SENT_NAME)

In [None]:
# %% LangGraph & 노드 정의 (이하는 기존과 동일)
import os
from langgraph.graph import StateGraph, START, END
from langchain_google_genai import ChatGoogleGenerativeAI
from sentence_transformers import SentenceTransformer, CrossEncoder

from typing import TypedDict

# 모델 설정
encoder     = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device=device)
ce_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
llm         = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=os.getenv('GOOGLE_API_KEY'),
    temperature=0.0
)

# State Schema
class InputState(TypedDict):
    paper_id: str
    question: str

class OutputState(TypedDict):
    paper_id: str
    answer: str
    top_docs: list[str]
    top_metadatas: list[dict]

class OverallState(InputState, OutputState):
    retrieved_docs: list[str]
    retrieved_metadatas: list[dict]

# retrieve 노드: 문장 컬렉션에서 검색
def retrieve(state: OverallState) -> OverallState:
    assert col_sent is not None, "Sentence-level collection not found"
    q        = state["question"]
    paper_id = state["paper_id"]
    # print(f"Retrieving sentences for paper {paper_id!r} with query {q!r}")
    # 1) 질의 임베딩
    q_emb = encoder.encode(q)
    # 2) col_sent에서 논문별 필터링 후 top-20 문장 검색
    res = col_sent.query(
        query_embeddings=[q_emb],
        n_results=10,
        include=["documents", "metadatas"],
        where={"paper_id": paper_id}
    )
    # print(f"Retrived sents : {res['documents'][0]}")
    # print(f"Retrieved {len(res['documents'][0])} sentences for paper {paper_id!r} with query {q!r}")
    # Chromadb 문법상 [0]으로 추출
    state["retrieved_docs"]      = res["documents"][0]
    state["retrieved_metadatas"] = res["metadatas"][0]
    return state

# rerank / generate 노드는 기존과 동일
def rerank(state: OverallState) -> OverallState:
    docs      = state["retrieved_docs"]
    metadatas = state["retrieved_metadatas"]
    q         = state["question"]

    scores = ce_reranker.predict([(q, d) for d in docs])
    ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)

    state["top_docs"]      = [docs[i] for i, _ in ranked[:6]]
    state["top_metadatas"] = [metadatas[i] for i, _ in ranked[:6]]
    return state

def generate(state: OverallState) -> OverallState:
    ctx = "\n\n".join(state["top_docs"])
    prompt = (
        f"Prompt: Answer based on context below. If you don't know, say 'unanswerable'.\n"
        f"Context:\n{ctx}\n\nQuestion: {state['question']}\nAnswer:"
    )
    resp = llm.invoke(prompt)
    state["answer"] = resp.content
    return state

# 파이프라인 빌드
builder = StateGraph(state_schema=OverallState, input=InputState, output=OutputState)
builder.add_node("retrieve", retrieve)
builder.add_node("rerank",   rerank)
builder.add_node("generate", generate)
builder.set_entry_point("retrieve")
builder.add_edge("retrieve", "rerank")
builder.add_edge("rerank",   "generate")
builder.set_finish_point("generate")

rag_pipeline = builder.compile()


In [None]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm


# 1) 데이터 로드 (검증용 subset)
dataset = load_dataset('allenai/qasper', split='train', cache_dir='./data/Qasper/qasper_cache')
# subset  = dataset.shuffle(seed=42).select(range(int(len(dataset)*0.1)))  # 10%

# 2) helper functions
def normalize(text: str) -> str:
    return ''.join(c.lower() for c in text if c.isalnum() or c.isspace()).strip()

def find_gold_para_ids(sample, gold_evids):
    paper_id = sample["id"]
    gold_ids = set()
    # full_text['paragraphs'] 는 섹션별로 [문단1, 문단2, …] 리스트
    sections = sample["full_text"]["paragraphs"]
    for ev in gold_evids:
        ev_norm = ev.strip()
        for sec_idx, paras in enumerate(sections):
            for para_idx, para_text in enumerate(paras):
                if ev_norm in para_text:
                    para_id = f"sec{sec_idx}/para{para_idx}"
                    gold_ids.add(f"{paper_id}_{para_id}")
                    # 하나의 evidence는 하나의 문단만 매핑하므로 찾으면 바로 빠져나감
                    break
            else:
                continue
            break
    return gold_ids

def compute_retrieval_metrics(pred: list[str], gold: set[str]):
    pred_set = set(pred)
    tp = len(pred_set & gold)
    prec = tp / len(pred_set) if pred_set else 0.0
    rec  = tp / len(gold)     if gold     else 0.0
    f1   = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
    return prec, rec, f1

def compute_answer_metrics(pred_ans: str, gold_texts: list[str], gold_yesno: list[bool], gold_unans: list[bool]):
    norm_pred = normalize(pred_ans)
    # Exact Match over all annotators
    em = any(norm_pred == normalize(gt) for gt in gold_texts)
    # yes/no
    if any(gold_yesno):
        yn_pred = norm_pred.startswith('yes') or norm_pred.startswith('no')
        yn_acc  = int(yn_pred and norm_pred.split()[0] == ('yes' if any(gold_yesno) else 'no'))
    else:
        yn_acc = None
    # unanswerable
    if any(gold_unans):
        un_pred = norm_pred.startswith('unanswerable')
        un_acc  = int(un_pred)
    else:
        un_acc = None
    return int(em), yn_acc, un_acc

# 3) evaluation loop
records = []
for sample in tqdm(subset, desc="Evaluating"):
    pid       = sample["id"]
    questions = sample["qas"]["question"]
    answers   = sample["qas"]["answers"]
    
    for q_text, ans_block in zip(questions, answers):
        # 3.1) RAG-invoke
        out = rag_pipeline.invoke({"paper_id": pid, "question": q_text})
        pred_ans        = out["answer"].strip()
        pred_metadatas  = out["top_metadatas"]
        # predicted evidence IDs
        pred_ids = []
        for md in pred_metadatas:
            pid     = md["paper_id"]
            sec_idx = md["sec_idx"]
            para_idx= md["para_idx"]
            para_id = f"sec{sec_idx}/para{para_idx}"
            pred_ids.append(f"{pid}_{para_id}")
        # in format of "paper_id_para_idx"
        
        # 3.2) gold aggregation
        gold_texts, gold_evids, gold_yesno, gold_unans = [], [], [], []
        # del gold_texts, gold_yesno, gold_unans  # clear previous
        for ann in ans_block["answer"]:
            if ann.get("answer"):
                gold_texts.append(ann["answer"])
            if ann.get("evidence"):
                gold_evids.extend(ann["evidence"])
            if ann.get("yes_no") is not None:
                gold_yesno.append(ann["yes_no"])
            if ann.get("unanswerable") is not None:
                gold_unans.append(ann["unanswerable"])
        
        gold_ids = find_gold_para_ids(sample, gold_evids)
        
        # 3.3) metrics
        prec, rec, f1 = compute_retrieval_metrics(pred_ids, gold_ids)
        em, yn_acc, un_acc = compute_answer_metrics(pred_ans, gold_texts, gold_yesno, gold_unans)
        
        records.append({
            "paper_id":           pid,
            "question":           q_text,
            "precision":          prec,
            "recall":             rec,
            "f1":                 f1,
            "exact_match":        em,
            "yes_no_acc":         yn_acc,
            "unans_acc":          un_acc
        })

# 4) DataFrame 생성 및 paper별 집계
df = pd.DataFrame(records)

# 질문 단위 전체 평균
overall = df.mean(numeric_only=True).round(4)
print("=== Overall ===")
print(overall.to_dict())

# paper_id 별 평균
by_paper = (
    df
    .groupby("paper_id")
    .agg({
        "precision":   "mean",
        "recall":      "mean",
        "f1":          "mean",
        "exact_match": "mean",
        "yes_no_acc":  "mean",
        "unans_acc":   "mean"
    })
    .round(4)
    .reset_index()
)
print("\n=== By Paper ===")
print(by_paper)

# 5) 결과 저장
df.to_csv("qasper_rag_eval_per_question.csv", index=False)
by_paper.to_csv("qasper_rag_eval_by_paper.csv", index=False)


In [None]:
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm


# 1) 데이터 로드 (검증용 subset)
dataset = load_dataset('allenai/qasper', split='train', cache_dir='./data/Qasper/qasper_cache')
# subset  = dataset.shuffle(seed=42).select(range(int(len(dataset)*0.1)))  # 10%

# 2) helper functions
def normalize(text: str) -> str:
    return ''.join(c.lower() for c in text if c.isalnum() or c.isspace()).strip()

def find_gold_para_ids(sample, gold_evids):
    paper_id = sample["id"]
    gold_ids = set()
    # full_text['paragraphs'] 는 섹션별로 [문단1, 문단2, …] 리스트
    sections = sample["full_text"]["paragraphs"]
    for ev in gold_evids:
        ev_norm = ev.strip()
        for sec_idx, paras in enumerate(sections):
            for para_idx, para_text in enumerate(paras):
                if ev_norm in para_text:
                    para_id = f"sec{sec_idx}/para{para_idx}"
                    gold_ids.add(f"{paper_id}_{para_id}")
                    # 하나의 evidence는 하나의 문단만 매핑하므로 찾으면 바로 빠져나감
                    break
            else:
                continue
            break
    return gold_ids

def compute_retrieval_metrics(pred: list[str], gold: set[str]):
    pred_set = set(pred)
    tp = len(pred_set & gold)
    prec = tp / len(pred_set) if pred_set else 0.0
    rec  = tp / len(gold)     if gold     else 0.0
    f1   = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
    return prec, rec, f1

def compute_answer_metrics(pred_ans: str, gold_texts: list[str], gold_yesno: list[bool], gold_unans: list[bool]):
    norm_pred = normalize(pred_ans)
    # Exact Match over all annotators
    em = any(norm_pred == normalize(gt) for gt in gold_texts)
    # yes/no
    if any(gold_yesno):
        yn_pred = norm_pred.startswith('yes') or norm_pred.startswith('no')
        yn_acc  = int(yn_pred and norm_pred.split()[0] == ('yes' if any(gold_yesno) else 'no'))
    else:
        yn_acc = None
    # unanswerable
    if any(gold_unans):
        un_pred = norm_pred.startswith('unanswerable')
        un_acc  = int(un_pred)
    else:
        un_acc = None
    return int(em), yn_acc, un_acc

# 3) evaluation loop
records = []
for sample in tqdm(subset, desc="Evaluating"):
    pid       = sample["id"]
    questions = sample["qas"]["question"]
    answers   = sample["qas"]["answers"]
    
    for q_text, ans_block in zip(questions, answers):
        # 3.1) RAG-invoke
        out = rag_pipeline.invoke({"paper_id": pid, "question": q_text})
        pred_ans        = out["answer"].strip()
        pred_metadatas  = out["top_metadatas"]
        # predicted evidence IDs
        pred_ids = []
        for md in pred_metadatas:
            pid     = md["paper_id"]
            sec_idx = md["sec_idx"]
            para_idx= md["para_idx"]
            para_id = f"sec{sec_idx}/para{para_idx}"
            pred_ids.append(f"{pid}_{para_id}")
        # in format of "paper_id_para_idx"
        
        # 3.2) gold aggregation
        gold_texts, gold_evids, gold_yesno, gold_unans = [], [], [], []
        # del gold_texts, gold_yesno, gold_unans  # clear previous
        for ann in ans_block["answer"]:
            if ann.get("answer"):
                gold_texts.append(ann["answer"])
            if ann.get("evidence"):
                gold_evids.extend(ann["evidence"])
            if ann.get("yes_no") is not None:
                gold_yesno.append(ann["yes_no"])
            if ann.get("unanswerable") is not None:
                gold_unans.append(ann["unanswerable"])
        
        gold_ids = find_gold_para_ids(sample, gold_evids)
        
        # 3.3) metrics
        prec, rec, f1 = compute_retrieval_metrics(pred_ids, gold_ids)
        em, yn_acc, un_acc = compute_answer_metrics(pred_ans, gold_texts, gold_yesno, gold_unans)
        
        records.append({
            "paper_id":           pid,
            "question":           q_text,
            "precision":          prec,
            "recall":             rec,
            "f1":                 f1,
            "exact_match":        em,
            "yes_no_acc":         yn_acc,
            "unans_acc":          un_acc
        })

# 4) DataFrame 생성 및 paper별 집계
df = pd.DataFrame(records)

# 질문 단위 전체 평균
overall = df.mean(numeric_only=True).round(4)
print("=== Overall ===")
print(overall.to_dict())

# paper_id 별 평균
by_paper = (
    df
    .groupby("paper_id")
    .agg({
        "precision":   "mean",
        "recall":      "mean",
        "f1":          "mean",
        "exact_match": "mean",
        "yes_no_acc":  "mean",
        "unans_acc":   "mean"
    })
    .round(4)
    .reset_index()
)
print("\n=== By Paper ===")
print(by_paper)

# 5) 결과 저장
df.to_csv("qasper_rag_eval_per_question.csv", index=False)
by_paper.to_csv("qasper_rag_eval_by_paper.csv", index=False)
