# 06. Baseline Comparisons

7개 Baseline과의 공정 비교

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import yaml
from omegaconf import OmegaConf

## 1. Baseline 목록

| # | Baseline | Description |
|---|----------|-------------|
| 1 | No Retrieval | LLM only (parametric knowledge) |
| 2 | BM25 + RAG | Sparse retrieval |
| 3 | Standard RAG | Dense retrieval (Contriever) |
| 4 | Self-RAG | Retrieval-augmented with self-reflection |
| 5 | IRCoT | Chain-of-thought with interleaved retrieval |
| 6 | Adaptive-RAG | Query complexity based routing |
| 7 | CoRAG | Collaborative RAG |


In [None]:
# Config
config_path = PROJECT_ROOT / "configs" / "phase3_full.yaml"
with open(config_path) as f:
    config = OmegaConf.create(yaml.safe_load(f))

print("Baselines config:")
print(OmegaConf.to_yaml(config.baselines))

## 2. Baseline 1: No Retrieval

In [None]:
from baselines.no_retrieval import NoRetrievalBaseline

class NoRetrievalBaseline:
    """
    Baseline 1: No Retrieval
    - LLM의 parametric knowledge만 사용
    - 외부 문서 없이 질문에 직접 답변
    """
    
    def __init__(self, llm_name: str, device: str = "cuda"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            llm_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def generate(self, question: str, max_new_tokens: int = 64):
        prompt = f"Question: {question}\nAnswer:"
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
            )
        
        generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = generated.split("Answer:")[-1].strip()
        return answer

print("NoRetrievalBaseline defined")

In [None]:
# 테스트
# no_retrieval = NoRetrievalBaseline(config.model.llm_name)
# answer = no_retrieval.generate("What is the capital of France?")
# print(f"Answer: {answer}")

## 3. Baseline 2-3: BM25 & Dense RAG

In [None]:
from baselines.standard_rag import StandardRAG

class StandardRAG:
    """
    Baseline 2-3: Standard RAG
    - BM25 (sparse) or Contriever (dense) retrieval
    - Top-k documents → LLM
    """
    
    def __init__(
        self,
        llm_name: str,
        corpus: dict,
        retriever_type: str = "dense",  # "bm25" or "dense"
        top_k: int = 5,
        device: str = "cuda",
    ):
        self.device = device
        self.corpus = corpus
        self.top_k = top_k
        self.retriever_type = retriever_type
        
        # LLM
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            llm_name,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Retriever 초기화
        self._init_retriever()
    
    def _init_retriever(self):
        if self.retriever_type == "bm25":
            from rank_bm25 import BM25Okapi
            
            self.doc_ids = list(self.corpus.keys())
            tokenized_corpus = [doc.split() for doc in self.corpus.values()]
            self.bm25 = BM25Okapi(tokenized_corpus)
        else:
            from sentence_transformers import SentenceTransformer
            import numpy as np
            
            self.encoder = SentenceTransformer("facebook/contriever")
            self.doc_ids = list(self.corpus.keys())
            self.doc_embeddings = self.encoder.encode(
                list(self.corpus.values()),
                show_progress_bar=True,
            )
    
    def retrieve(self, query: str, k: int = None):
        k = k or self.top_k
        
        if self.retriever_type == "bm25":
            tokenized_query = query.split()
            scores = self.bm25.get_scores(tokenized_query)
            top_indices = scores.argsort()[-k:][::-1]
        else:
            import numpy as np
            query_emb = self.encoder.encode([query])
            scores = np.dot(self.doc_embeddings, query_emb.T).squeeze()
            top_indices = scores.argsort()[-k:][::-1]
        
        return [self.doc_ids[i] for i in top_indices]
    
    def generate(self, question: str, max_new_tokens: int = 64):
        # Retrieve
        retrieved_ids = self.retrieve(question)
        context = "\n".join([self.corpus[did] for did in retrieved_ids])
        
        # Generate
        prompt = f"Context: {context}\n\nQuestion: {question}\nAnswer:"
        
        inputs = self.tokenizer(
            prompt, return_tensors="pt", truncation=True, max_length=2048
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
            )
        
        generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = generated.split("Answer:")[-1].strip()
        return answer, retrieved_ids

print("StandardRAG defined")

## 4. Baseline 4-7: Advanced RAG Methods

In [None]:
# Self-RAG, IRCoT, Adaptive-RAG, CoRAG는
# 복잡한 구현이 필요하므로 논문 참조 구현 사용

print("Advanced RAG baselines:")
print("")
print("4. Self-RAG (Asai et al., 2023)")
print("   - Retrieval-augmented with self-reflection tokens")
print("   - [Retrieve], [IsRel], [IsSup], [IsUse] 토큰 사용")
print("")
print("5. IRCoT (Trivedi et al., 2022)")
print("   - Interleaved Retrieval Chain-of-Thought")
print("   - 추론 중간에 retrieval 수행")
print("")
print("6. Adaptive-RAG (Jeong et al., 2024)")
print("   - Query complexity 기반 routing")
print("   - Simple → No retrieval, Complex → Multi-step")
print("")
print("7. CoRAG (Collaborative RAG)")
print("   - Multiple retriever ensemble")

In [None]:
# Self-RAG 간소화 버전
class SelfRAGSimplified:
    """
    Self-RAG 간소화 구현
    - 실제로는 fine-tuned model이 필요
    - 여기서는 retrieval decision을 rule-based로 수행
    """
    
    def __init__(self, llm_name: str, corpus: dict, device: str = "cuda"):
        self.rag = StandardRAG(llm_name, corpus, retriever_type="dense", device=device)
    
    def should_retrieve(self, question: str) -> bool:
        """Retrieval이 필요한지 판단 (rule-based)"""
        factual_keywords = ["who", "what", "when", "where", "which", "how many"]
        q_lower = question.lower()
        return any(kw in q_lower for kw in factual_keywords)
    
    def generate(self, question: str, max_new_tokens: int = 64):
        if self.should_retrieve(question):
            return self.rag.generate(question, max_new_tokens)
        else:
            # No retrieval path
            prompt = f"Question: {question}\nAnswer:"
            inputs = self.rag.tokenizer(prompt, return_tensors="pt").to(self.rag.device)
            with torch.no_grad():
                outputs = self.rag.model.generate(**inputs, max_new_tokens=max_new_tokens)
            generated = self.rag.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return generated.split("Answer:")[-1].strip(), []

print("SelfRAGSimplified defined")

## 5. 통합 평가

In [None]:
from evaluation.metrics import compute_em, compute_f1, compute_recall_at_k

def evaluate_baseline(baseline, qa_pairs, top_k=5):
    """
    Baseline 평가
    """
    all_em = []
    all_f1 = []
    all_recall = []
    
    for item in tqdm(qa_pairs, desc="Evaluating"):
        question = item["question"]
        answer = item["answer"]
        gold_doc_ids = item.get("gold_doc_ids", [])
        
        # Generate
        result = baseline.generate(question)
        if isinstance(result, tuple):
            prediction, retrieved_ids = result
        else:
            prediction = result
            retrieved_ids = []
        
        # Metrics
        em = compute_em(prediction, answer)
        f1 = compute_f1(prediction, answer)
        recall = compute_recall_at_k(retrieved_ids, gold_doc_ids, k=top_k) if gold_doc_ids else 0
        
        all_em.append(em)
        all_f1.append(f1)
        all_recall.append(recall)
    
    return {
        "EM": sum(all_em) / len(all_em) * 100,
        "F1": sum(all_f1) / len(all_f1) * 100,
        f"Recall@{top_k}": sum(all_recall) / len(all_recall) * 100 if all_recall else 0,
    }

print("evaluate_baseline defined")

In [None]:
# 전체 baseline 평가 실행
# baselines = {
#     "No Retrieval": no_retrieval_baseline,
#     "BM25-RAG": bm25_rag,
#     "Dense-RAG": dense_rag,
#     "Self-RAG": self_rag,
#     "Parametric-QA": parametric_qa_model,
# }

# results = {}
# for name, baseline in baselines.items():
#     print(f"\nEvaluating {name}...")
#     results[name] = evaluate_baseline(baseline, qa_pairs)
#     print(f"Results: {results[name]}")

## 6. 결과 비교

In [None]:
# 예시 결과
example_results = {
    "No Retrieval": {"EM": 22.1, "F1": 28.4, "Recall@5": 0.0},
    "BM25-RAG": {"EM": 35.6, "F1": 42.1, "Recall@5": 58.2},
    "Dense-RAG": {"EM": 40.3, "F1": 48.7, "Recall@5": 72.4},
    "Self-RAG": {"EM": 42.8, "F1": 51.2, "Recall@5": 74.1},
    "IRCoT": {"EM": 44.1, "F1": 53.5, "Recall@5": 76.3},
    "Adaptive-RAG": {"EM": 43.5, "F1": 52.8, "Recall@5": 75.2},
    "CoRAG": {"EM": 45.2, "F1": 54.1, "Recall@5": 78.5},
    "Parametric-QA (Ours)": {"EM": 47.3, "F1": 56.8, "Recall@5": 82.1},
}

df = pd.DataFrame(example_results).T
df = df.sort_values("EM", ascending=False)
print(df.to_string())

In [None]:
# 시각화
fig, ax = plt.subplots(figsize=(12, 6))

x = range(len(df))
width = 0.25

colors = ["#e74c3c" if "Ours" in name else "#3498db" for name in df.index]

bars1 = ax.bar([i - width for i in x], df["EM"], width, label="EM", color=colors, alpha=0.8)
bars2 = ax.bar([i for i in x], df["F1"], width, label="F1", color=colors, alpha=0.6)
bars3 = ax.bar([i + width for i in x], df["Recall@5"], width, label="Recall@5", color=colors, alpha=0.4)

ax.set_ylabel("Score (%)")
ax.set_title("Baseline Comparison")
ax.set_xticks(x)
ax.set_xticklabels(df.index, rotation=45, ha="right")
ax.legend()
ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

## 7. Latex Table 생성

In [None]:
def generate_latex_table(df, caption="Baseline Comparison"):
    """LaTeX 테이블 생성"""
    
    latex = "\\begin{table}[h]\n"
    latex += "\\centering\n"
    latex += f"\\caption{{{caption}}}\n"
    latex += "\\begin{tabular}{lccc}\n"
    latex += "\\toprule\n"
    latex += "Method & EM & F1 & Recall@5 \\\\\n"
    latex += "\\midrule\n"
    
    for method, row in df.iterrows():
        if "Ours" in method:
            latex += f"\\textbf{{{method}}} & \\textbf{{{row['EM']:.1f}}} & \\textbf{{{row['F1']:.1f}}} & \\textbf{{{row['Recall@5']:.1f}}} \\\\\n"
        else:
            latex += f"{method} & {row['EM']:.1f} & {row['F1']:.1f} & {row['Recall@5']:.1f} \\\\\n"
    
    latex += "\\bottomrule\n"
    latex += "\\end{tabular}\n"
    latex += "\\end{table}"
    
    return latex

print(generate_latex_table(df))