## 1. ENV

In [None]:
!pip install dotenv huggingface_hub datasets sentence_transformers langchain_google_genai chromadb neo4j langchain_community py2neo spacy --upgrade torch_geometric

In [None]:
import torch
TORCH = print(torch.__version__)          
CUDA = print(torch.version.cuda)

2.7.0+cpu
None


In [6]:
import os
from dotenv import load_dotenv
from huggingface_hub import login

SEED = 42
load_dotenv()
login(token=os.getenv('HF_TOKEN'))
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
    raise ValueError("GOOGLE_API_KEY not found in environment variables.")

import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## 2. Dataset loading

data loading

In [2]:
# load the Qasper dataset
from datasets import load_dataset, Dataset
dataset = load_dataset("allenai/qasper", cache_dir='./data/Qasper/qasper_cache')

# shuffle
trainset = dataset['train'] #.shuffle(seed=SEED)
validset = dataset['validation'] #.shuffle(seed=SEED)
testset  = dataset['test'] #.shuffle(seed=SEED)
# subset = trainset.select(range(5,10))  # select a subset of 1000 samples

data preprocessing: ensure data integrity

In [3]:
import re
from copy import deepcopy
from typing import Dict, List

FLOAT_TAG = "FLOAT SELECTED: "          # **대소문자 구분**
REF_TAG = re.compile(r" BIBREF\d+")
# FIG_TAG = re.compile(r" FIGREF\d+")

def _clear_tag(ev_list: List[str]) -> List[str]:
    """evidence 리스트에서 'FLOAT SELECTED: ' 부분 문자열을 삭제하고
       내용이 남은 evidence만 반환
    """
    cleaned = []
    for ev in ev_list:
        new_ev = ev.replace(FLOAT_TAG, "")
        new_ev = REF_TAG.sub("", new_ev)
        new_ev = new_ev.strip()
        if new_ev:                       # 치환 후 내용이 남아 있을 때만 보존
            cleaned.append(new_ev)
    return cleaned

def filter_sample(sample: Dict) -> Dict:
    """Qasper 샘플에서 'FLOAT SELECTED: ' 태그를 제거하고
       내용이 남지 않는 answer-detail / question 을 드롭
    """
    new_qas = {k: [] for k in sample["qas"]}

    for idx in range(len(sample["qas"]["question"])):
        blk = deepcopy(sample["qas"]["answers"][idx])

        kept_det, kept_ann, kept_wid = [], [], []
        for det, ann, wid in zip(blk["answer"],
                                 blk.get("annotation_id", []),
                                 blk.get("worker_id", [])):
            cleaned_ev = _clear_tag(det.get("evidence", []))
            if cleaned_ev:                               # evidence가 하나라도 남을 때만 keep
                det = deepcopy(det)
                det["evidence"] = cleaned_ev
                kept_det.append(det)
                kept_ann.append(ann)
                kept_wid.append(wid)

        if kept_det:                                    # 질문 유지 여부
            blk.update(
                answer        = kept_det,
                annotation_id = kept_ann,
                worker_id     = kept_wid,
            )
            new_qas["answers"].append(blk)
            for k in sample["qas"]:
                if k != "answers":
                    new_qas[k].append(sample["qas"][k][idx])

    # Update the sample full_text paragraph with cleaned QAs
    for j, paragraph in enumerate(sample['full_text']['paragraphs']):
        for sub_paragraph in paragraph:
            # Clean the sub-paragraph text
            cleaned_text = sub_paragraph.replace(FLOAT_TAG, "")
            cleaned_text = REF_TAG.sub("", cleaned_text).strip()
            # cleaned_text = FIG_TAG.sub("", cleaned_text).strip()

            sub_paragraph = cleaned_text

    sample["qas"] = new_qas
    return sample


In [4]:
# Apply the filtering function to every sample in the trainset.
trainset = trainset.map(filter_sample)
validset = validset.map(filter_sample)
testset = testset.map(filter_sample)

validate data pre processing

In [5]:
def validate_evidence_clean(sample: Dict) -> bool:
    """
    ① evidence가 비어 있지 않은지
    ② evidence 안에 placeholder 문자열이 남아 있지 않은지
    모두 만족하면 True
    """
    ok = True
    for q_idx, blk in enumerate(sample["qas"]["answers"]):
        for a_idx, det in enumerate(blk["answer"]):
            ev = det.get("evidence", [])
            if not ev:
                print(f"[오류] Q{q_idx}-A{a_idx}: evidence가 비어 있습니다.")
                ok = False
                continue
    return ok

In [6]:
def validate_no_float_selected(sample: Dict) -> bool:
    """모든 evidence에 'FLOAT SELECTED: ' 문자열이 남아 있지 않은지 확인"""
    ok = True
    for q_idx, blk in enumerate(sample["qas"]["answers"]):
        for a_idx, det in enumerate(blk["answer"]):
            for ev in det.get("evidence", []):
                if FLOAT_TAG in ev:
                    print(f"[오류] Q{q_idx}-A{a_idx}: 미제거 태그 발견 → “{ev}”")
                    ok = False
    return ok


In [7]:
is_not_clean = True
def validate_all_samples(dataset):
    """모든 샘플에 대해 evidence가 비어 있거나 미제거 태그가 있는지 확인"""
    for i, sample in enumerate(dataset):
        if not validate_evidence_clean(sample):
            print(f"[오류] 샘플 #{i}에 빈 evidence가 존재합니다.")
            return False
        if not validate_no_float_selected(sample):
            print(f"[오류] 샘플 #{i}에 미제거 태그가 존재합니다.")
            return False
    return True

if not validate_all_samples(trainset):
    raise ValueError("Trainset validation failed.")
if not validate_all_samples(validset):
    raise ValueError("Validset validation failed.")
if not validate_all_samples(testset):
    raise ValueError("Testset validation failed.")

***
## New Approach

In [25]:
import os, re, hashlib, itertools, shutil, json, time
from pathlib import Path
from typing import List, Tuple, Dict, Iterable

HF_CACHE_FOLDER     = "./data/.cache/scibert-nli"
# CHROMA_DIR          = "./demo_chroma_db"
CHROMA_DIR          = "./chroma_qasper_10"
COLLECTION_SENT_NAME     = "paper_sentences"
COLLECTION_RELA_NAME     = "paper_relations"


# Clean previous run
# if Path(CHROMA_DIR).exists():
#     shutil.rmtree(CHROMA_DIR)

In [9]:
from sentence_transformers import SentenceTransformer
print("▸ Loading SentenceTransformer …")
# os.environ['HF_HOME'] = './data/.cache'  # Set cache directory for HuggingFace models
ENC_MODEL = SentenceTransformer('gsarti/scibert-nli', cache_folder=HF_CACHE_FOLDER, device=DEVICE)
if not ENC_MODEL:
    raise RuntimeError("Failed to load the SentenceTransformer model.")

No sentence-transformers model found with name gsarti/scibert-nli. Creating a new one with mean pooling.


▸ Loading SentenceTransformer …


In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI

GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]
LLM = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=GOOGLE_API_KEY,
    temperature=0.0
)
if not LLM:
    raise RuntimeError("Failed to load the Gemini 1.5 Flash model.")
else: print("▸ Gemini 1.5 Flash loaded (LangChain wrapper).")

LABELS = ["Claim", "Evidence", "Background", "Method", "Result",
          "Interpretation", "Contrast", "Cause-Effect",
          "Temporal", "Condition", "Other"]
LABEL2IDX = {lab: i for i, lab in enumerate(LABELS)}

PROMPT = """\
[MUST FOLLOW Task]:
1. You are an academic discourse analyst.
2. Classify how Sentence_B is related to Sentence_A using EXACTLY ONE "Allowed labels".
3. If none apply, output "Other".

[Allowed labels]:
Claim, Evidence, Background, Method, Result, Interpretation,
Contrast, Cause-Effect, Temporal, Condition, Other

[Sentence pairs]:
{PAIRS}

Respond only with a list of labels, in order:
1: <label>
2: <label>
...
{N}: <label>
"""

def one_hot_encode_labels(labels: list[str]) -> list[list[float]]:
    index = {l: i for i, l in enumerate(LABELS)}
    vecs = []
    for l in labels:
        vec = [0.0] * len(LABELS)
        vec[index.get(l, -1)] = 1.0 if l in index else 0.0
        vecs.append(vec)
    return vecs

def format_sentence_pairs(pairs: list[tuple[str, str]]) -> str:
    return "\n".join([
        f"{i+1}.\nSentence_A: {a}\nSentence_B: {b}"
        for i, (a, b) in enumerate(pairs)
    ])

def parse_llm_labels(response: str, expected: int) -> List[str]:
    lines = response.strip().splitlines()
    labels = []
    
    for line in lines:
        if ":" not in line:
            continue                   # 형식 일치하지 않으면 스킵
        _, raw = line.split(":", 1)
        lbl = raw.strip()
        if not lbl or lbl not in LABELS:
            lbl = "Other"  # 유효하지 않은 라벨은 "Other"로 대체
        labels.append(lbl)

    # 부족한 개수만큼 패딩
    if len(labels) < expected:
        labels.extend(["Other"] * (expected - len(labels)))

    # 초과하면 앞 expected개만 사용
    return labels[:expected]

import re
SPLIT_PUNCS = re.compile(r"[.!?;:]+")

def split_sentences(paragraph: str) -> list[str]:
    """문장 구분 기호로 문단을 분할"""
    sentences = SPLIT_PUNCS.split(paragraph.strip())
    return [s.strip() for s in sentences if s.strip()]  # 빈 문자열 제거


▸ Gemini 1.5 Flash loaded (LangChain wrapper).


 Initialising ChromaDB

In [57]:
import chromadb

print("▸ Initialising ChromaDB …")
CHROMA_DIR = './demo_chroma_db'
print(CHROMA_DIR)
client = chromadb.PersistentClient(path=CHROMA_DIR)
col_sent  = client.get_or_create_collection(name=COLLECTION_SENT_NAME)
col_rel   = client.get_or_create_collection(name=COLLECTION_RELA_NAME)
col_sent.count(),   col_rel.count()


▸ Initialising ChromaDB …
./demo_chroma_db


(182, 126)

In [12]:
import numpy as np, hashlib
from tqdm import tqdm
from datasets import DatasetDict
from typing import Dict, Iterable, Tuple, Optional

def generate_sent_nodes(dataset: DatasetDict) -> Iterable[Dict]:
    for sample in tqdm(dataset, total=len(dataset), desc="Generating nodes"):
        paper_id = sample["id"]
    
        for sec_idx, (sec_name, sec_content) in enumerate(
                zip(sample["full_text"]["section_name"],
                    sample["full_text"]["paragraphs"])):
    
            for para_idx, para in enumerate(sec_content):
                sentences = split_sentences(para.strip())
                prev: Optional[Tuple[str, str]] = None

                for sent_idx, sent in enumerate(filter(None, sentences)):
                    para_id = f"{paper_id}/sec{sec_idx}/para{para_idx}"
                    path = f"{para_id}/sent{sent_idx}"
                    sid  = hashlib.sha1(path.encode()).hexdigest()

                    yield {
                        "sid":    sid,
                        "sent":   sent.strip(),
                        "prev":   prev,
                        "meta":   dict(paper_id=paper_id, sec_idx=sec_idx,
                                        para_idx=para_idx, sent_idx=sent_idx)
                    }
                    prev = dict(sid=sid, sent=sent.strip())  # skips first sentence in paragraph

In [13]:
import logging, sys
from typing import List, Tuple

def flush_sentence(ids, sents, metas):
    if not ids:
        return

    embs = ENC_MODEL.encode(
        sents,
        convert_to_numpy=True,
        normalize_embeddings=True,
        batch_size=len(sents),
        show_progress_bar=False,
        device=DEVICE
    )

    col_sent.upsert(
        ids=ids,
        embeddings=embs,
        documents=sents,
        metadatas=metas
    )

import logging
import time
import sys

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if not logger.handlers:
    handler = logging.StreamHandler(sys.stdout)
    formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
async def flush_pairs_async(
    pairs: List[Tuple[str, str, str, str, str]],  # (paper_id, sid_src, sid_dst, sent_prev, sent_curr)
    batch_size: int = 50,
) -> None:
    if not pairs:
        logging.debug("flush_pairs_async: Received empty pair buffer. Skipping.")
        return

    ab_pairs = [(sent_prev, sent_curr) for (_, _, _, sent_prev, sent_curr) in pairs]
    logging.info(f"↪ [flush_pairs_async] Processing {len(pairs)}")
    
    prompt_txt = format_sentence_pairs(ab_pairs)
    prompt     = PROMPT.format(PAIRS=prompt_txt, N=len(pairs))

    # 비동기 LLM 호출
    t0 = time.time()
    try:
        resp = await LLM.ainvoke(prompt)
    except Exception as e:
        logging.error(f"[LLM Error] flush_pairs_async failed: {e}")
        return
    t1 = time.time()
    logging.info(f"LLM response received in {t1 - t0:.2f}s")

    labels = parse_llm_labels(resp.content, expected=len(pairs))
    onehot = one_hot_encode_labels(labels)

    # Chroma upsert
    t2 = time.time()
    success, failed = 0, 0

    for (paper_id, sid_src, sid_dst, _, _), lab, vec in zip(pairs, labels, onehot):
        rel_id = f"{sid_src}|{sid_dst}"

        try:
            col_rel.upsert(
                ids        = [rel_id],
                documents  = [f"{sid_src} <REL> {sid_dst}"],
                embeddings = [vec],
                metadatas  = [{
                    "relation_label": lab,
                    "paper_id": paper_id,
                    "sid_src": sid_src,
                    "sid_dst": sid_dst
                }],
            )
            logging.debug(f"[Upsert Success] ID: {rel_id} → {lab}, paper_id: {paper_id}")
            success += 1
        except Exception as e:
            logging.warning(f"[Upsert Failed] ID: {rel_id} — {e}")
            failed += 1

    t3 = time.time()
    logging.info(f"✓ Upsert complete: {success} succeeded, {failed} failed (⏱ {t3 - t2:.2f}s)")


In [14]:
import asyncio

async def process_dataset(nodes, batch_sent=32, batch_pairs=50):
    ids, sents, metas = [], [], []
    pair_buf          = []
    tasks             = []

    for node in nodes:
        ids.append(node["sid"])
        sents.append(node["sent"])
        metas.append(node["meta"])

        if node["prev"]:
            # print(f"Processing pair: {node['prev']['sid']} -> {node['sid']}")
            pair_buf.append((node['meta']['paper_id'],
                             node["prev"]["sid"],  node["sid"],
                             node["prev"]["sent"], node["sent"]))

        if len(ids) >= batch_sent:
            flush_sentence(ids, sents, metas)
            ids, sents, metas = [], [], []

        if len(pair_buf) >= batch_pairs:
            # 병렬 실행 예약
            tasks.append(asyncio.create_task(flush_pairs_async(pair_buf)))
            pair_buf = []

    if ids:
        flush_sentence(ids, sents, metas)
    if pair_buf:
        tasks.append(asyncio.create_task(flush_pairs_async(pair_buf)))

    # 모든 병렬 작업 종료 대기
    if tasks:
        await asyncio.gather(*tasks)


In [16]:

print("▸ Embedding sentences + relations …")
subset = trainset.select(range(1))  # for testing
await process_dataset(generate_sent_nodes(subset))
print("▸ Done embedding sentences + relations.")
print(f"▸ Stored sentences: {col_sent.count():,}")
print(f"▸ Stored vectors   : {col_rel.count():,}")

▸ Embedding sentences + relations …


Generating nodes: 100%|██████████| 1/1 [00:01<00:00,  1.78s/it]


2025-06-16 16:27:52,587 [INFO] ↪ [flush_pairs_async] Processing 50
2025-06-16 16:27:52,593 [INFO] ↪ [flush_pairs_async] Processing 50
2025-06-16 16:27:52,594 [INFO] ↪ [flush_pairs_async] Processing 26
2025-06-16 16:27:54,639 [INFO] LLM response received in 2.04s
2025-06-16 16:27:55,087 [INFO] ✓ Upsert complete: 26 succeeded, 0 failed (⏱ 0.45s)
2025-06-16 16:27:55,100 [INFO] LLM response received in 2.51s
2025-06-16 16:27:55,981 [INFO] ✓ Upsert complete: 50 succeeded, 0 failed (⏱ 0.88s)
2025-06-16 16:27:55,983 [INFO] LLM response received in 3.40s
2025-06-16 16:27:56,906 [INFO] ✓ Upsert complete: 50 succeeded, 0 failed (⏱ 0.92s)
▸ Done embedding sentences + relations.
▸ Stored sentences: 182
▸ Stored vectors   : 126


In [None]:
def semantic_query(query: str, paper_id: str, k: int = 5):
    """
    paper_id 에 해당하는 문서 내부에서만 최근접 문장 k개 검색
    """
    q_vec = ENC_MODEL.encode(query, normalize_embeddings=True)
    res   = col_sent.query(
        query_embeddings=[q_vec],
        n_results=k,
        include=["documents", "metadatas", "distances"],
        # --- 핵심 ---
        where={"paper_id": paper_id}
    )
    
    print(f"\n[Doc = {paper_id}]  {query!r}")
    for doc, dist, sid, meta in zip(res["documents"][0], res["distances"][0], res["ids"][0], res["metadatas"][0]):
        print(f" • doc = {doc}")# dist={dist:.4f}  sid={sid[:8]} text≈{meta.get('sent_idx')}  rel={meta.get('relation_label')}")


sample = trainset[4]
print(len(sample['qas']['question']))
query = sample['qas']['question'][0]
id = sample['id']
topk= 5
semantic_query(query, id, k=topk)

3

[Doc = 1811.00942]  'What aspects have been compared between various language models?'
 • doc = In truth, language models exist in a quality–performance tradeoff space
 • doc = In the present work, we describe and examine the tradeoff space between quality and performance for the task of language modeling
 • doc = In this paper, we examine the quality–performance tradeoff in the shift from non-neural to neural language models
 • doc = Quality–performance tradeoff
 • doc = Specifically focused on language modeling, this paper examines an issue that to our knowledge has not been explored


## Training

In [55]:
# load the collections 
from chromadb import PersistentClient
print(CHROMA_DIR)
client = PersistentClient(path='./demo_chroma_db')
col_sent =client.get_or_create_collection(name=COLLECTION_SENT_NAME)
col_rel = client.get_or_create_collection(name=COLLECTION_RELA_NAME)
col_rel.count(), col_sent.count()

./chroma_qasper_10


(126, 182)

In [None]:
# col_rel 에서 확인용 아무 relation 
r = col_rel.get(
    where={"relation_label": "Evidence"},
    include=["documents", "metadatas", "embeddings"],
    limit=1
)
print(r)
# col_sent 에서 확인
result = col_sent.get(ids="d4a889228628703d6f4d577c9c9ae6c44c2c5597")
print(result)

{'ids': ['d4a889228628703d6f4d577c9c9ae6c44c2c5597|11c8af6284e5a0fc753e2900950071f30c31eb64'], 'embeddings': array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]), 'documents': ['d4a889228628703d6f4d577c9c9ae6c44c2c5597 <REL> 11c8af6284e5a0fc753e2900950071f30c31eb64'], 'uris': None, 'included': ['documents', 'metadatas', 'embeddings'], 'data': None, 'metadatas': [{'paper_id': '1909.00694', 'sid_src': 'd4a889228628703d6f4d577c9c9ae6c44c2c5597', 'relation_label': 'Evidence', 'sid_dst': '11c8af6284e5a0fc753e2900950071f30c31eb64'}]}
{'ids': ['d4a889228628703d6f4d577c9c9ae6c44c2c5597'], 'embeddings': None, 'documents': ['$\\lambda _{\\rm CA}$ was about one-third of $\\lambda _{\\rm CO}$, and this indicated that the CA pairs were noisier than the CO pairs'], 'uris': None, 'included': ['metadatas', 'documents'], 'data': None, 'metadatas': [{'sent_idx': 1, 'sec_idx': 14, 'para_idx': 6, 'paper_id': '1909.00694'}]}


Training

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import negative_sampling
import chromadb
from chromadb.config import Settings
from collections import defaultdict

# 1) ChromaDB 에서 문장·관계 로드

# 1.1) 문장 노드 피처 및 메타
sent_data   = col_sent.get(include=["embeddings","metadatas"])
node_ids    = sent_data["ids"]                                 # List[str]
feats       = torch.tensor(sent_data["embeddings"], dtype=torch.float)  # [N,768]
sent_metas  = sent_data["metadatas"]                           # 각 dict에 "paper_id","sid","para_id","sent_idx" 포함

# 1.2) 관계 엣지 데이터
rel_data   = col_rel.get(include=["metadatas"])
rel_metas  = rel_data["metadatas"]
label_set  = ["Claim","Evidence","Background","Method","Result",
              "Interpretation","Contrast","Cause-Effect",
              "Temporal","Condition","Other"]
lab2idx    = {l:i for i,l in enumerate(label_set)}
NUM_REL    = len(label_set) + 1   # +1 for paragraph edge type
PARA_TYPE  = len(label_set)       # index for paragraph-level edges

# sid → global index
id2idx = {sid:i for i,sid in enumerate(node_ids)}

# 2) 문서별 서브그래프 + paragraph 엣지 구성
docs = defaultdict(lambda: {"nodes": set(), "rel_edges": [], "rel_types": []})

# 2.1) 문장 기여 관계(rel) 추가
for m in rel_metas:
    pid = m["paper_id"]
    u   = id2idx[m["sid_src"]]
    v   = id2idx[m["sid_dst"]]
    t   = lab2idx[m["relation_label"]]
    docs[pid]["nodes"].update({u,v})
    docs[pid]["rel_edges"].append((u,v))
    docs[pid]["rel_types"].append(t)

# 2.2) paragraph-level 이웃 추가 (sent_idx 차이 1)
#    para_id별로 노드 인덱스와 sent_idx를 그룹화
para2nodes = defaultdict(list)
for idx, meta in enumerate(sent_metas):
    para2nodes[(meta["paper_id"], meta["para_idx"])].append((meta["sent_idx"], idx))

for (pid, _), seq in para2nodes.items():
    # 같은 문서·문단 내에서 sent_idx 순으로 정렬
    seq_sorted = sorted(seq, key=lambda x: x[0])
    # 인접 페어만 이웃으로 연결
    for (_, u), (_, v) in zip(seq_sorted, seq_sorted[1:]):
        docs[pid]["nodes"].update({u,v})
        docs[pid]["rel_edges"].append((u,v))
        docs[pid]["rel_types"].append(PARA_TYPE)
        # undirected 처리하려면 역방향도 추가
        docs[pid]["rel_edges"].append((v,u))
        docs[pid]["rel_types"].append(PARA_TYPE)

# 2.3) Data 리스트 생성
data_list = []
for pid, info in docs.items():
    if not info["rel_edges"]:
        continue
    uniq_nodes = sorted(info["nodes"])
    g2l        = {g:i for i,g in enumerate(uniq_nodes)}
    x_sub      = feats[uniq_nodes]  # [n_sub,768]

    # 엣지 및 타입
    e_list = info["rel_edges"]
    t_list = info["rel_types"]
    edge_index = torch.tensor([
        [g2l[u] for u,_ in e_list],
        [g2l[v] for _,v in e_list]
    ], dtype=torch.long)
    edge_type  = torch.tensor(t_list, dtype=torch.long)

    data_list.append(Data(x=x_sub, edge_index=edge_index, edge_type=edge_type))

# 3) Relation-Weighted GraphSAGE 정의
class RelGraphSAGE(MessagePassing):
    def __init__(self, in_dim, hid_dim, num_rel):
        super().__init__(aggr="mean")
        self.lin_self  = torch.nn.Linear(in_dim, hid_dim)
        self.lin_neigh = torch.nn.Linear(in_dim, hid_dim)
        self.w_rel     = torch.nn.Parameter(torch.ones(num_rel), requires_grad=True)

    def forward(self, x, edge_index, edge_type):
        h_self = self.lin_self(x)
        h_neigh = self.propagate(edge_index, x=x, edge_type=edge_type)
        return F.relu(h_self + h_neigh)

    def message(self, x_j, edge_type):
        w = self.w_rel[edge_type].unsqueeze(-1)    # [E,1]
        return w * self.lin_neigh(x_j)

# 4) 학습 세팅
device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model     = RelGraphSAGE(in_dim=768, hid_dim=256, num_rel=NUM_REL).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs    = 30
loader    = DataLoader(data_list, batch_size=1, shuffle=True)

# 5) Contrastive-style 학습
for ep in range(1, epochs+1):
    model.train()
    total_loss = 0.0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # 노드 임베딩 계산
        h = model(batch.x, batch.edge_index, batch.edge_type)  # [n_sub,256]

        # positive edge similarity
        src, dst  = batch.edge_index
        pos_sim   = (h[src] * h[dst]).sum(dim=1)               # [E]
        w_pos     = model.w_rel[batch.edge_type]               # [E]

        # negative sampling
        neg_idx   = negative_sampling(batch.edge_index,
                                      num_nodes=h.size(0),
                                      num_neg_samples=src.size(0))
        ns, nd    = neg_idx
        neg_sim   = (h[ns] * h[nd]).sum(dim=1)

        # loss 계산
        loss_pos  = - (w_pos * F.logsigmoid(pos_sim)).sum() / w_pos.sum()
        loss_neg  = - F.logsigmoid(-neg_sim).mean()
        loss      = loss_pos + loss_neg

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {ep:02d}   Avg Loss: {total_loss/len(loader):.4f}")

# 6) 결과: 최종 노드 임베딩 & 관계 가중치
model.eval()
embeddings_by_doc = []
with torch.no_grad():
    for data in data_list:
        data = data.to(device)
        h_sub = model(data.x, data.edge_index, data.edge_type)
        embeddings_by_doc.append(h_sub.cpu())
learned_w = model.w_rel.detach().cpu()  # [num_rel]

print("Learned relation weights:", learned_w)




Epoch 01   Avg Loss: 1.3877
Epoch 02   Avg Loss: 1.3821
Epoch 03   Avg Loss: 1.3782
Epoch 04   Avg Loss: 1.3732
Epoch 05   Avg Loss: 1.3678
Epoch 06   Avg Loss: 1.3627
Epoch 07   Avg Loss: 1.3529
Epoch 08   Avg Loss: 1.3447
Epoch 09   Avg Loss: 1.3409
Epoch 10   Avg Loss: 1.3329
Epoch 11   Avg Loss: 1.3275
Epoch 12   Avg Loss: 1.3204
Epoch 13   Avg Loss: 1.3021
Epoch 14   Avg Loss: 1.3009
Epoch 15   Avg Loss: 1.2991
Epoch 16   Avg Loss: 1.2912
Epoch 17   Avg Loss: 1.2849
Epoch 18   Avg Loss: 1.2722
Epoch 19   Avg Loss: 1.2668
Epoch 20   Avg Loss: 1.2566
Epoch 21   Avg Loss: 1.2603
Epoch 22   Avg Loss: 1.2629
Epoch 23   Avg Loss: 1.2690
Epoch 24   Avg Loss: 1.2431
Epoch 25   Avg Loss: 1.2290
Epoch 26   Avg Loss: 1.2332
Epoch 27   Avg Loss: 1.2276
Epoch 28   Avg Loss: 1.2202
Epoch 29   Avg Loss: 1.2417
Epoch 30   Avg Loss: 1.1849
Learned relation weights: tensor([1.0000, 1.0234, 1.0312, 1.0248, 1.0197, 1.0083, 1.0001, 1.0317, 1.0000,
        1.0282, 1.0299, 1.0264])


## 4. LangGraph State and Node config

In [64]:
col_sent = client.get_or_create_collection(COLLECTION_SENT_NAME)

In [None]:
# %% 3) 벡터 인덱싱 → 문장 단위 DB 로딩으로 변경
import chromadb
from chromadb.config import Settings

# %% LangGraph & 노드 정의 (이하는 기존과 동일)
import os
from langgraph.graph import StateGraph, START, END
from langchain_google_genai import ChatGoogleGenerativeAI
from sentence_transformers import SentenceTransformer, CrossEncoder

from typing import TypedDict

# 모델 설정
encoder     = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', device=device)
ce_reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
llm         = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    google_api_key=os.getenv('GOOGLE_API_KEY'),
    temperature=0.0
)

# State Schema
class InputState(TypedDict):
    paper_id: str
    question: str

class OutputState(TypedDict):
    paper_id: str
    answer: str
    top_docs: list[str]
    top_metadatas: list[dict]

class OverallState(InputState, OutputState):
    retrieved_docs: list[str]
    retrieved_metadatas: list[dict]

# retrieve 노드: 문장 컬렉션에서 검색
def retrieve(state: OverallState) -> OverallState:
    assert col_sent is not None, "Sentence-level collection not found"
    q        = state["question"]
    paper_id = state["paper_id"]
    # print(f"Retrieving sentences for paper {paper_id!r} with query {q!r}")
    # 1) 질의 임베딩
    q_emb = encoder.encode(q)
    # 2) col_sent에서 논문별 필터링 후 top-20 문장 검색
    res = col_sent.query(
        query_embeddings=[q_emb],
        n_results=10,
        include=["documents", "metadatas"],
        where={"paper_id": paper_id}
    )
    # print(f"Retrived sents : {res['documents'][0]}")
    # print(f"Retrieved {len(res['documents'][0])} sentences for paper {paper_id!r} with query {q!r}")
    # Chromadb 문법상 [0]으로 추출
    state["retrieved_docs"]      = res["documents"][0]
    state["retrieved_metadatas"] = res["metadatas"][0]
    return state

# rerank / generate 노드는 기존과 동일
def rerank(state: OverallState) -> OverallState:
    docs      = state["retrieved_docs"]
    metadatas = state["retrieved_metadatas"]
    q         = state["question"]

    scores = ce_reranker.predict([(q, d) for d in docs])
    ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)

    state["top_docs"]      = [docs[i] for i, _ in ranked[:6]]
    state["top_metadatas"] = [metadatas[i] for i, _ in ranked[:6]]
    return state

def generate(state: OverallState) -> OverallState:
    ctx = "\n\n".join(state["top_docs"])
    prompt = (
        f"Prompt: Answer based on context below. If you don't know, say 'unanswerable'.\n"
        f"Context:\n{ctx}\n\nQuestion: {state['question']}\nAnswer:"
    )
    resp = llm.invoke(prompt)
    state["answer"] = resp.content
    return state

# 파이프라인 빌드
builder = StateGraph(state_schema=OverallState, input=InputState, output=OutputState)
builder.add_node("retrieve", retrieve)
builder.add_node("rerank",   rerank)
builder.add_node("generate", generate)
builder.set_entry_point("retrieve")
builder.add_edge("retrieve", "rerank")
builder.add_edge("rerank",   "generate")
builder.set_finish_point("generate")

rag_pipeline = builder.compile()


In [None]:
import os
import hashlib
import torch
import pandas as pd
from datasets import load_dataset
from tqdm import tqdm
from langchain_google_genai import ChatGoogleGenerativeAI


# 1) 데이터 로드 (검증용 subset)
dataset = load_dataset('allenai/qasper', split='train', cache_dir='./data/Qasper/qasper_cache')
# subset  = dataset.shuffle(seed=42).select(range(int(len(dataset)*0.1)))  # 10%

# 2) helper functions
def normalize(text: str) -> str:
    return ''.join(c.lower() for c in text if c.isalnum() or c.isspace()).strip()

def find_gold_para_ids(sample, gold_evids):
    paper_id = sample["id"]
    gold_ids = set()
    # full_text['paragraphs'] 는 섹션별로 [문단1, 문단2, …] 리스트를 가집니다.
    sections = sample["full_text"]["paragraphs"]
    for ev in gold_evids:
        ev_norm = ev.strip()
        for sec_idx, paras in enumerate(sections):
            for para_idx, para_text in enumerate(paras):
                if ev_norm in para_text:
                    para_id = f"sec{sec_idx}/para{para_idx}"
                    gold_ids.add(f"{paper_id}_{para_id}")
                    # 하나의 evidence는 하나의 문단만 매핑하므로 찾으면 바로 빠져나감
                    break
            else:
                continue
            break
    return gold_ids

def compute_retrieval_metrics(pred: list[str], gold: set[str]):
    pred_set = set(pred)
    tp = len(pred_set & gold)
    prec = tp / len(pred_set) if pred_set else 0.0
    rec  = tp / len(gold)     if gold     else 0.0
    f1   = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
    return prec, rec, f1

def compute_answer_metrics(pred_ans: str, gold_texts: list[str], gold_yesno: list[bool], gold_unans: list[bool]):
    norm_pred = normalize(pred_ans)
    # Exact Match over all annotators
    em = any(norm_pred == normalize(gt) for gt in gold_texts)
    # yes/no
    if any(gold_yesno):
        yn_pred = norm_pred.startswith('yes') or norm_pred.startswith('no')
        yn_acc  = int(yn_pred and norm_pred.split()[0] == ('yes' if any(gold_yesno) else 'no'))
    else:
        yn_acc = None
    # unanswerable
    if any(gold_unans):
        un_pred = norm_pred.startswith('unanswerable')
        un_acc  = int(un_pred)
    else:
        un_acc = None
    return int(em), yn_acc, un_acc

# 3) evaluation loop
records = []
for sample in tqdm(subset, desc="Evaluating"):
    pid       = sample["id"]
    questions = sample["qas"]["question"]
    answers   = sample["qas"]["answers"]
    
    for q_text, ans_block in zip(questions, answers):
        # 3.1) RAG-invoke
        out = rag_pipeline.invoke({"paper_id": pid, "question": q_text})
        pred_ans        = out["answer"].strip()
        pred_metadatas  = out["top_metadatas"]
        # predicted evidence IDs
        pred_ids = []
        for md in pred_metadatas:
            pid     = md["paper_id"]
            sec_idx = md["sec_idx"]
            para_idx= md["para_idx"]
            para_id = f"sec{sec_idx}/para{para_idx}"
            pred_ids.append(f"{pid}_{para_id}")
        # in format of "paper_id_para_idx"
        
        # 3.2) gold aggregation
        gold_texts, gold_evids, gold_yesno, gold_unans = [], [], [], []
        # del gold_texts, gold_yesno, gold_unans  # clear previous
        for ann in ans_block["answer"]:
            if ann.get("answer"):
                gold_texts.append(ann["answer"])
            if ann.get("evidence"):
                gold_evids.extend(ann["evidence"])
            if ann.get("yes_no") is not None:
                gold_yesno.append(ann["yes_no"])
            if ann.get("unanswerable") is not None:
                gold_unans.append(ann["unanswerable"])
        
        gold_ids = find_gold_para_ids(sample, gold_evids)
        
        # 3.3) metrics
        prec, rec, f1 = compute_retrieval_metrics(pred_ids, gold_ids)
        em, yn_acc, un_acc = compute_answer_metrics(pred_ans, gold_texts, gold_yesno, gold_unans)
        
        records.append({
            "paper_id":           pid,
            "question":           q_text,
            "precision":          prec,
            "recall":             rec,
            "f1":                 f1,
            "exact_match":        em,
            "yes_no_acc":         yn_acc,
            "unans_acc":          un_acc
        })

# 4) DataFrame 생성 및 paper별 집계
df = pd.DataFrame(records)

# 질문 단위 전체 평균
overall = df.mean(numeric_only=True).round(4)
print("=== Overall ===")
print(overall.to_dict())

# paper_id 별 평균
by_paper = (
    df
    .groupby("paper_id")
    .agg({
        "precision":   "mean",
        "recall":      "mean",
        "f1":          "mean",
        "exact_match": "mean",
        "yes_no_acc":  "mean",
        "unans_acc":   "mean"
    })
    .round(4)
    .reset_index()
)
print("\n=== By Paper ===")
print(by_paper)

# 5) 결과 저장
df.to_csv("qasper_rag_eval_per_question.csv", index=False)
by_paper.to_csv("qasper_rag_eval_by_paper.csv", index=False)


Evaluating:   0%|          | 0/88 [00:00<?, ?it/s]


IndexError: list index out of range