In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import json
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer

# 경로 세팅
train_path = "/content/drive/MyDrive/nlpbook/downstream/korquad-v1/train.json"
val_path   = "/content/drive/MyDrive/nlpbook/downstream/korquad-v1/val.json"
model_ckpt = "klue/bert-base"      # 또는 KR-BERT 계열 등
tokenizer  = BertTokenizerFast.from_pretrained(model_ckpt)
model      = BertForQuestionAnswering.from_pretrained(model_ckpt)
# KLUE BERT는 max_position_embeddings = 512

# 1. KorQuAD json을 QA 쌍 리스트로 변환
def flatten_korquad(json_path):
    with open(json_path, encoding="utf-8") as f:
        raw = json.load(f)["data"]
    out = []
    for article in raw:
        title = article["title"]
        for para in article["paragraphs"]:
            context = para["context"]
            for qa in para["qas"]:
                for ans in qa["answers"]:
                    out.append({
                        "id": qa["id"],
                        "title": title,
                        "context": context,
                        "question": qa["question"],
                        "answers": {
                            "text": [ans["text"]],
                            "answer_start": [ans["answer_start"]]
                        }
                    })
    return out

train_samples = flatten_korquad(train_path)
val_samples   = flatten_korquad(val_path)

# 2. Dataset 클래스 정의
class KorQuADDataset(Dataset):
    def __init__(self, samples, tokenizer, max_length=512):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        enc = self.tokenizer(
            sample["question"],
            sample["context"],
            max_length=self.max_length,
            truncation="only_second",
            padding="max_length",
            return_offsets_mapping=True,
            return_tensors="pt",
        )
        offset_mapping = enc.pop("offset_mapping")[0]
        answer = sample["answers"]["text"][0]
        start_char = sample["answers"]["answer_start"][0]
        end_char = start_char + len(answer)

        # Start/end token index 찾기
        start_token, end_token = 0, 0
        for i, (start, end) in enumerate(offset_mapping):
            if start <= start_char < end:
                start_token = i
            if start < end_char <= end:
                end_token = i
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        enc["start_positions"] = torch.tensor(start_token)
        enc["end_positions"] = torch.tensor(end_token)
        return enc

# 3. Tokenizer 및 Dataset 생성
tokenizer = BertTokenizerFast.from_pretrained(model_ckpt)
train_dataset = KorQuADDataset(train_samples, tokenizer)
val_dataset   = KorQuADDataset(val_samples, tokenizer)

# 4. 모델 불러오기
model = BertForQuestionAnswering.from_pretrained(model_ckpt)

# 5. TrainingArguments 및 Trainer 정의
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="/content/qa-out",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir="/content/qa-logs",
    report_to="none",
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

# 6. 학습 실행
trainer.train()


tokenizer_config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/495k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/425 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/445M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,1.7092,0.986757
2,0.6628,1.335749


TrainOutput(global_step=1218, training_loss=1.0531780512266362, metrics={'train_runtime': 629.9069, 'train_samples_per_second': 7.734, 'train_steps_per_second': 1.934, 'total_flos': 1273037798817792.0, 'train_loss': 1.0531780512266362, 'epoch': 2.0})

In [None]:
import transformers, inspect, textwrap, importlib
from transformers import TrainingArguments

print("transformers ver:", transformers.__version__)
print("TrainingArguments 위치:", inspect.getfile(TrainingArguments))

# TrainingArguments 생성자 시그니처 살펴보기
sig = inspect.signature(TrainingArguments.__init__)
print("\n-- TrainingArguments.__init__ signature --")
print(textwrap.indent(str(sig), "  "))


transformers ver: 4.52.4
TrainingArguments 위치: /usr/local/lib/python3.11/dist-packages/transformers/training_args.py

-- TrainingArguments.__init__ signature --


In [None]:
# ── 1. evaluate() 로 검증(=> loss 포함) ──────────────────────────────
metrics = trainer.evaluate(eval_dataset=val_dataset)
# start/end logits 은 필요하면 predict_with_generate=True 대신 logits 반환 사용
start_logits, end_logits = trainer.predict(val_dataset).predictions

# ── 2. 예측 문자열 뽑기 ──────────────────────────────────────────────
pred_texts = []
for (s_log, e_log), sample in zip(zip(start_logits, end_logits), val_dataset):
    s = int(np.argmax(s_log)); e = int(np.argmax(e_log))
    if e < s: e = s
    pred_texts.append(tokenizer.decode(sample["input_ids"][s:e+1], skip_special_tokens=True).strip())

# ── 3. 정답·지표 계산 (EM/F1) ────────────────────────────────────────
gold_texts = [ex["answers"]["text"][0] for ex in val_samples]
EM  = np.mean([exact_match(p, g) for p, g in zip(pred_texts, gold_texts)])
F1  = np.mean([f1_squad (p, g) for p, g in zip(pred_texts, gold_texts)])

loss_val = metrics.get("eval_loss", float("nan"))
print(f"📊 Validation | Loss={loss_val:.4f}  EM={EM:.4f}  F1={F1:.4f}")


📊 Validation | Loss=1.3357  EM=0.6429  F1=0.7192


In [None]:
# ╔═════════════════╗
# ║  🟢  Cell 8     ║  사용자 질문 인터랙티브 예측
# ╚═════════════════╝
import textwrap, torch

# ── 1) 지문을 미리 지정 (원하면 다른 동화로 교체) ───────────
fixed_context = textwrap.dedent("""
    옛날 어느 마을에 한 할머니가 살았어요.
    할머니는 알록달록 예쁜 꽃들을 팔아서 살림을 꾸려 나갔지요.
    그러던 어느 날 할머니가 꽃밭에 물을 주고 있는데,
    갑자기 하늘에서 주먹만 한 우박이 떨어졌어요.
    “에구머니!” 깜짝 놀란 할머니는 우박을 피해 집 안으로 뛰어 들어갔어요.
""").strip()

print("\n📖  지문(context)")
print("-" * 40)
print(fixed_context)
print("-" * 40)

# ── 2) 사용자 질문 입력 ──────────────────────────────
user_q = input("❓  질문을 입력하세요: ").strip()
if not user_q:
    print("➡️  질문이 비어 있습니다.")
else:
    # ── 3) 예측 함수 ────────────────────────────────
    def qa_infer(question: str, context: str, max_len: int = 512) -> str:
        inputs = tokenizer(
            question, context,
            truncation="only_second", max_length=max_len,
            return_offsets_mapping=False, return_tensors="pt"
        ).to(model.device)
        with torch.no_grad():
            out = model(**inputs)
        s = int(torch.argmax(out.start_logits))
        e = int(torch.argmax(out.end_logits))
        if e < s: e = s
        answer_ids = inputs["input_ids"][0][s:e+1]
        return tokenizer.decode(answer_ids, skip_special_tokens=True).strip()

    # ── 4) 결과 출력 ────────────────────────────────
    pred_ans = qa_infer(user_q, fixed_context)
    print("\n📝  예측 답변:", pred_ans)



📖  지문(context)
----------------------------------------
옛날 어느 마을에 한 할머니가 살았어요.
할머니는 알록달록 예쁜 꽃들을 팔아서 살림을 꾸려 나갔지요.
그러던 어느 날 할머니가 꽃밭에 물을 주고 있는데,
갑자기 하늘에서 주먹만 한 우박이 떨어졌어요.
“에구머니!” 깜짝 놀란 할머니는 우박을 피해 집 안으로 뛰어 들어갔어요.
----------------------------------------
❓  질문을 입력하세요: 할머니는 왜 바다로 갔나요?

📝  예측 답변: 할머니는 왜 바다로 갔나요? 옛날 어느 마을에 한 할머니가 살았어요. 할머니는 알록달록 예쁜 꽃들을 팔아서 살림을 꾸려 나갔지요.
