In [1]:
from google.colab import drive

drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
!git clone https://github.com/ItWasAllYellow/public_cs224n_gpt.git

%cd /content/public_cs224n_gpt

Cloning into 'public_cs224n_gpt'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 69 (delta 21), reused 13 (delta 13), pack-reused 32 (from 1)[K
Receiving objects: 100% (69/69), 30.87 MiB | 29.66 MiB/s, done.
Resolving deltas: 100% (22/22), done.
/content/public_cs224n_gpt


In [3]:
"""
train_expl_all_curr.py
GPT-2 해설 생성 (모든 선택지 포함) + Curriculum Learning (grade 2 ➜ grade 3)
"""

import argparse, os, random, numpy as np, pandas as pd, torch
import torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model   # 반드시 여러분의 커스텀 GPT2Model

TQDM_DISABLE = False  # True면 TQDM bar 숨김

# ───────────────────────────── utils
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ───────────────────────────── 모델
class GPT2Explainer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.backbone = GPT2Model.from_pretrained(
            model=args.model_size, d=args.d, l=args.l, num_heads=args.num_heads
        )
        vocab = GPT2Tokenizer.from_pretrained(args.model_size).vocab_size
        self.lm_head = nn.Linear(args.d, vocab, bias=False)
        for p in self.parameters(): p.requires_grad = True

    def forward(self, ids, mask, labels=None):
        h = self.backbone(ids, attention_mask=mask)["last_hidden_state"]
        logits = self.lm_head(h)                       # (B, L, V)
        if labels is None:
            return logits
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        return loss, logits

# ───────────────────────────── 데이터셋
class ExplDataset(Dataset):
    def __init__(self, df, tok, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tok, max_len
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        return r["input"], r["target"]
    def collate_fn(self, batch):
        ins, tgs = zip(*batch)
        enc_in  = self.tok(list(ins), padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")
        enc_out = self.tok(list(tgs), padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")
        labels = enc_out.input_ids.clone()
        labels[enc_out.attention_mask == 0] = -100
        return {"input_ids": enc_in.input_ids,
                "attention_mask": enc_in.attention_mask,
                "labels": labels}

# ───────────────────────────── 평가
def evaluate(loader, model, dev):
    model.eval(); tot_loss = tok_cnt = 0
    with torch.no_grad():
        for bt in loader:
            ids  = bt["input_ids"].to(dev)
            mask = bt["attention_mask"].to(dev)
            lbls = bt["labels"].to(dev)
            loss, _ = model(ids, mask, lbls)
            num_tok = (lbls != -100).sum().item()
            tot_loss += loss.item() * num_tok
            tok_cnt  += num_tok
    ppl = np.exp(tot_loss / max(tok_cnt, 1))
    return ppl

# ───────────────────────────── 학습
def train(args):
    seed_everything(args.seed)
    dev = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print("Device:", dev)

    # 전체 CSV 로드
    df = pd.read_csv(args.data_path, encoding="utf-8-sig")

    tok = GPT2Tokenizer.from_pretrained(args.model_size)
    tok.pad_token = tok.eos_token

    model = GPT2Explainer(args).to(dev)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)

    best_ppl = 1e9  # 낮을수록 좋음

    # ----------- Curriculum: grade 2 다음 grade 3 ----------
    for grade in [2, 3]:
        print(f"\n=====  Curriculum Stage: Grade {grade}  =====")
        train_df = df[(df.split == "train") & (df.grade == grade)]
        valid_df = df[(df.split == "valid") & (df.grade == grade)]
        if len(train_df) == 0:
            print(f"  (Grade {grade} 데이터가 없습니다, 스킵)")
            continue

        tr_ds = ExplDataset(train_df, tok, args.max_length)
        va_ds = ExplDataset(valid_df, tok, args.max_length)
        tr_ld = DataLoader(tr_ds, batch_size=args.batch_size, shuffle=True,
                           collate_fn=tr_ds.collate_fn)
        va_ld = DataLoader(va_ds, batch_size=args.batch_size, shuffle=False,
                           collate_fn=va_ds.collate_fn)

        # -------- epoch loop --------
        for ep in range(1, args.epochs + 1):
            model.train(); ep_loss = 0
            for bt in tqdm(tr_ld, disable=TQDM_DISABLE,
                           desc=f"(G{grade}) Epoch {ep}"):
                optim.zero_grad()
                loss, _ = model(bt["input_ids"].to(dev),
                                bt["attention_mask"].to(dev),
                                bt["labels"].to(dev))
                loss.backward(); optim.step()
                ep_loss += loss.item()
            print(f"   >> train_loss={ep_loss/len(tr_ld):.4f}")

            ppl = evaluate(va_ld, model, dev)
            print(f"   >> valid PPL={ppl:.3f}")

            if ppl < best_ppl:
                best_ppl = ppl
                save(model, optim, args)

def save(m, o, args):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    torch.save({"model": m.state_dict(),
                "optim": o.state_dict(),
                "args":  args}, args.save_path)
    print("  ** Best model saved →", args.save_path)

# ───────────────────────────── 인자
def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--data_path", type=str,
                   default="/content/drive/MyDrive/CSEG321/dataset/explanation_all_options.csv")
    p.add_argument("--model_size", type=str, default="gpt2",
                   choices=["gpt2", "gpt2-medium", "gpt2-large"])
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--use_gpu", action="store_true")
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    # hidden dim 등
    if args.model_size == "gpt2":
        args.d, args.l, args.num_heads = 768, 12, 12
    elif args.model_size == "gpt2-medium":
        args.d, args.l, args.num_heads = 1024, 24, 16
    else:
        args.d, args.l, args.num_heads = 1280, 36, 20

    args.save_path = "/content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt"
    return args

# ───────────────────────────── main
if __name__ == "__main__":
    seed_everything()
    train(get_args())

Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]


=====  Curriculum Stage: Grade 2  =====


(G2) Epoch 1: 100%|██████████| 7/7 [00:01<00:00,  3.78it/s]


   >> train_loss=14.2450
   >> valid PPL=215893.871
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 2: 100%|██████████| 7/7 [00:01<00:00,  6.75it/s]


   >> train_loss=11.3365
   >> valid PPL=72623.975
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 3: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=10.7929
   >> valid PPL=46061.068
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 4: 100%|██████████| 7/7 [00:01<00:00,  6.70it/s]


   >> train_loss=10.5871
   >> valid PPL=41802.794
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 5: 100%|██████████| 7/7 [00:01<00:00,  6.71it/s]


   >> train_loss=10.3656
   >> valid PPL=31949.963
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 6: 100%|██████████| 7/7 [00:01<00:00,  6.65it/s]


   >> train_loss=10.1062
   >> valid PPL=25920.819
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 7: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=9.7356
   >> valid PPL=17192.143
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 8: 100%|██████████| 7/7 [00:01<00:00,  6.55it/s]


   >> train_loss=9.0949
   >> valid PPL=7996.317
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 9: 100%|██████████| 7/7 [00:01<00:00,  6.59it/s]


   >> train_loss=8.0308
   >> valid PPL=4143.812
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 10: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=7.3270
   >> valid PPL=2973.207
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 11: 100%|██████████| 7/7 [00:01<00:00,  6.56it/s]


   >> train_loss=6.7936
   >> valid PPL=2368.592
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 12: 100%|██████████| 7/7 [00:01<00:00,  6.74it/s]


   >> train_loss=6.4777
   >> valid PPL=2126.204
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 13: 100%|██████████| 7/7 [00:01<00:00,  6.63it/s]


   >> train_loss=6.2448
   >> valid PPL=2069.050
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G2) Epoch 14: 100%|██████████| 7/7 [00:01<00:00,  6.73it/s]


   >> train_loss=6.1128
   >> valid PPL=2182.328


(G2) Epoch 15: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=6.0222
   >> valid PPL=2139.145


(G2) Epoch 16: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.9414
   >> valid PPL=2215.921


(G2) Epoch 17: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=5.8929
   >> valid PPL=2251.196


(G2) Epoch 18: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=5.8402
   >> valid PPL=2330.860


(G2) Epoch 19: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=5.7962
   >> valid PPL=2409.446


(G2) Epoch 20: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.7808
   >> valid PPL=2513.057


(G2) Epoch 21: 100%|██████████| 7/7 [00:01<00:00,  6.76it/s]


   >> train_loss=5.7300
   >> valid PPL=2610.108


(G2) Epoch 22: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.7309
   >> valid PPL=2616.445


(G2) Epoch 23: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=5.6534
   >> valid PPL=2708.656


(G2) Epoch 24: 100%|██████████| 7/7 [00:01<00:00,  6.84it/s]


   >> train_loss=5.6229
   >> valid PPL=2802.580


(G2) Epoch 25: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=5.6023
   >> valid PPL=2859.755


(G2) Epoch 26: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=5.5545
   >> valid PPL=2906.227


(G2) Epoch 27: 100%|██████████| 7/7 [00:01<00:00,  6.84it/s]


   >> train_loss=5.4866
   >> valid PPL=2947.921


(G2) Epoch 28: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.4460
   >> valid PPL=3131.394


(G2) Epoch 29: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=5.3882
   >> valid PPL=3297.911


(G2) Epoch 30: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=5.3723
   >> valid PPL=3653.691


(G2) Epoch 31: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=5.2754
   >> valid PPL=3948.228


(G2) Epoch 32: 100%|██████████| 7/7 [00:01<00:00,  6.85it/s]


   >> train_loss=5.1981
   >> valid PPL=3248.637


(G2) Epoch 33: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=5.1467
   >> valid PPL=3348.846


(G2) Epoch 34: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=5.0274
   >> valid PPL=3477.152


(G2) Epoch 35: 100%|██████████| 7/7 [00:01<00:00,  6.86it/s]


   >> train_loss=4.9349
   >> valid PPL=3655.329


(G2) Epoch 36: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=4.9229
   >> valid PPL=3902.609


(G2) Epoch 37: 100%|██████████| 7/7 [00:01<00:00,  6.84it/s]


   >> train_loss=4.7897
   >> valid PPL=4415.543


(G2) Epoch 38: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=4.6985
   >> valid PPL=4261.157


(G2) Epoch 39: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=4.5625
   >> valid PPL=4439.188


(G2) Epoch 40: 100%|██████████| 7/7 [00:01<00:00,  6.87it/s]


   >> train_loss=4.4900
   >> valid PPL=4872.073


(G2) Epoch 41: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=4.3726
   >> valid PPL=6013.093


(G2) Epoch 42: 100%|██████████| 7/7 [00:01<00:00,  6.86it/s]


   >> train_loss=4.2685
   >> valid PPL=4704.362


(G2) Epoch 43: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=4.0954
   >> valid PPL=3982.550


(G2) Epoch 44: 100%|██████████| 7/7 [00:01<00:00,  6.81it/s]


   >> train_loss=3.9634
   >> valid PPL=4003.006


(G2) Epoch 45: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=3.7932
   >> valid PPL=4516.528


(G2) Epoch 46: 100%|██████████| 7/7 [00:01<00:00,  6.86it/s]


   >> train_loss=3.6733
   >> valid PPL=4644.865


(G2) Epoch 47: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=3.4952
   >> valid PPL=5457.426


(G2) Epoch 48: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=3.4048
   >> valid PPL=5152.569


(G2) Epoch 49: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=3.2324
   >> valid PPL=4774.637


(G2) Epoch 50: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=3.1250
   >> valid PPL=4862.497

=====  Curriculum Stage: Grade 3  =====


(G3) Epoch 1: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


   >> train_loss=8.2722
   >> valid PPL=2296.407


(G3) Epoch 2: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=7.2279
   >> valid PPL=1573.924
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G3) Epoch 3: 100%|██████████| 29/29 [00:04<00:00,  6.81it/s]


   >> train_loss=6.5680
   >> valid PPL=1479.607
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_all_curr_50.pt


(G3) Epoch 4: 100%|██████████| 29/29 [00:04<00:00,  6.83it/s]


   >> train_loss=6.2316
   >> valid PPL=1568.379


(G3) Epoch 5: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=6.0860
   >> valid PPL=1649.176


(G3) Epoch 6: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=5.9693
   >> valid PPL=1755.558


(G3) Epoch 7: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=5.8907
   >> valid PPL=1823.791


(G3) Epoch 8: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=5.7877
   >> valid PPL=1977.392


(G3) Epoch 9: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=5.6912
   >> valid PPL=1973.633


(G3) Epoch 10: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=5.5518
   >> valid PPL=2096.851


(G3) Epoch 11: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=5.4099
   >> valid PPL=2140.718


(G3) Epoch 12: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=5.2416
   >> valid PPL=2248.288


(G3) Epoch 13: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=5.0387
   >> valid PPL=2342.461


(G3) Epoch 14: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=4.8572
   >> valid PPL=2539.207


(G3) Epoch 15: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=4.6460
   >> valid PPL=2541.562


(G3) Epoch 16: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=4.4211
   >> valid PPL=2731.810


(G3) Epoch 17: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=4.1925
   >> valid PPL=2899.192


(G3) Epoch 18: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.9574
   >> valid PPL=2982.103


(G3) Epoch 19: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.7353
   >> valid PPL=3314.253


(G3) Epoch 20: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=3.4974
   >> valid PPL=3406.384


(G3) Epoch 21: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.2844
   >> valid PPL=3688.971


(G3) Epoch 22: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.0514
   >> valid PPL=3725.699


(G3) Epoch 23: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=2.8360
   >> valid PPL=4146.245


(G3) Epoch 24: 100%|██████████| 29/29 [00:04<00:00,  6.86it/s]


   >> train_loss=2.6124
   >> valid PPL=4379.768


(G3) Epoch 25: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=2.3926
   >> valid PPL=4578.512


(G3) Epoch 26: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=2.2038
   >> valid PPL=4959.414


(G3) Epoch 27: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=2.0056
   >> valid PPL=5370.939


(G3) Epoch 28: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=1.8333
   >> valid PPL=5868.061


(G3) Epoch 29: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=1.6721
   >> valid PPL=5925.996


(G3) Epoch 30: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=1.5090
   >> valid PPL=6397.624


(G3) Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=1.3873
   >> valid PPL=6611.548


(G3) Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=1.2337
   >> valid PPL=7435.743


(G3) Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=1.1306
   >> valid PPL=7512.455


(G3) Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=1.0233
   >> valid PPL=7939.770


(G3) Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.9408
   >> valid PPL=8713.817


(G3) Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.8566
   >> valid PPL=9150.666


(G3) Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.7822
   >> valid PPL=9856.042


(G3) Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.6983
   >> valid PPL=10203.493


(G3) Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.6566
   >> valid PPL=10419.407


(G3) Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=0.5954
   >> valid PPL=10960.493


(G3) Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.5567
   >> valid PPL=11334.954


(G3) Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.5263
   >> valid PPL=11664.377


(G3) Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.4836
   >> valid PPL=12167.523


(G3) Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=0.4552
   >> valid PPL=12268.513


(G3) Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.4226
   >> valid PPL=13104.186


(G3) Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.87it/s]


   >> train_loss=0.4095
   >> valid PPL=13236.810


(G3) Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.3940
   >> valid PPL=14363.894


(G3) Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=0.3758
   >> valid PPL=14588.724


(G3) Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=0.3530
   >> valid PPL=14753.851


(G3) Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.3396
   >> valid PPL=15492.765


In [4]:
"""
train_expl_all_curr.py
GPT-2 해설 생성 (모든 선택지 포함) + Curriculum Learning (grade 2 ➜ grade 3)
"""

import argparse, os, random, numpy as np, pandas as pd, torch
import torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model   # 반드시 여러분의 커스텀 GPT2Model

TQDM_DISABLE = False  # True면 TQDM bar 숨김

# ───────────────────────────── utils
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ───────────────────────────── 모델
class GPT2Explainer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.backbone = GPT2Model.from_pretrained(
            model=args.model_size, d=args.d, l=args.l, num_heads=args.num_heads
        )
        vocab = GPT2Tokenizer.from_pretrained(args.model_size).vocab_size
        self.lm_head = nn.Linear(args.d, vocab, bias=False)
        for p in self.parameters(): p.requires_grad = True

    def forward(self, ids, mask, labels=None):
        h = self.backbone(ids, attention_mask=mask)["last_hidden_state"]
        logits = self.lm_head(h)                       # (B, L, V)
        if labels is None:
            return logits
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            ignore_index=-100
        )
        return loss, logits

# ───────────────────────────── 데이터셋
class ExplDataset(Dataset):
    def __init__(self, df, tok, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tok, max_len
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        return r["input"], r["target"]
    def collate_fn(self, batch):
        ins, tgs = zip(*batch)
        enc_in  = self.tok(list(ins), padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")
        enc_out = self.tok(list(tgs), padding="max_length", truncation=True,
                           max_length=self.max_len, return_tensors="pt")
        labels = enc_out.input_ids.clone()
        labels[enc_out.attention_mask == 0] = -100
        return {"input_ids": enc_in.input_ids,
                "attention_mask": enc_in.attention_mask,
                "labels": labels}

# ───────────────────────────── 평가
def evaluate(loader, model, dev):
    model.eval(); tot_loss = tok_cnt = 0
    with torch.no_grad():
        for bt in loader:
            ids  = bt["input_ids"].to(dev)
            mask = bt["attention_mask"].to(dev)
            lbls = bt["labels"].to(dev)
            loss, _ = model(ids, mask, lbls)
            num_tok = (lbls != -100).sum().item()
            tot_loss += loss.item() * num_tok
            tok_cnt  += num_tok
    ppl = np.exp(tot_loss / max(tok_cnt, 1))
    return ppl

# ───────────────────────────── 학습
def train(args):
    seed_everything(args.seed)
    dev = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print("Device:", dev)

    # 전체 CSV 로드
    df = pd.read_csv(args.data_path, encoding="utf-8-sig")

    tok = GPT2Tokenizer.from_pretrained(args.model_size)
    tok.pad_token = tok.eos_token

    model = GPT2Explainer(args).to(dev)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)

    best_ppl = 1e9  # 낮을수록 좋음

    # ----------- Curriculum: grade 2 다음 grade 3 ----------
    for grade in [2, 3]:
        print(f"\n=====  Curriculum Stage: Grade {grade}  =====")
        train_df = df[(df.split == "train") & (df.grade == grade)]
        valid_df = df[(df.split == "valid") & (df.grade == grade)]
        if len(train_df) == 0:
            print(f"  (Grade {grade} 데이터가 없습니다, 스킵)")
            continue

        tr_ds = ExplDataset(train_df, tok, args.max_length)
        va_ds = ExplDataset(valid_df, tok, args.max_length)
        tr_ld = DataLoader(tr_ds, batch_size=args.batch_size, shuffle=True,
                           collate_fn=tr_ds.collate_fn)
        va_ld = DataLoader(va_ds, batch_size=args.batch_size, shuffle=False,
                           collate_fn=va_ds.collate_fn)

        # -------- epoch loop --------
        for ep in range(1, args.epochs + 1):
            model.train(); ep_loss = 0
            for bt in tqdm(tr_ld, disable=TQDM_DISABLE,
                           desc=f"(G{grade}) Epoch {ep}"):
                optim.zero_grad()
                loss, _ = model(bt["input_ids"].to(dev),
                                bt["attention_mask"].to(dev),
                                bt["labels"].to(dev))
                loss.backward(); optim.step()
                ep_loss += loss.item()
            print(f"   >> train_loss={ep_loss/len(tr_ld):.4f}")

            ppl = evaluate(va_ld, model, dev)
            print(f"   >> valid PPL={ppl:.3f}")

            if ppl < best_ppl:
                best_ppl = ppl
                save(model, optim, args)

def save(m, o, args):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    torch.save({"model": m.state_dict(),
                "optim": o.state_dict(),
                "args":  args}, args.save_path)
    print("  ** Best model saved →", args.save_path)

# ───────────────────────────── 인자
def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--data_path", type=str,
                   default="/content/drive/MyDrive/CSEG321/dataset/explanation_only_answer.csv")
    p.add_argument("--model_size", type=str, default="gpt2",
                   choices=["gpt2", "gpt2-medium", "gpt2-large"])
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--epochs", type=int, default=50)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--use_gpu", action="store_true")
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    # hidden dim 등
    if args.model_size == "gpt2":
        args.d, args.l, args.num_heads = 768, 12, 12
    elif args.model_size == "gpt2-medium":
        args.d, args.l, args.num_heads = 1024, 24, 16
    else:
        args.d, args.l, args.num_heads = 1280, 36, 20

    args.save_path = "/content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt"
    return args

# ───────────────────────────── main
if __name__ == "__main__":
    seed_everything()
    train(get_args())

Device: cuda

=====  Curriculum Stage: Grade 2  =====


(G2) Epoch 1: 100%|██████████| 7/7 [00:01<00:00,  6.34it/s]


   >> train_loss=14.2505
   >> valid PPL=187714.827
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 2: 100%|██████████| 7/7 [00:01<00:00,  6.68it/s]


   >> train_loss=11.2944
   >> valid PPL=63569.182
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 3: 100%|██████████| 7/7 [00:01<00:00,  6.78it/s]


   >> train_loss=10.7374
   >> valid PPL=43382.468
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 4: 100%|██████████| 7/7 [00:01<00:00,  6.59it/s]


   >> train_loss=10.5452
   >> valid PPL=38394.987
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 5: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=10.3247
   >> valid PPL=30383.045
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 6: 100%|██████████| 7/7 [00:01<00:00,  6.62it/s]


   >> train_loss=10.0524
   >> valid PPL=23910.164
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 7: 100%|██████████| 7/7 [00:01<00:00,  6.72it/s]


   >> train_loss=9.6498
   >> valid PPL=15453.411
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 8: 100%|██████████| 7/7 [00:01<00:00,  6.61it/s]


   >> train_loss=8.9297
   >> valid PPL=6734.717
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 9: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=7.8620
   >> valid PPL=4032.736
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 10: 100%|██████████| 7/7 [00:01<00:00,  6.63it/s]


   >> train_loss=7.2219
   >> valid PPL=2747.150
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 11: 100%|██████████| 7/7 [00:01<00:00,  6.69it/s]


   >> train_loss=6.7341
   >> valid PPL=2371.478
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 12: 100%|██████████| 7/7 [00:01<00:00,  6.75it/s]


   >> train_loss=6.4367
   >> valid PPL=2206.533
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G2) Epoch 13: 100%|██████████| 7/7 [00:01<00:00,  6.66it/s]


   >> train_loss=6.2370
   >> valid PPL=2230.320


(G2) Epoch 14: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=6.1249
   >> valid PPL=2373.411


(G2) Epoch 15: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=6.0275
   >> valid PPL=2309.389


(G2) Epoch 16: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=5.9730
   >> valid PPL=2414.553


(G2) Epoch 17: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=5.9135
   >> valid PPL=2418.300


(G2) Epoch 18: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=5.8701
   >> valid PPL=2489.806


(G2) Epoch 19: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=5.8273
   >> valid PPL=2562.633


(G2) Epoch 20: 100%|██████████| 7/7 [00:01<00:00,  6.83it/s]


   >> train_loss=5.8132
   >> valid PPL=2637.154


(G2) Epoch 21: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=5.7475
   >> valid PPL=2782.748


(G2) Epoch 22: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=5.7536
   >> valid PPL=2797.806


(G2) Epoch 23: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=5.6785
   >> valid PPL=2833.065


(G2) Epoch 24: 100%|██████████| 7/7 [00:01<00:00,  6.85it/s]


   >> train_loss=5.6559
   >> valid PPL=2921.338


(G2) Epoch 25: 100%|██████████| 7/7 [00:01<00:00,  6.80it/s]


   >> train_loss=5.6326
   >> valid PPL=2974.896


(G2) Epoch 26: 100%|██████████| 7/7 [00:01<00:00,  6.82it/s]


   >> train_loss=5.5892
   >> valid PPL=2979.884


(G2) Epoch 27: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=5.5367
   >> valid PPL=3096.798


(G2) Epoch 28: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=5.4924
   >> valid PPL=3167.644


(G2) Epoch 29: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=5.4298
   >> valid PPL=3244.207


(G2) Epoch 30: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=5.4104
   >> valid PPL=3504.373


(G2) Epoch 31: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=5.3047
   >> valid PPL=3526.477


(G2) Epoch 32: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.2489
   >> valid PPL=3451.488


(G2) Epoch 33: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=5.2101
   >> valid PPL=3409.553


(G2) Epoch 34: 100%|██████████| 7/7 [00:01<00:00,  6.86it/s]


   >> train_loss=5.0890
   >> valid PPL=3555.966


(G2) Epoch 35: 100%|██████████| 7/7 [00:01<00:00,  6.86it/s]


   >> train_loss=4.9990
   >> valid PPL=3854.928


(G2) Epoch 36: 100%|██████████| 7/7 [00:01<00:00,  6.85it/s]


   >> train_loss=4.9840
   >> valid PPL=3898.487


(G2) Epoch 37: 100%|██████████| 7/7 [00:01<00:00,  6.89it/s]


   >> train_loss=4.8473
   >> valid PPL=3936.411


(G2) Epoch 38: 100%|██████████| 7/7 [00:01<00:00,  6.94it/s]


   >> train_loss=4.7402
   >> valid PPL=3857.565


(G2) Epoch 39: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=4.7118
   >> valid PPL=4600.261


(G2) Epoch 40: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=4.7072
   >> valid PPL=5120.342


(G2) Epoch 41: 100%|██████████| 7/7 [00:01<00:00,  6.85it/s]


   >> train_loss=4.5077
   >> valid PPL=5139.775


(G2) Epoch 42: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=4.3706
   >> valid PPL=4041.075


(G2) Epoch 43: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=4.1926
   >> valid PPL=4241.129


(G2) Epoch 44: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=4.0897
   >> valid PPL=4732.558


(G2) Epoch 45: 100%|██████████| 7/7 [00:01<00:00,  6.88it/s]


   >> train_loss=3.9419
   >> valid PPL=6137.851


(G2) Epoch 46: 100%|██████████| 7/7 [00:01<00:00,  6.90it/s]


   >> train_loss=3.8336
   >> valid PPL=4983.337


(G2) Epoch 47: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=3.6456
   >> valid PPL=4756.444


(G2) Epoch 48: 100%|██████████| 7/7 [00:01<00:00,  6.92it/s]


   >> train_loss=3.5486
   >> valid PPL=5617.544


(G2) Epoch 49: 100%|██████████| 7/7 [00:01<00:00,  6.93it/s]


   >> train_loss=3.3570
   >> valid PPL=5615.112


(G2) Epoch 50: 100%|██████████| 7/7 [00:01<00:00,  6.91it/s]


   >> train_loss=3.2073
   >> valid PPL=5539.415

=====  Curriculum Stage: Grade 3  =====


(G3) Epoch 1: 100%|██████████| 29/29 [00:04<00:00,  6.71it/s]


   >> train_loss=8.3685
   >> valid PPL=2480.342


(G3) Epoch 2: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=7.2899
   >> valid PPL=1703.435
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G3) Epoch 3: 100%|██████████| 29/29 [00:04<00:00,  6.85it/s]


   >> train_loss=6.6627
   >> valid PPL=1485.457
  ** Best model saved → /content/drive/MyDrive/CSEG321/models/expl_only_answer_curr_50.pt


(G3) Epoch 4: 100%|██████████| 29/29 [00:04<00:00,  6.85it/s]


   >> train_loss=6.2956
   >> valid PPL=1593.696


(G3) Epoch 5: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=6.1116
   >> valid PPL=1661.158


(G3) Epoch 6: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=5.9658
   >> valid PPL=1746.589


(G3) Epoch 7: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=5.8689
   >> valid PPL=1832.338


(G3) Epoch 8: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=5.7428
   >> valid PPL=1945.319


(G3) Epoch 9: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=5.6223
   >> valid PPL=2013.125


(G3) Epoch 10: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=5.4556
   >> valid PPL=2084.043


(G3) Epoch 11: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=5.2859
   >> valid PPL=2150.654


(G3) Epoch 12: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=5.0898
   >> valid PPL=2205.927


(G3) Epoch 13: 100%|██████████| 29/29 [00:04<00:00,  6.90it/s]


   >> train_loss=4.8784
   >> valid PPL=2273.280


(G3) Epoch 14: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=4.6667
   >> valid PPL=2434.461


(G3) Epoch 15: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=4.4734
   >> valid PPL=2565.312


(G3) Epoch 16: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=4.2160
   >> valid PPL=2597.599


(G3) Epoch 17: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=3.9704
   >> valid PPL=2725.817


(G3) Epoch 18: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=3.7425
   >> valid PPL=2899.217


(G3) Epoch 19: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=3.5012
   >> valid PPL=3095.402


(G3) Epoch 20: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.2721
   >> valid PPL=3225.176


(G3) Epoch 21: 100%|██████████| 29/29 [00:04<00:00,  6.88it/s]


   >> train_loss=3.0340
   >> valid PPL=3417.025


(G3) Epoch 22: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=2.8169
   >> valid PPL=3637.617


(G3) Epoch 23: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=2.6090
   >> valid PPL=4046.213


(G3) Epoch 24: 100%|██████████| 29/29 [00:04<00:00,  6.89it/s]


   >> train_loss=2.4028
   >> valid PPL=4206.552


(G3) Epoch 25: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=2.1904
   >> valid PPL=4225.218


(G3) Epoch 26: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=2.0030
   >> valid PPL=4766.559


(G3) Epoch 27: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=1.8242
   >> valid PPL=5372.209


(G3) Epoch 28: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=1.6459
   >> valid PPL=5508.155


(G3) Epoch 29: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=1.5129
   >> valid PPL=5601.458


(G3) Epoch 30: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=1.3647
   >> valid PPL=6178.275


(G3) Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=1.2542
   >> valid PPL=6192.265


(G3) Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=1.1141
   >> valid PPL=6568.828


(G3) Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=1.0030
   >> valid PPL=6936.989


(G3) Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=0.9100
   >> valid PPL=7298.500


(G3) Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.8475
   >> valid PPL=8061.563


(G3) Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.7797
   >> valid PPL=8277.941


(G3) Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.7072
   >> valid PPL=8649.886


(G3) Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=0.6386
   >> valid PPL=9214.008


(G3) Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=0.6034
   >> valid PPL=9079.639


(G3) Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.5421
   >> valid PPL=9726.868


(G3) Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.5112
   >> valid PPL=10579.643


(G3) Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.4795
   >> valid PPL=10746.008


(G3) Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=0.4480
   >> valid PPL=11017.642


(G3) Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=0.4175
   >> valid PPL=11634.153


(G3) Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.3957
   >> valid PPL=12163.804


(G3) Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


   >> train_loss=0.3816
   >> valid PPL=12030.992


(G3) Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.94it/s]


   >> train_loss=0.3660
   >> valid PPL=12937.415


(G3) Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.93it/s]


   >> train_loss=0.3512
   >> valid PPL=12619.583


(G3) Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.91it/s]


   >> train_loss=0.3332
   >> valid PPL=12994.642


(G3) Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  6.95it/s]


   >> train_loss=0.3170
   >> valid PPL=13259.361
