In [None]:
# ✅ 빠르게 동작하도록 조정한 GPT-style Chatbot training script (PyTorch / fallback bigram)
# (이미지의 [QUEST 평가기준]을 코드에 반영하여 1~5 항목별 주석 및 점검 출력을 추가)
# -------------------------------------------------------------------------------------------
# [QUEST 1] Transformer와 비교해 변경/차별된 부분을 텍스트 블록으로 서술
# [QUEST 2] 입력 형태(디코더-온리 LM)에 맞춘 전처리 수행 및 검증 출력
# [QUEST 3] GPT 논문(https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf)
#           기준의 입력 블록(token + position, causal mask, pre-LN, residual) 구성 및 점검
# [QUEST 4] GPT 모델 정상 구성 확인(model summary, 학습 루프/fit 로그)
# [QUEST 5] 주어진 입력에 따른 생성 출력 확인(generation 데모)
# -------------------------------------------------------------------------------------------

# [PLUS] GPU 메모리 파편화 완화(파이토치 import '전'에 설정)
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import re
import math
import random
import urllib.request
import pandas as pd
import numpy as np
import sentencepiece as spm

# ========================= [QUEST 1] 변경/차별 사항 서술 블록 =========================
CHANGELOG_QUEST1 = """
[QUEST 1] Baseline Transformer 대비 변경·차별 포인트(요약)
- 인코더/디코더 완전 Transformer가 아닌, GPT 스타일 '디코더-온리(Decoder-only)' 언어모델로 단순화.
- [개선] 질문 뒤에 전용 구분 토큰 <ANS>(SEP)를 도입하여 Q/ A 경계를 명확히 함.
  · 학습: [BOS] Q [SEP] A [EOS]  (Q구간은 loss에서 무시, A만 예측)
  · 추론: [BOS] Q [SEP] … (A만 생성) → Q 복창/에코링 감소
- [개선] 디코딩 제어: temperature + nucleus(top_p) + repetition_penalty + no_repeat_ngram + min_length
  → 의미불명 반복과 횡설수설 억제, “말이 되는” 답변 유도
- [개선] label smoothing(0.1) + cosine 스케줄(+warmup) → 안정적 수렴/일반화 향상.
  ※ 검증은 smoothing=0.0으로 순수 NLL을 측정해 ppl 왜곡을 방지.
- SentencePiece BPE로 한국어 토큰화, 실제 vocab size 동기화.
- 속도 최적화: 작은 모델(d_model=256, n_layers=6, n_heads=8), AMP, 자동 배치, 샘플/길이 제한.
"""
print(CHANGELOG_QUEST1.strip())
# ================================================================================

# ---------------------------- Try import PyTorch ----------------------------
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torch.optim import AdamW
    from torch.optim.lr_scheduler import CosineAnnealingLR
    HAS_TORCH = True
except Exception:
    HAS_TORCH = False
    torch = None
    nn = object
    F = None
    Dataset = object
    DataLoader = None
    AdamW = None
    CosineAnnealingLR = None
    print("[경고] PyTorch 미탑재 환경. Numpy Bigram LM 모드로 실행")

# ---------------------------- [PLUS] GPU 캐시 정리 유틸 ----------------------------
def _free_cuda():
    try:
        import gc
        gc.collect()
        if HAS_TORCH and torch.cuda.is_available():
            torch.cuda.empty_cache()
            if hasattr(torch.cuda, "ipc_collect"):
                torch.cuda.ipc_collect()
    except Exception:
        pass

# ---------------------------- Config ----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

MAX_SAMPLES = 100000
MAX_LEN = 64
VOCAB_SIZE_FOR_SPM = 8000

DEFAULT_D_MODEL = 256
DEFAULT_N_LAYERS = 6
DEFAULT_N_HEADS = 8

SPM_PREFIX = os.path.expanduser("~/work/transformer_chatbot/data/spm_kr")
DATA_DIR = os.path.expanduser("~/work/transformer_chatbot/data")
os.makedirs(DATA_DIR, exist_ok=True)
DATA_PATH = os.path.join(DATA_DIR, "ChatbotData.csv")

# (NEW) label smoothing 설정을 학습/검증으로 분리
TRAIN_SMOOTH = 0.10
EVAL_SMOOTH  = 0.00   # ← ppl 산출 시엔 smoothing=0.0로 순수 NLL 측정

# ---------------------------- [PLUS] 생성 하이퍼(자연스러움·반복 억제) ----------------------------
GEN_TEMPERATURE = 0.7      # 0.8→0.7 : 엉뚱 샘플 감소
GEN_TOP_P = 0.88           # 0.90→0.88
GEN_TOP_K = 60             # tail 컷 보완
GEN_REP_PENALTY = 1.22     # 반복 억제 (1.15~1.25 권장)
GEN_NGRAM_NO_REPEAT = 3    # “세요.세요”류 연속 억제
GEN_MIN_NEW_TOKENS = 12    # 빈/짧은 답변 방지
BAN_UNK_AT_GEN = True      # UNK 금지
ECHO_PENALTY = 0.6         # 질문 토큰 에코 억제(로짓 감소)

# ---------------------------- Device ----------------------------
if HAS_TORCH:
    torch.manual_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    USE_CUDA = device.type == "cuda"
    if USE_CUDA:
        torch.backends.cudnn.benchmark = True
else:
    device = None
    USE_CUDA = False

# ---------------------------- Download data if needed ----------------------------
if not os.path.exists(DATA_PATH):
    print("Downloading ChatbotData.csv ...")
    urllib.request.urlretrieve(
        "https://github.com/songys/Chatbot_data/raw/master/ChatbotData.csv",
        DATA_PATH
    )

# ---------------------------- (NEW) 간단 텍스트 정규화 ----------------------------
_url_pat = re.compile(r"https?://\S+")
_space_pat = re.compile(r"\s+")
_lol_pat = re.compile(r"[ㅋㅎ]{2,}")
_rep_punct_pat = re.compile(r"([!?~.])\1{1,}")

def normalize_text(s: str) -> str:
    s = str(s)
    s = _url_pat.sub("", s)                 # URL 제거
    s = _lol_pat.sub("ㅎㅎ", s)              # 과도한 ㅋ/ㅎ 축약
    s = _rep_punct_pat.sub(r"\1\1", s)      # 과잉 반복 기호 축약
    s = s.replace("\u200b", "")             # zero-width 삭제
    s = _space_pat.sub(" ", s).strip()      # 공백 정리
    return s

# ---------------------------- Load data & build corpus ----------------------------
raw = pd.read_csv(DATA_PATH).dropna(subset=["Q", "A"])[:MAX_SAMPLES]
raw["Q"] = raw["Q"].astype(str).map(normalize_text)
raw["A"] = raw["A"].astype(str).map(normalize_text)

corpus_path = os.path.join(DATA_DIR, "corpus.txt")
with open(corpus_path, "w", encoding="utf-8") as f:
    for q, a in zip(raw["Q"], raw["A"]):
        if q and a:
            f.write(q + "\n")
            f.write(a + "\n")

# ---------------------------- Train SentencePiece (if needed) ----------------------------
if not (os.path.exists(SPM_PREFIX + ".model") and os.path.exists(SPM_PREFIX + ".vocab")):
    spm.SentencePieceTrainer.Train(
        input=corpus_path,
        model_prefix=SPM_PREFIX,
        vocab_size=VOCAB_SIZE_FOR_SPM,
        character_coverage=1.0,
        model_type="bpe",
        bos_id=1,   # BOS
        eos_id=2,   # EOS
        pad_id=0,   # PAD
        unk_id=3    # UNK
    )

sp = spm.SentencePieceProcessor()
sp.Load(SPM_PREFIX + ".model")
SPM_VOCAB_SIZE = sp.GetPieceSize()
PAD_ID, BOS_ID, EOS_ID, UNK_ID = 0, 1, 2, 3

# === (NEW) 전용 구분 토큰: 모델 임베딩에서만 추가 ===
SEP_ID = SPM_VOCAB_SIZE                 # <ANS> 토큰 (SentencePiece 바깥에서 새 id 할당)
VOCAB_PLUS = SPM_VOCAB_SIZE + 1         # 임베딩/LM헤드는 +1 크기

def _clip_and_pad(ids, max_len):
    if len(ids) < max_len:
        return ids + [PAD_ID] * (max_len - len(ids))
    return ids[:max_len]

# ========================= [QUEST 2] 전처리(디코더-온리용) 및 점검 =========================
def encode_pair(q_text, a_text, max_len=MAX_LEN):
    """
    [QUEST 2] Decoder-only LM 학습을 위한 한 줄 시퀀스 구성:
    seq = [BOS] + Q + [SEP] + A + [EOS]
    labels: [SEP] 이전은 -100(무시), A 구간만 next-token 예측 대상으로 남김.
    PAD/EOS 위치는 -100 처리.
    """
    q_ids = sp.EncodeAsIds(normalize_text(q_text))
    a_ids = sp.EncodeAsIds(normalize_text(a_text))
    seq = [BOS_ID] + q_ids + [SEP_ID] + a_ids + [EOS_ID]
    seq = _clip_and_pad(seq, max_len)
    # SEP까지는 전부 무시
    try:
        sep_pos = seq.index(SEP_ID) + 1  # SEP 포함 지점까지 무시
    except ValueError:
        sep_pos = min(1 + len(q_ids), max_len)
    labels = seq[1:] + [PAD_ID]
    labels = labels[:max_len]
    for i in range(sep_pos):
        labels[i] = -100
    for i in range(max_len):
        if seq[i] == PAD_ID:
            labels[i] = -100
    labels[-1] = -100
    has_target = any(l >= 0 for l in labels)
    return seq, labels, has_target

def preview_preprocessing(df, n=2):
    """[QUEST 2] 전처리 결과 간단 점검 출력"""
    print("\n[QUEST 2] 전처리 점검 예시:")
    for i, (q, a) in enumerate(zip(df["Q"].astype(str), df["A"].astype(str))):
        if i >= n: break
        seq, labels, ok = encode_pair(q, a)
        print(f"- 샘플 {i+1}")
        print("  Q:", q)
        print("  A:", a)
        print("  seq(ids)[:24]:", seq[:24])
        print("  labels[:24]:  ", labels[:24])
        print("  유효타겟존재?:", ok)
preview_preprocessing(raw)

# ---------------------------- Dataset ----------------------------
class ChatbotQADataset(Dataset if HAS_TORCH else object):
    def __init__(self, df, max_len=MAX_LEN):
        qs = df["Q"].astype(str).tolist()
        as_ = df["A"].astype(str).tolist()
        self.samples = []
        for q, a in zip(qs, as_):
            seq, labels, ok = encode_pair(q, a, max_len)
            if ok:
                self.samples.append((seq, labels))
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x_ids, y_ids = self.samples[idx]
        if HAS_TORCH:
            return torch.tensor(x_ids, dtype=torch.long), torch.tensor(y_ids, dtype=torch.long)
        else:
            return np.array(x_ids, dtype=np.int64), np.array(y_ids, dtype=np.int64)

# ---------------------------- Model (GPT-style, Pre-LN) ----------------------------
if HAS_TORCH:
    class GELU(nn.Module):
        def forward(self, x):
            return F.gelu(x)

    class GPTBlock(nn.Module):
        """
        [QUEST 3] GPT 논문 입력 블록 규격의 핵심:
        - Pre-LN: 각 sublayer 앞에 LayerNorm
        - Masked self-attention (causal)
        - MLP(FFN) + Residual 연결
        참고: Radford et al., "Improving Language Understanding by Generative Pre-Training" (GPT-1)
              https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf
        """
        def __init__(self, d_model, n_heads, mlp_ratio=4.0, dropout=0.1):
            super().__init__()
            self.ln1 = nn.LayerNorm(d_model)
            self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
            self.ln2 = nn.LayerNorm(d_model)
            self.mlp = nn.Sequential(
                nn.Linear(d_model, int(d_model * mlp_ratio)),
                GELU(),
                nn.Linear(int(d_model * mlp_ratio), d_model),
                nn.Dropout(dropout)
            )
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, attn_mask=None, key_padding_mask=None):
            h = self.ln1(x)
            attn_out, _ = self.attn(
                h, h, h,
                attn_mask=attn_mask,
                key_padding_mask=key_padding_mask,
                need_weights=False
            )
            x = x + self.dropout(attn_out)
            h2 = self.ln2(x)
            x = x + self.mlp(h2)
            return x

    class GPTSmall(nn.Module):
        """
        [QUEST 3] 토큰 임베딩 + 절대 위치 임베딩(learned) 합성 → 드롭아웃 → 블록 스택 → 최종 LN → LM 헤드
        causal mask를 forward에서 생성해 상삼각을 -inf로 차단.
        """
        def __init__(self, vocab_size, d_model=DEFAULT_D_MODEL, n_layers=DEFAULT_N_LAYERS,
                     n_heads=DEFAULT_N_HEADS, max_len=MAX_LEN, dropout=0.1, tie_weights=True):
            super().__init__()
            self.max_len = max_len
            self.tok_emb = nn.Embedding(vocab_size, d_model)  # vocab_size = VOCAB_PLUS
            self.pos_emb = nn.Embedding(max_len, d_model)      # 절대 위치 임베딩 (learned)
            self.drop = nn.Dropout(dropout)
            self.blocks = nn.ModuleList([GPTBlock(d_model, n_heads, dropout=dropout) for _ in range(n_layers)])
            self.ln_f = nn.LayerNorm(d_model)
            self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
            if tie_weights:
                self.lm_head.weight = self.tok_emb.weight

        def forward(self, idx, labels=None, label_smoothing=0.1):
            B, T = idx.size()
            assert T <= self.max_len, "시퀀스 길이가 max_len을 초과했습니다."
            pos = torch.arange(T, device=idx.device).unsqueeze(0)  # [1, T]
            x = self.tok_emb(idx) + self.pos_emb(pos)              # token + position
            x = self.drop(x)

            # [PLUS] causal mask: 상삼각 True=차단 (bool 마스크로 경고 제거)
            attn_mask = torch.triu(torch.ones((T, T), device=idx.device, dtype=torch.bool), diagonal=1)
            key_padding_mask = (idx == PAD_ID)  # bool

            for blk in self.blocks:
                x = blk(x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
            x = self.ln_f(x)
            logits = self.lm_head(x)
            loss = None
            if labels is not None:
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                    ignore_index=-100,
                    label_smoothing=label_smoothing
                )
            return logits, loss

    def model_summary(model):
        # [QUEST 4] model summary 출력
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("\n=== [QUEST 4] Model Summary ===")
        print(model)
        print(f"Total params: {total_params:,}")
        print(f"Trainable params: {trainable_params:,}")
        print("===============================\n")

# ---------------------------- Decode helper (SEP 제외) ----------------------------
# [PLUS] 후처리 강화: 짧은 조각 반복 접기, 공백/구두점 정리, 빈 문자열 방지
def _clean_text(s: str) -> str:
    s = s.replace("\u200b", "")
    s = re.sub(r"\s+", " ", s)

    # 구두점 과반복 줄이기 (???, !!! → ??, !!)
    s = re.sub(r"([!?~.])\1{2,}", r"\1\1", s)

    # 2~6글자 짧은 조각 반복을 1회로 접기 (예: "하세요하세요" → "하세요", "세요.세요" → "세요.")
    s = re.sub(r'(\S{2,6})(\s?\1){1,}', r'\1', s)

    # 공백·구두점 간격 정리
    s = re.sub(r"\s+([,.!?])", r"\1", s).strip()

    # 너무 길면 2문장까지만
    parts = re.split(r'(?<=[.!?])\s+', s)
    if len(parts) > 2:
        s = " ".join(parts[:2]).strip()
    return s

def decode_model_output(ids):
    """
    SentencePiece에 없는 SEP_ID 등을 제거하고, 가능하면 [SEP] 이후(A만) 디코드.
    [PLUS] UNK/PAD 제거 및 후처리 적용
    """
    if SEP_ID in ids:
        start = ids.index(SEP_ID) + 1
        tail = ids[start:]
    else:
        tail = ids
    filtered = [t for t in tail if (0 <= t < SPM_VOCAB_SIZE) and (t not in (BOS_ID, PAD_ID, EOS_ID, UNK_ID))]
    if len(filtered) == 0:
        return ""
    out = sp.DecodeIds(filtered)
    out = _clean_text(out)
    return out

# ---------------------------- Generate & chatbot test ----------------------------
if HAS_TORCH:
    @torch.no_grad()
    def _violates_no_repeat(out_ids, next_id, n=3):
        if n <= 1 or len(out_ids) < n-1: return False
        gram = out_ids[-(n-1):] + [next_id]
        for i in range(len(out_ids) - (n-1)):
            if out_ids[i:i+n] == gram:
                return True
        return False

    # [PLUS] 시작 토큰이 구두점/이상 토큰이면 재샘플
    def _is_bad_start(tok_id):
        piece = sp.IdToPiece(tok_id) if 0 <= tok_id < SPM_VOCAB_SIZE else ""
        return (piece in {".", ",", "!", "?", "…"} or piece.strip()=="")

    def _apply_top_k(logits, top_k):
        if top_k and top_k > 0:
            kth = torch.topk(logits, k=min(top_k, logits.size(-1))).values[:, -1]
            logits[logits < kth] = -float("inf")

    def _apply_top_p(logits, top_p):
        if top_p and 0 < top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True)
            probs_sorted = F.softmax(sorted_logits, dim=-1)
            cumprobs = torch.cumsum(probs_sorted, dim=-1)
            cutoff = (cumprobs > top_p).float().argmax(dim=-1)
            for b in range(logits.size(0)):
                logits[b, sorted_idx[b, cutoff[b]+1:]] = -float("inf")

    @torch.inference_mode()  # [PLUS] 메모리 절감
    def generate(
        model, prompt, max_new_tokens=64,
        temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, repetition_penalty=GEN_REP_PENALTY,
        no_repeat_ngram_size=GEN_NGRAM_NO_REPEAT, min_new_tokens=GEN_MIN_NEW_TOKENS, max_len=MAX_LEN,
        top_k=GEN_TOP_K
    ):
        # [BOS] Q [SEP] 로 시작(답변만 생성)
        q_ids = sp.EncodeAsIds(normalize_text(prompt))
        ids = [BOS_ID] + q_ids + [SEP_ID]
        out = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)

        banned = {PAD_ID, BOS_ID, SEP_ID}
        if BAN_UNK_AT_GEN:
            banned.add(UNK_ID)

        for step in range(max_new_tokens):
            logits, _ = model(out)
            logits = logits[:, -1, :]

            # 특수 토큰 금지
            for b in banned:
                logits[:, b] = -float("inf")

            # [PLUS] 질문 에코 억제: 질문 토큰의 로짓 감소
            if ECHO_PENALTY > 0:
                for tok in set(q_ids):
                    if 0 <= tok < logits.size(-1):
                        logits[0, tok] -= ECHO_PENALTY

            # repetition penalty
            if repetition_penalty and repetition_penalty != 1.0:
                for tok in set(out[0].tolist()):
                    logits[0, tok] /= repetition_penalty

            # temperature
            if temperature and temperature > 0:
                logits = logits / temperature

            # top-k / top-p
            _apply_top_k(logits, top_k)
            _apply_top_p(logits, top_p)

            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1).item()

            # [PLUS] 시작 토큰이 이상하면 대안 선택
            if step == 0 and _is_bad_start(next_id):
                for alt_id in torch.argsort(probs[0], descending=True).tolist():
                    if alt_id in banned or alt_id == EOS_ID: continue
                    if not _is_bad_start(alt_id):
                        next_id = alt_id
                        break

            # no-repeat n-gram & 최소 길이 전 EOS 금지
            out_list = out[0].tolist()
            if no_repeat_ngram_size and _violates_no_repeat(out_list, next_id, n=no_repeat_ngram_size):
                for alt_id in torch.argsort(probs[0], descending=True).tolist():
                    if alt_id == next_id: 
                        continue
                    if alt_id in banned or alt_id == EOS_ID:
                        continue
                    if not _violates_no_repeat(out_list, alt_id, n=no_repeat_ngram_size):
                        next_id = alt_id
                        break
            if (next_id == EOS_ID) and (step < min_new_tokens):
                for alt_id in torch.argsort(probs[0], descending=True).tolist():
                    if alt_id not in banned and alt_id != EOS_ID:
                        next_id = alt_id
                        break

            out = torch.cat([out, torch.tensor([[next_id]], device=device)], dim=1)
            if next_id == EOS_ID or out.size(1) >= max_len:
                break

        text = decode_model_output(out[0].tolist())

        # [PLUS] 빈 문자열 방어: 최소 보장 길이 없으면 안전한 기본 답변
        if not text or len(text) < 4:
            text = "음... 조금 더 자세히 말해줄래?"
        return text

    def chatbot_test(model):
        # [QUEST 5] 생성 결과 데모 출력
        prompts = [
            "안녕!",
            "오늘 날씨 어때?",
            "내일 힘이 날 말 한마디 해줘",
            "주말에 뭐하면 좋아?",
            "파이토치 설치가 안돼. 어떻게 해야 해?"
        ]
        print("\n[QUEST 5] Chatbot Test 시작")
        for p in prompts:
            ans = generate(model, p, max_new_tokens=64)
            print(f"[Q] {p}\n[A] {ans}\n")
        print("[QUEST 5] Chatbot Test 완료\n")

# ---------------------------- Data split / loaders ----------------------------
from sklearn.model_selection import train_test_split
train_df, valid_df = train_test_split(raw, test_size=0.02, random_state=SEED)

if HAS_TORCH:
    train_ds = ChatbotQADataset(train_df)
    valid_ds = ChatbotQADataset(valid_df)

    BATCH_SIZE = 32 if USE_CUDA else 8
    NUM_WORKERS = min(4, max(0, os.cpu_count() - 1))
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=USE_CUDA, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=USE_CUDA, num_workers=NUM_WORKERS)
else:
    # fallback bigram (학습 없이도 최소 생성 가능)
    counts = np.zeros((VOCAB_PLUS, VOCAB_PLUS), dtype=np.float64)
    for q, a in zip(raw["Q"].astype(str), raw["A"].astype(str)):
        seq = [BOS_ID] + sp.EncodeAsIds(normalize_text(q)) + [SEP_ID] + sp.EncodeAsIds(normalize_text(a)) + [EOS_ID]
        for x, y in zip(seq[:-1], seq[1:]):
            if 0 <= x < VOCAB_PLUS and 0 <= y < VOCAB_PLUS:
                counts[x, y] += 1
    probs = counts + 1.0
    probs /= probs.sum(axis=1, keepdims=True)

# ---------------------------- [QUEST 3] 입력 블록/마스크 검증 유틸 ----------------------------
if HAS_TORCH:
    def verify_input_block(model):
        print("[QUEST 3] 입력 블록/마스크 점검:")
        x = torch.randint(4, VOCAB_PLUS, (2, 16), device=device)  # [B=2, T=16]
        x[:, -2:] = PAD_ID  # 일부 PAD 삽입
        logits, loss = model(x, labels=torch.full_like(x, -100))
        B, T = x.size()
        assert logits.shape[:2] == (B, T), "logits shape 불일치"
        print(" - logits shape:", tuple(logits.shape))
        print(" - token/position embedding 합성 및 Pre-LN/Residual 적용 확인(코드 상 구현)")
        print(" - causal mask 상삼각 차단(bool) 적용(코드 상 구현)")
        print(" - PAD key_padding_mask 적용(코드 상 구현)\n")

# ---------------------------- Training loop with AMP + Cosine LR ----------------------------
if HAS_TORCH:
    print(f"Train samples: {len(train_df)} → usable: {len(train_ds)}")
    print(f"Valid samples: {len(valid_df)} → usable: {len(valid_ds)}")
    print(f"SPM vocab size (original): {SPM_VOCAB_SIZE}, +SEP → {VOCAB_PLUS}")

    # [PLUS] 모델 생성 전 캐시 정리 + OOM 시 CPU 폴백
    _free_cuda()
    try:
        model = GPTSmall(vocab_size=VOCAB_PLUS, d_model=DEFAULT_D_MODEL,
                         n_layers=DEFAULT_N_LAYERS, n_heads=DEFAULT_N_HEADS).to(device)
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print("[경고] CUDA OOM 감지 → CPU로 자동 전환합니다.")
            _free_cuda()
            device = torch.device("cpu")
            USE_CUDA = False
            train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, pin_memory=False, num_workers=0)
            valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=False, pin_memory=False, num_workers=0)
            model = GPTSmall(vocab_size=VOCAB_PLUS, d_model=DEFAULT_D_MODEL,
                             n_layers=DEFAULT_N_LAYERS, n_heads=DEFAULT_N_HEADS).to(device)
        else:
            raise

    # [QUEST 3] 입력 블록 구성/마스크 검증
    verify_input_block(model)

    # [QUEST 4] 모델 요약 출력
    model_summary(model)

    optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
    EPOCHS = 80

    # [PLUS] 최신 AMP API (경고 제거)
    scaler = torch.amp.GradScaler('cuda', enabled=USE_CUDA)

    # Warmup + Cosine
    warmup_steps = max(100, len(train_loader))
    total_steps = EPOCHS * max(1, len(train_loader))
    scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps, eta_min=3e-5)

    best_val = float("inf")
    patience, bad = 3, 0

    global_step = 0
    for epoch in range(1, EPOCHS + 1):
        model.train()
        total_loss = 0.0
        for x, labels in train_loader:
            x, labels = x.to(device), labels.to(device)
            optimizer.zero_grad(set_to_none=True)
            if USE_CUDA:
                with torch.amp.autocast('cuda'):
                    _, loss = model(x, labels, label_smoothing=TRAIN_SMOOTH)
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                _, loss = model(x, labels, label_smoothing=TRAIN_SMOOTH)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            global_step += 1
            if global_step > warmup_steps:
                scheduler.step()
            total_loss += float(loss.item()) * x.size(0)

        train_loss = total_loss / max(1, len(train_ds))

        model.eval()
        total_loss = 0.0
        with torch.no_grad():
            for x, labels in valid_loader:
                x, labels = x.to(device), labels.to(device)
                # ⚠️ 검증은 label_smoothing=0.0 로 순수 NLL 측정 → ppl 하락
                if USE_CUDA:
                    with torch.amp.autocast('cuda'):
                        _, loss = model(x, labels, label_smoothing=EVAL_SMOOTH)
                else:
                    _, loss = model(x, labels, label_smoothing=EVAL_SMOOTH)
                total_loss += float(loss.item()) * x.size(0)
        val_loss = total_loss / max(1, len(valid_ds))
        ppl = math.exp(min(20, val_loss))
        print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | ppl={ppl:.2f}")

        if val_loss < best_val - 1e-3:
            best_val = val_loss
            bad = 0
            torch.save(model.state_dict(), os.path.join(DATA_DIR, "gptsmall_chatbot_best.pth"))
        else:
            bad += 1
            if bad >= patience:
                print(f"Early stopping triggered (patience={patience}).")
                break

        if epoch % 2 == 0:
            sample_q = random.choice(valid_df["Q"].tolist())
            print("[Sample Q]", sample_q)
            print("[Sample A]", generate(model, sample_q, max_new_tokens=64))

    # [QUEST 5] 최종 생성 데모
    chatbot_test(model)

    torch.save(model.state_dict(), os.path.join(DATA_DIR, "gptsmall_chatbot_last.pth"))
    print("Model saved to:", os.path.join(DATA_DIR, "gptsmall_chatbot_last.pth"))

# ---------------------------- Fallback path note ----------------------------
else:
    print("\n[주의] PyTorch 미탑재 환경에서는 Bigram 확률로만 문장을 샘플링할 수 있습니다. (데모 전용)")


[QUEST 1] Baseline Transformer 대비 변경·차별 포인트(요약)
- 인코더/디코더 완전 Transformer가 아닌, GPT 스타일 '디코더-온리(Decoder-only)' 언어모델로 단순화.
- [개선] 질문 뒤에 전용 구분 토큰 <ANS>(SEP)를 도입하여 Q/ A 경계를 명확히 함.
  · 학습: [BOS] Q [SEP] A [EOS]  (Q구간은 loss에서 무시, A만 예측)
  · 추론: [BOS] Q [SEP] … (A만 생성) → Q 복창/에코링 감소
- [개선] 디코딩 제어: temperature + nucleus(top_p) + repetition_penalty + no_repeat_ngram + min_length
  → 의미불명 반복과 횡설수설 억제, “말이 되는” 답변 유도
- [개선] label smoothing(0.1) + cosine 스케줄(+warmup) → 안정적 수렴/일반화 향상.
  ※ 검증은 smoothing=0.0으로 순수 NLL을 측정해 ppl 왜곡을 방지.
- SentencePiece BPE로 한국어 토큰화, 실제 vocab size 동기화.
- 속도 최적화: 작은 모델(d_model=256, n_layers=6, n_heads=8), AMP, 자동 배치, 샘플/길이 제한.

[QUEST 2] 전처리 점검 예시:
- 샘플 1
  Q: 12시 땡!
  A: 하루가 또 가네요.
  seq(ids)[:24]: [1, 5566, 6801, 3207, 6907, 8000, 4489, 211, 5936, 6760, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  labels[:24]:   [-100, -100, -100, -100, -100, -100, 211, 5936, 6760, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]
  유효타겟존재?: True
- 샘플 2
  