# 04. Read Phase Training

Stage 2: Query → z_i selection → Answer 학습

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm.notebook import tqdm
import yaml
from omegaconf import OmegaConf

## 1. Configuration

In [None]:
# Config 로드
config_path = PROJECT_ROOT / "configs" / "phase1_poc.yaml"
with open(config_path) as f:
    config = OmegaConf.create(yaml.safe_load(f))

# Read phase config
read_config = config.read_phase
print("Read Phase Configuration:")
print(OmegaConf.to_yaml(read_config))

In [None]:
# 하이퍼파라미터
BATCH_SIZE = read_config.batch_size
LEARNING_RATE_Z = read_config.lr_z
LEARNING_RATE_LORA = read_config.lr_lora
NUM_EPOCHS = read_config.epochs
TOP_K = read_config.top_k
RETRIEVAL_WEIGHT = read_config.retrieval_weight

print(f"Batch size: {BATCH_SIZE}")
print(f"Top-k: {TOP_K}")
print(f"Retrieval weight: {RETRIEVAL_WEIGHT}")

## 2. 데이터 로드

In [None]:
from data.dataloader import ReadPhaseDataset

# 샘플 QA pairs
sample_qa_pairs = [
    {
        "question": "What is the capital of France?",
        "answer": "Paris",
        "gold_doc_ids": [0],
    },
    {
        "question": "Who wrote Romeo and Juliet?",
        "answer": "William Shakespeare",
        "gold_doc_ids": [1],
    },
    {
        "question": "What did Albert Einstein develop?",
        "answer": "the theory of relativity",
        "gold_doc_ids": [3],
    },
]

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model.llm_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Dataset
read_dataset = ReadPhaseDataset(sample_qa_pairs, tokenizer)
read_loader = DataLoader(read_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"QA pairs: {len(sample_qa_pairs)}")

## 3. Read Phase Trainer

In [None]:
class ReadPhaseTrainer:
    """
    Read Phase Training
    
    목표: max log P(answer | q, z_selected; θ)
    
    Loss = L_gen + λ * L_retrieval
    - L_gen: answer generation loss (cross-entropy)
    - L_retrieval: gold doc을 top-k에 포함하도록 하는 loss
    """
    
    def __init__(
        self,
        model,
        tokenizer,
        lr_z: float = 1e-4,
        lr_lora: float = 5e-5,
        lr_selector: float = 1e-4,
        retrieval_weight: float = 0.1,
        top_k: int = 5,
        device: str = "cuda",
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.retrieval_weight = retrieval_weight
        self.top_k = top_k
        
        # Optimizer 구성
        param_groups = [
            {"params": [model.doc_vectors], "lr": lr_z, "name": "z"},
            {"params": model.z_to_embedding.parameters(), "lr": lr_lora, "name": "projection"},
            {"params": model.selector.parameters(), "lr": lr_selector, "name": "selector"},
            {"params": model.query_encoder.parameters(), "lr": lr_lora * 0.1, "name": "query_encoder"},
        ]
        
        # LoRA params
        lora_params = [p for n, p in model.llm.named_parameters() if "lora" in n.lower()]
        param_groups.append({"params": lora_params, "lr": lr_lora, "name": "lora"})
        
        self.optimizer = torch.optim.AdamW(param_groups)
    
    def compute_retrieval_loss(self, scores, gold_doc_ids):
        """
        Retrieval loss: gold doc이 높은 score를 갖도록
        
        Args:
            scores: [B, num_docs] - selection scores
            gold_doc_ids: [B, num_gold] - gold document indices
        """
        batch_size = scores.size(0)
        num_docs = scores.size(1)
        
        # Gold mask 생성
        gold_mask = torch.zeros_like(scores)  # [B, num_docs]
        for i, gold_ids in enumerate(gold_doc_ids):
            for gid in gold_ids:
                if gid < num_docs:
                    gold_mask[i, gid] = 1.0
        
        # Contrastive loss: gold docs vs non-gold docs
        log_probs = F.log_softmax(scores, dim=-1)
        
        # Gold에 해당하는 log prob의 평균
        gold_log_probs = (log_probs * gold_mask).sum(dim=-1)
        num_gold = gold_mask.sum(dim=-1).clamp(min=1)
        loss = -gold_log_probs / num_gold
        
        return loss.mean()
    
    def train_step(self, batch):
        """단일 training step"""
        self.model.train()
        
        query_ids = batch["query_ids"].to(self.device)
        query_mask = batch["query_mask"].to(self.device)
        answer_ids = batch["answer_ids"].to(self.device)
        answer_mask = batch["answer_mask"].to(self.device)
        gold_doc_ids = batch["gold_doc_ids"]  # list of tensors
        
        # Selection
        selected_ids, scores = self.model.select_documents(
            query_ids, query_mask, k=self.top_k
        )
        
        # Generation loss
        gen_loss = self.model(
            query_ids=query_ids,
            doc_indices=selected_ids,
            answer_ids=answer_ids,
            query_attention_mask=query_mask,
            answer_attention_mask=answer_mask,
        )
        
        # Retrieval loss
        retrieval_loss = self.compute_retrieval_loss(scores, gold_doc_ids)
        
        # Total loss
        loss = gen_loss + self.retrieval_weight * retrieval_loss
        
        # Backward
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return {
            "total_loss": loss.item(),
            "gen_loss": gen_loss.item(),
            "retrieval_loss": retrieval_loss.item(),
        }
    
    def train_epoch(self, dataloader, epoch: int):
        """한 epoch 학습"""
        total_loss = 0
        total_gen_loss = 0
        total_ret_loss = 0
        num_batches = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
        for batch in pbar:
            losses = self.train_step(batch)
            total_loss += losses["total_loss"]
            total_gen_loss += losses["gen_loss"]
            total_ret_loss += losses["retrieval_loss"]
            num_batches += 1
            pbar.set_postfix({
                "total": f"{losses['total_loss']:.4f}",
                "gen": f"{losses['gen_loss']:.4f}",
                "ret": f"{losses['retrieval_loss']:.4f}",
            })
        
        return {
            "total_loss": total_loss / max(num_batches, 1),
            "gen_loss": total_gen_loss / max(num_batches, 1),
            "retrieval_loss": total_ret_loss / max(num_batches, 1),
        }

print("ReadPhaseTrainer defined")

## 4. Training Loop

In [None]:
# Write phase checkpoint 로드 후 Read phase 학습
# write_checkpoint = PROJECT_ROOT / "checkpoints" / "phase1" / "write_final.pt"

# model = ParametricQA(...)
# load_checkpoint(model, write_checkpoint)

# trainer = ReadPhaseTrainer(
#     model=model,
#     tokenizer=tokenizer,
#     lr_z=read_config.lr_z,
#     lr_lora=read_config.lr_lora,
#     retrieval_weight=RETRIEVAL_WEIGHT,
#     top_k=TOP_K,
#     device=device,
# )

# history = []
# for epoch in range(NUM_EPOCHS):
#     metrics = trainer.train_epoch(read_loader, epoch + 1)
#     history.append(metrics)
#     print(f"Epoch {epoch + 1}: {metrics}")

## 5. QA 평가

In [None]:
from evaluation.metrics import compute_em, compute_f1, compute_recall_at_k

def evaluate_qa(model, qa_pairs, tokenizer, top_k=5, device="cuda"):
    """
    QA 성능 평가
    
    Metrics:
    - EM (Exact Match)
    - F1
    - Recall@K (retrieval)
    """
    model.eval()
    
    all_em = []
    all_f1 = []
    all_recall = []
    
    for item in tqdm(qa_pairs, desc="Evaluating"):
        question = item["question"]
        answer = item["answer"]
        gold_doc_ids = item["gold_doc_ids"]
        
        # Tokenize
        q_encoded = tokenizer(
            question, max_length=128, truncation=True,
            padding="max_length", return_tensors="pt",
        )
        query_ids = q_encoded["input_ids"].to(device)
        query_mask = q_encoded["attention_mask"].to(device)
        
        with torch.no_grad():
            # Selection
            selected_ids, scores = model.select_documents(query_ids, query_mask, k=top_k)
            selected_list = selected_ids[0].cpu().tolist()
            
            # Generation
            generated_ids = model.generate(
                query_ids=query_ids,
                doc_indices=selected_ids,
                query_attention_mask=query_mask,
                max_new_tokens=64,
            )
            prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        # Metrics
        em = compute_em(prediction, answer)
        f1 = compute_f1(prediction, answer)
        recall = compute_recall_at_k(selected_list, gold_doc_ids, k=top_k)
        
        all_em.append(em)
        all_f1.append(f1)
        all_recall.append(recall)
    
    return {
        "EM": sum(all_em) / len(all_em) * 100,
        "F1": sum(all_f1) / len(all_f1) * 100,
        f"Recall@{top_k}": sum(all_recall) / len(all_recall) * 100,
    }

# 평가 실행
# metrics = evaluate_qa(model, sample_qa_pairs, tokenizer, top_k=TOP_K, device=device)
# print(f"Evaluation results: {metrics}")

## 6. 예측 결과 분석

In [None]:
def analyze_predictions(model, qa_pairs, tokenizer, top_k=5, device="cuda", num_samples=5):
    """
    예측 결과 상세 분석
    """
    model.eval()
    results = []
    
    for item in qa_pairs[:num_samples]:
        question = item["question"]
        answer = item["answer"]
        gold_doc_ids = item["gold_doc_ids"]
        
        # Tokenize
        q_encoded = tokenizer(
            question, max_length=128, truncation=True,
            padding="max_length", return_tensors="pt",
        )
        query_ids = q_encoded["input_ids"].to(device)
        query_mask = q_encoded["attention_mask"].to(device)
        
        with torch.no_grad():
            selected_ids, scores = model.select_documents(query_ids, query_mask, k=top_k)
            selected_list = selected_ids[0].cpu().tolist()
            top_scores = scores[0].topk(top_k).values.cpu().tolist()
            
            generated_ids = model.generate(
                query_ids=query_ids,
                doc_indices=selected_ids,
                query_attention_mask=query_mask,
                max_new_tokens=64,
            )
            prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        results.append({
            "question": question,
            "gold_answer": answer,
            "prediction": prediction,
            "gold_doc_ids": gold_doc_ids,
            "selected_ids": selected_list,
            "scores": top_scores,
            "correct": compute_em(prediction, answer) > 0,
        })
    
    return results

# 분석 실행
# results = analyze_predictions(model, sample_qa_pairs, tokenizer, top_k=TOP_K, device=device)
# for r in results:
#     print(f"\nQ: {r['question']}")
#     print(f"A: {r['gold_answer']}")
#     print(f"P: {r['prediction']}")
#     print(f"Selected: {r['selected_ids']} (gold: {r['gold_doc_ids']})")
#     print(f"Correct: {r['correct']}")

## 7. Checkpoint 저장

In [None]:
def save_full_checkpoint(model, optimizer, epoch, metrics, save_path):
    """전체 체크포인트 저장 (Read phase 후)"""
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    
    checkpoint = {
        "epoch": epoch,
        "metrics": metrics,
        "doc_vectors": model.doc_vectors.data.cpu(),
        "z_to_embedding_state_dict": model.z_to_embedding.state_dict(),
        "selector_state_dict": model.selector.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }
    
    # LoRA weights
    model.llm.save_pretrained(save_path.parent / "lora_weights")
    
    # Query encoder
    torch.save(
        model.query_encoder.state_dict(),
        save_path.parent / "query_encoder.pt"
    )
    
    torch.save(checkpoint, save_path)
    print(f"Full checkpoint saved to {save_path}")

# 저장
# save_full_checkpoint(
#     model, trainer.optimizer, NUM_EPOCHS, metrics,
#     PROJECT_ROOT / "checkpoints" / "phase1" / "read_final.pt"
# )