# 09. Multi-hop Extension (Phase 4)

HotpotQA에서의 Multi-hop QA 실험

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
from transformers import AutoTokenizer
from tqdm.notebook import tqdm
import pandas as pd
import yaml
from omegaconf import OmegaConf

## 1. Multi-hop QA 개요

### HotpotQA 특성
- **Bridge questions**: 여러 문서를 연결하는 bridge entity 필요
- **Comparison questions**: 두 entity를 비교하는 질문
- **Supporting facts**: 정답 추론에 필요한 문장들

### Multi-hop Strategies
1. **Concat**: 선택된 z_i들을 단순 concat
2. **Iterative**: 첫 번째 z로 bridge entity 식별 후 두 번째 retrieval

In [None]:
# Config
config_path = PROJECT_ROOT / "configs" / "phase4_multihop.yaml"
with open(config_path) as f:
    config = OmegaConf.create(yaml.safe_load(f))

print("Multi-hop config:")
print(OmegaConf.to_yaml(config.multihop))

## 2. HotpotQA 데이터 로드

In [None]:
from datasets import load_dataset

def load_hotpotqa(split="validation", max_samples=1000):
    """
    HotpotQA 로드
    """
    dataset = load_dataset("hotpot_qa", "fullwiki", split=split)
    
    processed = []
    for item in dataset:
        # Context에서 corpus 구성
        titles = item["context"]["title"]
        sentences_list = item["context"]["sentences"]
        
        docs = {}
        for i, (title, sentences) in enumerate(zip(titles, sentences_list)):
            docs[i] = {
                "title": title,
                "text": " ".join(sentences),
            }
        
        # Supporting facts에서 gold doc IDs 추출
        sf_titles = item["supporting_facts"]["title"]
        gold_doc_ids = []
        for sf_title in set(sf_titles):
            for doc_id, doc in docs.items():
                if doc["title"] == sf_title:
                    gold_doc_ids.append(doc_id)
                    break
        
        processed.append({
            "question": item["question"],
            "answer": item["answer"],
            "type": item["type"],  # bridge or comparison
            "level": item["level"],  # hard, medium, easy
            "gold_doc_ids": gold_doc_ids,
            "docs": docs,
        })
        
        if len(processed) >= max_samples:
            break
    
    return processed

# 로드 (실제로는 시간이 걸림)
# hotpot_data = load_hotpotqa(max_samples=1000)
# print(f"Loaded {len(hotpot_data)} samples")

# 샘플 확인
# print(hotpot_data[0])

## 3. Multi-hop Strategy: Concat

In [None]:
class ConcatStrategy:
    """
    Concat Strategy
    
    선택된 여러 z_i를 단순 concatenation
    - 장점: 간단, 빠름
    - 단점: 문서 간 관계 명시적으로 모델링 안 함
    """
    
    def __init__(self, model):
        self.model = model
    
    def forward(self, query_ids, query_mask, top_k=5):
        """
        Args:
            query_ids: [B, query_len]
            query_mask: [B, query_len]
            top_k: 선택할 문서 수
        Returns:
            z_combined: [B, k*m_tokens, z_dim]
        """
        # Select top-k documents
        selected_ids, scores = self.model.select_documents(query_ids, query_mask, k=top_k)
        
        # Get z vectors
        batch_size = query_ids.size(0)
        z_selected = self.model.doc_vectors[selected_ids]  # [B, k, m_tokens, z_dim]
        
        # Concat along m_tokens dimension
        z_combined = z_selected.view(batch_size, -1, z_selected.size(-1))  # [B, k*m_tokens, z_dim]
        
        return z_combined, selected_ids, scores

print("ConcatStrategy defined")

## 4. Multi-hop Strategy: Iterative

In [None]:
class IterativeStrategy:
    """
    Iterative Strategy
    
    1단계: Query로 첫 번째 문서 검색
    2단계: 첫 번째 문서에서 bridge entity 식별
    3단계: Bridge entity로 두 번째 문서 검색
    
    - 장점: 명시적 multi-hop reasoning
    - 단점: 2번의 selection 필요, latency 증가
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def extract_bridge_entity(self, query_ids, z_first, max_new_tokens=32):
        """
        첫 번째 문서에서 bridge entity 추출
        (실제로는 LLM을 사용하거나 NER 사용)
        """
        # Simplified: LLM으로 bridge entity 생성
        # 프롬프트: "Based on the context, what entity connects to the answer?"
        
        # 여기서는 placeholder
        return "bridge_entity"
    
    def forward(self, query_ids, query_mask, k_per_hop=3):
        """
        Args:
            query_ids: [B, query_len]
            query_mask: [B, query_len]
            k_per_hop: 각 hop에서 선택할 문서 수
        Returns:
            z_combined: 두 hop의 z 결합
        """
        batch_size = query_ids.size(0)
        
        # Hop 1: Initial query
        selected_ids_1, scores_1 = self.model.select_documents(
            query_ids, query_mask, k=k_per_hop
        )
        z_hop1 = self.model.doc_vectors[selected_ids_1]  # [B, k, m, z]
        
        # Bridge entity extraction (simplified)
        # 실제로는 LLM 호출 또는 NER 필요
        
        # Hop 2: Query with bridge context
        # 여기서는 z_hop1을 query에 추가하여 재검색
        # (실제 구현은 더 복잡)
        
        # Combine z from both hops
        z_hop1_flat = z_hop1.view(batch_size, -1, z_hop1.size(-1))
        
        # Simplified: 같은 docs 재사용 (실제로는 다른 docs)
        z_combined = z_hop1_flat
        
        return z_combined, selected_ids_1, scores_1

print("IterativeStrategy defined")

## 5. Multi-hop 학습

In [None]:
class MultihopTrainer:
    """
    Multi-hop QA Trainer
    
    Loss = L_gen + λ1 * L_retrieval + λ2 * L_supporting_facts
    """
    
    def __init__(
        self,
        model,
        strategy,  # ConcatStrategy or IterativeStrategy
        tokenizer,
        lr: float = 1e-4,
        retrieval_weight: float = 0.1,
        sf_weight: float = 0.1,  # supporting facts weight
        device: str = "cuda",
    ):
        self.model = model
        self.strategy = strategy
        self.tokenizer = tokenizer
        self.device = device
        self.retrieval_weight = retrieval_weight
        self.sf_weight = sf_weight
        
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    def train_step(self, batch):
        self.model.train()
        
        query_ids = batch["query_ids"].to(self.device)
        query_mask = batch["query_mask"].to(self.device)
        answer_ids = batch["answer_ids"].to(self.device)
        gold_doc_ids = batch["gold_doc_ids"]
        
        # Multi-hop selection
        z_combined, selected_ids, scores = self.strategy.forward(
            query_ids, query_mask
        )
        
        # Generation loss
        gen_loss = self.model.forward_with_z(
            z_combined, query_ids, answer_ids, query_mask
        )
        
        # Retrieval loss (gold docs 포함)
        # ... (이전과 동일)
        
        # Total loss
        loss = gen_loss  # + retrieval_loss + sf_loss
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

print("MultihopTrainer defined")

## 6. Multi-hop 평가

In [None]:
from evaluation.metrics import compute_em, compute_f1, compute_recall_at_k

def evaluate_multihop(model, strategy, qa_pairs, tokenizer, top_k=5, device="cuda"):
    """
    Multi-hop QA 평가
    
    Metrics:
    - EM, F1 (answer quality)
    - Recall@K (retrieval for both hops)
    - Supporting Facts EM/F1 (optional)
    """
    model.eval()
    
    results = {
        "bridge": {"em": [], "f1": [], "recall": []},
        "comparison": {"em": [], "f1": [], "recall": []},
    }
    
    for item in tqdm(qa_pairs, desc="Evaluating"):
        question = item["question"]
        answer = item["answer"]
        q_type = item["type"]  # bridge or comparison
        gold_doc_ids = item["gold_doc_ids"]
        
        # Tokenize
        q_encoded = tokenizer(
            question, max_length=256, truncation=True,
            padding="max_length", return_tensors="pt",
        )
        query_ids = q_encoded["input_ids"].to(device)
        query_mask = q_encoded["attention_mask"].to(device)
        
        with torch.no_grad():
            # Multi-hop selection
            z_combined, selected_ids, scores = strategy.forward(
                query_ids, query_mask, k_per_hop=top_k
            )
            
            # Generation
            generated_ids = model.generate_with_z(
                z_combined, query_ids, query_mask, max_new_tokens=64
            )
            prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        # Metrics
        em = compute_em(prediction, answer)
        f1 = compute_f1(prediction, answer)
        selected_list = selected_ids[0].cpu().tolist()
        recall = compute_recall_at_k(selected_list, gold_doc_ids, k=top_k)
        
        results[q_type]["em"].append(em)
        results[q_type]["f1"].append(f1)
        results[q_type]["recall"].append(recall)
    
    # Aggregate
    summary = {}
    for q_type in ["bridge", "comparison"]:
        if results[q_type]["em"]:
            summary[q_type] = {
                "EM": sum(results[q_type]["em"]) / len(results[q_type]["em"]) * 100,
                "F1": sum(results[q_type]["f1"]) / len(results[q_type]["f1"]) * 100,
                f"Recall@{top_k}": sum(results[q_type]["recall"]) / len(results[q_type]["recall"]) * 100,
            }
    
    return summary

print("evaluate_multihop defined")

## 7. Strategy 비교

In [None]:
# 예시 결과
strategy_comparison = {
    "Concat": {
        "bridge": {"EM": 32.5, "F1": 41.2, "Recall@5": 58.3},
        "comparison": {"EM": 38.2, "F1": 46.5, "Recall@5": 72.1},
    },
    "Iterative": {
        "bridge": {"EM": 36.8, "F1": 45.3, "Recall@5": 65.2},
        "comparison": {"EM": 39.5, "F1": 47.8, "Recall@5": 73.5},
    },
}

# 결과 출력
for strategy, results in strategy_comparison.items():
    print(f"\n{strategy} Strategy:")
    for q_type, metrics in results.items():
        print(f"  {q_type}: EM={metrics['EM']:.1f}, F1={metrics['F1']:.1f}, Recall@5={metrics['Recall@5']:.1f}")

In [None]:
import matplotlib.pyplot as plt

# 시각화
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bridge questions
ax1 = axes[0]
x = ["EM", "F1", "Recall@5"]
concat_vals = [strategy_comparison["Concat"]["bridge"][m] for m in x]
iter_vals = [strategy_comparison["Iterative"]["bridge"][m] for m in x]

width = 0.35
ax1.bar([i - width/2 for i in range(len(x))], concat_vals, width, label="Concat", color="#3498db")
ax1.bar([i + width/2 for i in range(len(x))], iter_vals, width, label="Iterative", color="#e74c3c")
ax1.set_xticks(range(len(x)))
ax1.set_xticklabels(x)
ax1.set_ylabel("Score (%)")
ax1.set_title("Bridge Questions")
ax1.legend()
ax1.grid(axis="y", alpha=0.3)

# Comparison questions
ax2 = axes[1]
concat_vals = [strategy_comparison["Concat"]["comparison"][m] for m in x]
iter_vals = [strategy_comparison["Iterative"]["comparison"][m] for m in x]

ax2.bar([i - width/2 for i in range(len(x))], concat_vals, width, label="Concat", color="#3498db")
ax2.bar([i + width/2 for i in range(len(x))], iter_vals, width, label="Iterative", color="#e74c3c")
ax2.set_xticks(range(len(x)))
ax2.set_xticklabels(x)
ax2.set_ylabel("Score (%)")
ax2.set_title("Comparison Questions")
ax2.legend()
ax2.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

## 8. 한계점 분석

In [None]:
# Multi-hop에서의 한계점
limitations = {
    "Bridge entity recognition": {
        "issue": "z_i만으로는 bridge entity를 명시적으로 식별하기 어려움",
        "potential_solution": "Iterative strategy + NER/LLM 추출",
    },
    "Hop ordering": {
        "issue": "어떤 문서를 먼저 검색해야 하는지 판단 어려움",
        "potential_solution": "Query decomposition + ordering model",
    },
    "Information propagation": {
        "issue": "z_i 간의 정보 흐름이 implicit",
        "potential_solution": "Cross-attention between z vectors",
    },
    "Scalability": {
        "issue": "Iterative는 hop 수만큼 latency 증가",
        "potential_solution": "Parallel retrieval + merge",
    },
}

print("Multi-hop Limitations:")
for limitation, details in limitations.items():
    print(f"\n{limitation}:")
    print(f"  Issue: {details['issue']}")
    print(f"  Solution: {details['potential_solution']}")