# 01. Data Preparation

데이터 다운로드, 전처리, DataLoader 구성

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path(".").resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))

import torch
from omegaconf import OmegaConf
import yaml

## 1. 데이터셋 다운로드

In [None]:
from datasets import load_dataset

# Natural Questions 다운로드
DATASET_NAME = "natural_questions"  # or "hotpot_qa", "trivia_qa"
SAVE_DIR = PROJECT_ROOT / "data" / "raw"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

print(f"Downloading {DATASET_NAME}...")

In [None]:
def download_natural_questions():
    """Natural Questions 데이터셋 다운로드"""
    # NQ는 매우 크므로 validation split만 사용 (POC용)
    dataset = load_dataset("natural_questions", "default", split="validation")
    
    # 간단한 전처리
    processed = []
    for item in dataset:
        question = item["question"]["text"]
        
        # Short answer 추출
        annotations = item["annotations"]
        if annotations["short_answers"][0]:
            start = annotations["short_answers"][0]["start_token"]
            end = annotations["short_answers"][0]["end_token"]
            doc_tokens = item["document"]["tokens"]["token"]
            answer = " ".join(doc_tokens[start:end])
            
            processed.append({
                "question": question,
                "answer": answer,
                "document": " ".join(doc_tokens[:512]),  # 앞부분만
            })
    
    return processed

# 다운로드 (시간이 오래 걸릴 수 있음)
# nq_data = download_natural_questions()
# print(f"Downloaded {len(nq_data)} samples")

In [None]:
def download_hotpot_qa():
    """HotpotQA 데이터셋 다운로드"""
    dataset = load_dataset("hotpot_qa", "fullwiki", split="validation")
    
    processed = []
    for item in dataset:
        processed.append({
            "question": item["question"],
            "answer": item["answer"],
            "supporting_facts": item["supporting_facts"],
            "context": item["context"],
            "type": item["type"],  # bridge or comparison
            "level": item["level"],
        })
    
    return processed

# hotpot_data = download_hotpot_qa()
# print(f"Downloaded {len(hotpot_data)} samples")

## 2. 데이터 전처리

In [None]:
from transformers import AutoTokenizer

# Tokenizer 로드
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Vocab size: {tokenizer.vocab_size}")

In [None]:
def preprocess_for_parametric_qa(
    raw_data: list,
    tokenizer,
    max_doc_length: int = 512,
    max_query_length: int = 128,
    max_answer_length: int = 64,
):
    """
    Parametric QA를 위한 전처리
    
    Returns:
        corpus: {doc_id: document_text}
        qa_pairs: [{question, answer, gold_doc_ids}]
    """
    corpus = {}
    qa_pairs = []
    doc_id = 0
    
    for item in raw_data:
        # Document 저장
        doc_text = item.get("document", "")
        if doc_text:
            corpus[doc_id] = doc_text[:2000]  # 길이 제한
            
            # QA pair
            qa_pairs.append({
                "question": item["question"],
                "answer": item["answer"],
                "gold_doc_ids": [doc_id],
            })
            
            doc_id += 1
    
    return corpus, qa_pairs

# 예시 데이터로 테스트
sample_data = [
    {
        "question": "What is the capital of France?",
        "answer": "Paris",
        "document": "France is a country in Western Europe. Paris is the capital and largest city of France.",
    },
    {
        "question": "Who wrote Romeo and Juliet?",
        "answer": "William Shakespeare",
        "document": "Romeo and Juliet is a tragedy written by William Shakespeare early in his career.",
    },
]

corpus, qa_pairs = preprocess_for_parametric_qa(sample_data, tokenizer)
print(f"Corpus size: {len(corpus)}")
print(f"QA pairs: {len(qa_pairs)}")
print(f"Sample: {qa_pairs[0]}")

## 3. DataLoader 구성

In [None]:
from torch.utils.data import Dataset, DataLoader

class WritePhaseDataset(Dataset):
    """Write Phase용 Dataset: Document reconstruction"""
    
    def __init__(self, corpus: dict, tokenizer, max_length: int = 512):
        self.corpus = corpus
        self.doc_ids = list(corpus.keys())
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.doc_ids)
    
    def __getitem__(self, idx):
        doc_id = self.doc_ids[idx]
        doc_text = self.corpus[doc_id]
        
        encoded = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        
        return {
            "doc_id": doc_id,
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
        }

# 테스트
write_dataset = WritePhaseDataset(corpus, tokenizer)
print(f"Write dataset size: {len(write_dataset)}")
print(f"Sample: {write_dataset[0]}")

In [None]:
class ReadPhaseDataset(Dataset):
    """Read Phase용 Dataset: QA generation"""
    
    def __init__(
        self,
        qa_pairs: list,
        tokenizer,
        max_query_length: int = 128,
        max_answer_length: int = 64,
    ):
        self.qa_pairs = qa_pairs
        self.tokenizer = tokenizer
        self.max_query_length = max_query_length
        self.max_answer_length = max_answer_length
    
    def __len__(self):
        return len(self.qa_pairs)
    
    def __getitem__(self, idx):
        item = self.qa_pairs[idx]
        
        # Query encoding
        query_encoded = self.tokenizer(
            item["question"],
            max_length=self.max_query_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        
        # Answer encoding
        answer_encoded = self.tokenizer(
            item["answer"],
            max_length=self.max_answer_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        
        return {
            "query_ids": query_encoded["input_ids"].squeeze(0),
            "query_mask": query_encoded["attention_mask"].squeeze(0),
            "answer_ids": answer_encoded["input_ids"].squeeze(0),
            "answer_mask": answer_encoded["attention_mask"].squeeze(0),
            "gold_doc_ids": torch.tensor(item["gold_doc_ids"]),
        }

# 테스트
read_dataset = ReadPhaseDataset(qa_pairs, tokenizer)
print(f"Read dataset size: {len(read_dataset)}")
print(f"Sample keys: {read_dataset[0].keys()}")

In [None]:
# DataLoader 생성
write_loader = DataLoader(
    write_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

read_loader = DataLoader(
    read_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

# 배치 확인
for batch in write_loader:
    print("Write batch:")
    print(f"  doc_id: {batch['doc_id']}")
    print(f"  input_ids shape: {batch['input_ids'].shape}")
    break

for batch in read_loader:
    print("\nRead batch:")
    print(f"  query_ids shape: {batch['query_ids'].shape}")
    print(f"  answer_ids shape: {batch['answer_ids'].shape}")
    break

## 4. 전처리된 데이터 저장

In [None]:
import json

def save_processed_data(corpus, qa_pairs, save_dir):
    """전처리된 데이터 저장"""
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Corpus 저장
    with open(save_dir / "corpus.json", "w") as f:
        json.dump(corpus, f, ensure_ascii=False, indent=2)
    
    # QA pairs 저장
    with open(save_dir / "qa_pairs.json", "w") as f:
        json.dump(qa_pairs, f, ensure_ascii=False, indent=2)
    
    print(f"Saved to {save_dir}")

# 저장
# save_processed_data(corpus, qa_pairs, PROJECT_ROOT / "data" / "processed")

In [None]:
def load_processed_data(load_dir):
    """전처리된 데이터 로드"""
    load_dir = Path(load_dir)
    
    with open(load_dir / "corpus.json") as f:
        corpus = json.load(f)
    
    with open(load_dir / "qa_pairs.json") as f:
        qa_pairs = json.load(f)
    
    # corpus keys를 int로 변환
    corpus = {int(k): v for k, v in corpus.items()}
    
    return corpus, qa_pairs

# 로드
# corpus, qa_pairs = load_processed_data(PROJECT_ROOT / "data" / "processed")