In [None]:
# cell 1 패키지 설치
!pip install -U pyarrow==15.0.2 fastparquet pandas torch optuna

In [None]:
# cell 2 경로/시드/하이퍼 파라미터
from pathlib import Path
import json, pandas as pd, numpy as np
import torch

# ==== 경로 ====
DATA_DIR = Path("/home/jovyan/datasets/next_state_pre")
DATA_DIR.mkdir(parents=True, exist_ok=True)

RAW_PARQUET = Path("/home/jovyan/datasets/processed_user_behavior.sorted.parquet")  # 필요시 수정
MAP_JSON    = DATA_DIR / "state_mapping.json"
SPLIT_JSON  = DATA_DIR / "train_val_sessions.json"
PAIR_NPY    = DATA_DIR / "observed_prev_pairs.npy"
BEST_PATH   = DATA_DIR / "best_model.pt"

# ==== 시드/분할 ====
SEED = 17
VAL_FRAC = 0.2
np.random.seed(SEED)

# ==== 하이퍼파라미터 (수동 설정 기본값) ====
HYPERPARAMS = {
    "emb_dim": 192,
    "pair_dim": 16,
    "hid": 384,
    "num_layers": 2,
    "dropout": 0.2,
    "lr": 3e-4,
    "weight_decay": 5e-4,
    "batch_size": 64,
    "epochs": 50,
    "patience": 3,
    "use_amp": True,           # CUDA 사용시 AMP 켜기
    "grad_clip_norm": 1.0,
}

device = "cuda" if torch.cuda.is_available() else "cpu"
print("DATA_DIR:", DATA_DIR, "| device:", device)


In [None]:
# cell 3 데이터 로드 & 필수 컬럼 확인
df = pd.read_parquet(RAW_PARQUET)
need_cols = ["session_id", "current_state", "next_state"]
miss = [c for c in need_cols if c not in df.columns]
assert not miss, f"필수 컬럼 없음: {miss}"
print(df[need_cols].head(3))
print("[INFO] rows:", len(df))


In [None]:
# cell 4 세션 통계 & 짧은 세션(2이하) 필터링
MIN_LEN = 3  # 권장: >=3

sess_len = df.groupby("session_id").size().rename("sess_len")
desc = sess_len.describe(percentiles=[0.25,0.5,0.75,0.9,0.95]).to_dict()

print("[INFO] session length stats:")
for k in ["count","mean","std","min","25%","50%","75%","90%","95%","max"]:
    if k in desc: print(f"  {k:>4}: {desc[k]}")

short_ratio = (sess_len < MIN_LEN).mean()
print(f"[INFO] short sessions (<{MIN_LEN}) ratio: {short_ratio:.2%}  ({(sess_len < MIN_LEN).sum()} / {len(sess_len)})")

valid_sessions = sess_len[sess_len >= MIN_LEN].index
n_before = len(df)
df = df[df["session_id"].isin(valid_sessions)].reset_index(drop=True)
n_after = len(df)
print(f"[FILTER] kept sessions >= {MIN_LEN}: {len(valid_sessions)} / {len(sess_len)} rows: {n_before} -> {n_after} (-{n_before - n_after})")


In [None]:
# cell 5 상태 인덱싱(PAD/UNK) 및 저장
PAD_ID, UNK_ID = 0, 1

states_all = pd.Index(df["current_state"].astype(str).unique()).union(
    df["next_state"].astype(str).unique()
)
states_sorted = pd.Index(sorted(states_all))

state2idx = {s: i+2 for i, s in enumerate(states_sorted)}  # 2부터 시작
idx2state = {i: s for s, i in state2idx.items()}
num_states = len(state2idx) + 2  # PAD, UNK 포함

with open(MAP_JSON, "w", encoding="utf-8") as f:
    json.dump({"PAD_ID": PAD_ID, "UNK_ID": UNK_ID, "state2idx": state2idx}, f, ensure_ascii=False, indent=2)

print("[SAVE] state mapping ->", MAP_JSON)
print("[INFO] num_states (PAD/UNK 포함):", num_states)


In [None]:
# cell 6 인덱싱 적용
with open(MAP_JSON, "r", encoding="utf-8") as f:
    mp = json.load(f)
state2idx = mp["state2idx"]; PAD_ID = mp["PAD_ID"]; UNK_ID = mp["UNK_ID"]

def map_state(s):
    return state2idx.get(str(s), UNK_ID)

df["current_state_idx"] = df["current_state"].astype(str).map(map_state).astype("int32")
df["next_state_idx"]    = df["next_state"].astype(str).map(map_state).astype("int32")
print(df[["current_state","current_state_idx","next_state","next_state_idx"]].head(5))


In [None]:
# cell 7 prev1/prev2 생성 & 관측 쌍 저장
g = df.groupby("session_id")["current_state_idx"]
df["prev1_idx"] = g.shift(1).fillna(PAD_ID).astype("int32")
df["prev2_idx"] = g.shift(2).fillna(PAD_ID).astype("int32")

pairs = df[["prev2_idx","prev1_idx"]].drop_duplicates().to_numpy(dtype=np.int32)
np.save(PAIR_NPY, pairs)
print("[SAVE] observed prev pairs ->", PAIR_NPY, pairs.shape)


In [None]:
# cell 8 세션 기반 Train/Val 분할
sess = df["session_id"].drop_duplicates().to_numpy()
np.random.shuffle(sess)
n_val = int(len(sess)*VAL_FRAC)
val_sessions = set(sess[:n_val]); train_sessions = set(sess[n_val:])

with open(SPLIT_JSON, "w") as f:
    json.dump({"seed": SEED, "val_frac": VAL_FRAC,
               "train_sessions": list(train_sessions),
               "val_sessions": list(val_sessions)}, f, indent=2)
print("[SAVE] split ->", SPLIT_JSON, "| #train:", len(train_sessions), " #val:", len(val_sessions))


In [None]:
# cell 9 탭형 피처 저장
TAB_FEATS = ["current_state_idx","prev1_idx","prev2_idx"]
LABEL = "next_state_idx"

X_train = df[df["session_id"].isin(train_sessions)][TAB_FEATS].reset_index(drop=True)
y_train = df[df["session_id"].isin(train_sessions)][LABEL].reset_index(drop=True)
X_val   = df[df["session_id"].isin(val_sessions)][TAB_FEATS].reset_index(drop=True)
y_val   = df[df["session_id"].isin(val_sessions)][LABEL].reset_index(drop=True)

# Parquet
X_train.astype("int32").to_parquet(DATA_DIR/"X_train.parquet", index=False)
y_train.to_frame(name="y").astype("int32").to_parquet(DATA_DIR/"y_train.parquet", index=False)
X_val.astype("int32").to_parquet(DATA_DIR/"X_val.parquet", index=False)
y_val.to_frame(name="y").astype("int32").to_parquet(DATA_DIR/"y_val.parquet", index=False)

# CSV (옵션)
X_train.astype("int32").to_csv(DATA_DIR/"X_train.csv", index=False)
y_train.to_frame(name="y").astype("int32").to_csv(DATA_DIR/"y_train.csv", index=False)
X_val.astype("int32").to_csv(DATA_DIR/"X_val.csv", index=False)
y_val.to_frame(name="y").astype("int32").to_csv(DATA_DIR/"y_val.csv", index=False)

print("[SAVE] parquet & csv ->", DATA_DIR)
print("  X_train:", X_train.shape, " | y_train:", y_train.shape)
print("  X_val  :", X_val.shape,   " | y_val  :", y_val.shape)


In [None]:
# cell 10 시퀀스 빌드
def build_session_sequences(frame):
    seqs_in, seqs_prevpair, seqs_y = [], [], []
    for sid, sub in frame.sort_values(["session_id"]).groupby("session_id"):
        x = sub["current_state_idx"].to_numpy(dtype=np.int32)
        y = sub["next_state_idx"].to_numpy(dtype=np.int32)
        p1 = sub["prev1_idx"].to_numpy(dtype=np.int32)
        p2 = sub["prev2_idx"].to_numpy(dtype=np.int32)
        if len(x) < 2: 
            continue
        seqs_in.append(x[:-1])
        seqs_y.append(y[:-1])
        seqs_prevpair.append(np.stack([p2[:-1], p1[:-1]], axis=1))  # [T-1,2]
    return seqs_in, seqs_prevpair, seqs_y

train_in, train_prevpair, train_y = build_session_sequences(df[df["session_id"].isin(train_sessions)])
val_in,   val_prevpair,   val_y   = build_session_sequences(df[df["session_id"].isin(val_sessions)])

print("[INFO] #train seq:", len(train_in), " | #val seq:", len(val_in))
print("[INFO] sample lens:", [len(train_in[i]) for i in range(min(3, len(train_in)))])


In [None]:
# cell 11.1 Dataset/DataLoader & 배치 확인
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch

class NextStateSeqDataset(Dataset):
    def __init__(self, xs, pairs, ys):
        self.xs = xs; self.pairs = pairs; self.ys = ys
    def __len__(self): return len(self.xs)
    def __getitem__(self, i): return self.xs[i], self.pairs[i], self.ys[i]

def pad_collate(batch, pad_id=PAD_ID):
    xs, ps, ys = zip(*batch)
    L = max(len(x) for x in xs)
    x_pad = np.full((len(xs), L), pad_id, dtype=np.int64)
    y_pad = np.full((len(xs), L), pad_id, dtype=np.int64)
    mask  = np.zeros((len(xs), L), dtype=np.bool_)
    pair_pad = np.full((len(xs), L, 2), pad_id, dtype=np.int64)
    for i,(x,p,y) in enumerate(zip(xs,ps,ys)):
        l = len(x)
        x_pad[i,:l] = x; y_pad[i,:l] = y; mask[i,:l] = True; pair_pad[i,:l,:] = p
    return (torch.from_numpy(x_pad), torch.from_numpy(pair_pad),
            torch.from_numpy(y_pad), torch.from_numpy(mask))

train_ds = NextStateSeqDataset(train_in, train_prevpair, train_y)
val_ds   = NextStateSeqDataset(val_in,   val_prevpair,   val_y)

train_loader = DataLoader(train_ds, batch_size=HYPERPARAMS["batch_size"], shuffle=True,  collate_fn=pad_collate)
val_loader   = DataLoader(val_ds,   batch_size=HYPERPARAMS["batch_size"], shuffle=False, collate_fn=pad_collate)

NUM_CLASSES = num_states  # 중요: vocab_size와 동일
xb, pb, yb, mb = next(iter(train_loader))
print("[INFO] batch shapes:", xb.shape, pb.shape, yb.shape, mb.shape)


In [None]:
# cell 11.2 prev1/prev2 csv로 저장
# === CSV로 prev2/prev1 확인 저장 ===
import pandas as pd

# 배치에서 numpy로 변환
pb_np = pb.numpy()    # (B, L, 2)
yb_np = yb.numpy()    # (B, L)
mask_np = mb.numpy()  # (B, L)

rows = []
for i in range(pb_np.shape[0]):       # 배치 크기
    for j in range(pb_np.shape[1]):   # 시퀀스 길이
        if mask_np[i, j]:  # 유효 토큰만
            prev2, prev1 = pb_np[i, j]
            rows.append({
                "sample_id": i,
                "step": j,
                "prev2": int(prev2),
                "prev1": int(prev1),
                "target": int(yb_np[i, j])
            })

df = pd.DataFrame(rows)
out_path = "/home/jovyan/datasets/next_state_pre/pre_pairs_sample.csv"
df.to_csv(out_path, index=False, encoding="utf-8")
print(f"[SAVE] prev2/prev1/target 샘플 -> {out_path} (rows={len(df)})")


In [None]:
# cell 12 간단 LSTMClassifier 구현
import torch.nn as nn
import torch

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, emb_dim=192, pair_dim=16, hid=384, num_layers=2, dropout=0.2):
        super().__init__()
        self.num_classes = vocab_size
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD_ID)
        self.pair_proj = nn.Linear(2, pair_dim)
        self.lstm = nn.LSTM(input_size=emb_dim + pair_dim,
                            hidden_size=hid, num_layers=num_layers,
                            dropout=dropout if num_layers > 1 else 0.0,
                            batch_first=True, bidirectional=False)
        self.drop = nn.Dropout(dropout)
        self.head = nn.Linear(hid, vocab_size)

    def forward(self, x, pair, mask):
        xe = self.emb(x)                  # [B,T,E]
        pe = self.pair_proj(pair.float()) # [B,T,P]
        inp = torch.cat([xe, pe], dim=-1) # [B,T,E+P]
        out, _ = self.lstm(inp)           # [B,T,H]
        out = self.drop(out)
        logits = self.head(out)           # [B,T,C]
        return logits


In [None]:
# cell 13 클래스 가중치 계산
from collections import Counter
import numpy as np
import torch

# y_train은 탭형용이라도 라벨 분포 확인에 그대로 사용 가능
label_counts = Counter(y_train.tolist())
total = sum(label_counts.values())

weights = np.ones(NUM_CLASSES, dtype=np.float32)
for k, v in label_counts.items():
    weights[k] = total / (len(label_counts) * v)  # inverse frequency

class_weights = torch.tensor(weights, device=device)
print("[INFO] class_weights sample:", {k: float(class_weights[k]) for k in list(label_counts.keys())[:8]})


In [None]:
# cell 14 디바이스/모델/옵티마이저 준비 & 디버그 포워드
from torch.optim import AdamW
from torch.amp import autocast, GradScaler

model = LSTMClassifier(
    vocab_size=NUM_CLASSES,
    emb_dim=HYPERPARAMS["emb_dim"],
    pair_dim=HYPERPARAMS["pair_dim"],
    hid=HYPERPARAMS["hid"],
    num_layers=HYPERPARAMS["num_layers"],
    dropout=HYPERPARAMS["dropout"],
).to(device)

optimizer = AdamW(model.parameters(), lr=HYPERPARAMS["lr"], weight_decay=HYPERPARAMS["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)
scaler    = GradScaler("cuda", enabled=(device=="cuda" and HYPERPARAMS["use_amp"]))

# 디버그 포워드
xb, pb, yb, mb = next(iter(train_loader))
xb, pb, yb, mb = xb.to(device), pb.to(device), yb.to(device), mb.to(device)
model.eval()
with torch.no_grad():
    logits = model(xb, pb, mb)
    print("[DEBUG] logits shape:", tuple(logits.shape))


In [None]:
# cell 15 학습/평가 함수
import torch
import torch.nn as nn
from sklearn.metrics import f1_score

def train_one_epoch(model, loader, optimizer, device, num_classes, class_weights=None, use_amp=True, grad_clip_norm=1.0, scaler=None):
    model.train()
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, weight=class_weights)
    correct = 0; total = 0; last_loss = 0.0

    for x, pair, y, mask in loader:
        x, pair, y, mask = x.to(device), pair.to(device), y.to(device), mask.to(device)

        optimizer.zero_grad(set_to_none=True)
        if use_amp and scaler is not None:
            with torch.amp.autocast(device_type="cuda"):
                logits = model(x, pair, mask)
                loss = loss_fn(logits.view(-1, num_classes), y.view(-1))
            scaler.scale(loss).backward()
            if grad_clip_norm:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x, pair, mask)
            loss = loss_fn(logits.view(-1, num_classes), y.view(-1))
            loss.backward()
            if grad_clip_norm:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)
            optimizer.step()

        last_loss = float(loss.item())
        with torch.no_grad():
            pred  = logits.argmax(-1)
            valid = (mask & (y != PAD_ID))
            correct += (pred[valid] == y[valid]).sum().item()
            total   += valid.sum().item()

    return {"loss": last_loss, "acc": correct / max(total, 1)}

@torch.no_grad()
def eval_epoch(model, loader, device, num_classes, class_weights=None):
    model.eval()
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID, weight=class_weights)
    total_valid = 0; top1_correct = 0; top3_correct = 0
    sum_loss = 0.0; n_batches = 0
    all_preds, all_trues = [], []

    for x, pair, y, mask in loader:
        x, pair, y, mask = x.to(device), pair.to(device), y.to(device), mask.to(device)
        logits = model(x, pair, mask)
        loss   = loss_fn(logits.view(-1, num_classes), y.view(-1))
        sum_loss += float(loss.item()); n_batches += 1

        prob = logits.softmax(-1); pred = prob.argmax(-1)
        valid = (mask & (y != PAD_ID))
        top1_correct += (pred[valid] == y[valid]).sum().item()
        total_valid  += valid.sum().item()

        _, topk_idx = prob.topk(3, dim=-1)
        top3_correct += (topk_idx[valid].eq(y[valid].unsqueeze(-1))).any(dim=-1).sum().item()

        all_preds.extend(pred[valid].detach().cpu().tolist())
        all_trues.extend(y[valid].detach().cpu().tolist())

    macro_f1 = f1_score(all_trues, all_preds, average="macro") if all_trues else 0.0
    return {
        "loss": (sum_loss / max(n_batches, 1)),
        "top1": (top1_correct / max(total_valid, 1)),
        "top3": (top3_correct / max(total_valid, 1)),
        "macro_f1": macro_f1,
    }


In [None]:
# cell 16 학습 루프
best, patience, bad = None, HYPERPARAMS["patience"], 0

for epoch in range(HYPERPARAMS["epochs"]):
    tr = train_one_epoch(model, train_loader, optimizer, device,
                         num_classes=NUM_CLASSES, class_weights=class_weights,
                         use_amp=(device=="cuda" and HYPERPARAMS["use_amp"]),
                         grad_clip_norm=HYPERPARAMS["grad_clip_norm"], scaler=scaler)
    ev = eval_epoch(model,  val_loader,   device,
                    num_classes=NUM_CLASSES, class_weights=class_weights)

    # 스케줄러(plateau) — 최대화 기준: top1
    scheduler.step(ev["top1"])

    print(f"[E{epoch:02d}] "
          f"train loss={tr['loss']:.4f} acc={tr['acc']:.4f} | "
          f"val loss={ev['loss']:.4f} top1={ev['top1']:.4f} "
          f"top3={ev['top3']:.4f} macroF1={ev['macro_f1']:.4f}")

    if not best or ev["top1"] > best["top1"]:
        best, bad = {"epoch": epoch, **ev}, 0
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "metrics": ev,
            "num_classes": NUM_CLASSES,
            "pad_id": PAD_ID,
            "hparams": HYPERPARAMS,
        }, str(BEST_PATH))
        print(f"[SAVE] best -> {BEST_PATH}  (epoch={epoch}, top1={ev['top1']:.4f})")
    else:
        bad += 1
        if bad >= patience:
            print("Early stopping at epoch", epoch)
            break

print("[BEST]", best)


In [None]:
# cell 17 Optuna 하이퍼 파라미터 탐색
import optuna

def objective(trial):
    # 탐색 공간
    emb_dim    = trial.suggest_int("emb_dim", 96, 256, step=32)
    pair_dim   = trial.suggest_int("pair_dim", 8, 32, step=8)
    hid        = trial.suggest_int("hid", 128, 512, step=64)
    num_layers = trial.suggest_int("num_layers", 1, 3)
    dropout    = trial.suggest_float("dropout", 0.1, 0.5)
    lr         = trial.suggest_float("lr", 1e-4, 3e-3, log=True)
    wd         = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)

    m = LSTMClassifier(
        vocab_size=NUM_CLASSES, emb_dim=emb_dim, pair_dim=pair_dim,
        hid=hid, num_layers=num_layers, dropout=dropout
    ).to(device)
    opt = AdamW(m.parameters(), lr=lr, weight_decay=wd)

    scaler_tmp = GradScaler("cuda", enabled=(device=="cuda" and HYPERPARAMS["use_amp"]))

    # 짧게 예열
    for _ in range(3):
        train_one_epoch(m, train_loader, opt, device, NUM_CLASSES, class_weights=class_weights,
                        use_amp=(device=="cuda" and HYPERPARAMS["use_amp"]),
                        grad_clip_norm=HYPERPARAMS["grad_clip_norm"], scaler=scaler_tmp)
    ev = eval_epoch(m, val_loader, device, NUM_CLASSES, class_weights=class_weights)
    return ev["top1"]

study = optuna.create_study(direction="maximize", study_name="next_state_lstm_top1")
study.optimize(objective, n_trials=15, show_progress_bar=True)

print("[BEST VALUE]", study.best_value)
print("[BEST PARAMS]")
for k, v in study.best_trial.params.items():
    print(f"  {k}: {v}")


In [None]:
# cell 18 최종 모델 학습
# === Cell 18: Apply Optuna best -> update HYPERPARAMS -> retrain final model ===
import json, torch
from torch.optim import AdamW
from torch.amp import GradScaler

# ---- 안전 체크 (Cell 17을 먼저 실행했는지 등) ----
assert 'study' in globals(), "Optuna study가 없습니다. 먼저 Cell 17(Optuna)을 실행하세요."
assert 'NUM_CLASSES' in globals(), "NUM_CLASSES가 필요합니다."
assert 'PAD_ID' in globals(), "PAD_ID가 필요합니다."
assert 'HYPERPARAMS' in globals(), "HYPERPARAMS 딕셔너리가 필요합니다."
assert 'train_loader' in globals() and 'val_loader' in globals(), "train/val DataLoader가 필요합니다."
assert 'train_one_epoch' in globals() and 'eval_epoch' in globals(), "train_one_epoch / eval_epoch 함수가 필요합니다."
assert 'LSTMClassifier' in globals(), "LSTMClassifier 클래스가 필요합니다."
device = globals().get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
BEST_PATH = globals().get('BEST_PATH', './best_model.pt')
class_weights = globals().get('class_weights', None)

# ---- 1) Optuna best 가져와서 HYPERPARAMS 업데이트 ----
best_params = study.best_trial.params
HYPERPARAMS.update({
    "emb_dim":      best_params["emb_dim"],
    "pair_dim":     best_params["pair_dim"],
    "hid":          best_params["hid"],
    "num_layers":   best_params["num_layers"],
    "dropout":      best_params["dropout"],
    "lr":           best_params["lr"],
    "weight_decay": best_params["weight_decay"],
})
print("[INFO] HYPERPARAMS updated with Optuna best:", HYPERPARAMS)

# (옵션) 베스트 파라미터를 JSON으로 남기기
try:
    from pathlib import Path
    BEST_JSON = (Path(BEST_PATH).parent / "best_hparams.json")
    with open(BEST_JSON, "w") as f:
        json.dump({"optuna_best": best_params, "hparams": HYPERPARAMS}, f, indent=2)
    print(f"[SAVE] best_hparams -> {BEST_JSON}")
except Exception as e:
    print("[WARN] best_hparams 저장 생략:", e)

# ---- 2) 베스트 HYPERPARAMS로 모델/옵티마이저/스케줄러 준비 ----
USE_AMP   = (device == "cuda") and bool(HYPERPARAMS.get("use_amp", True))
EPOCHS    = int(HYPERPARAMS.get("epochs", 50))
PATIENCE  = int(HYPERPARAMS.get("patience", 3))
GRAD_CLIP = float(HYPERPARAMS.get("grad_clip_norm", 1.0))

model = LSTMClassifier(
    vocab_size=NUM_CLASSES,
    emb_dim=HYPERPARAMS["emb_dim"],
    pair_dim=HYPERPARAMS["pair_dim"],
    hid=HYPERPARAMS["hid"],
    num_layers=HYPERPARAMS["num_layers"],
    dropout=HYPERPARAMS["dropout"],
).to(device)

optimizer = AdamW(model.parameters(), lr=HYPERPARAMS["lr"], weight_decay=HYPERPARAMS["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)
scaler    = GradScaler("cuda", enabled=USE_AMP)

# ---- 3) 본 학습 루프 (조기 종료 & 체크포인트 저장) ----
best, bad = None, 0
for epoch in range(EPOCHS):
    tr = train_one_epoch(model, train_loader, optimizer, device,
                         num_classes=NUM_CLASSES, class_weights=class_weights,
                         use_amp=USE_AMP, grad_clip_norm=GRAD_CLIP, scaler=scaler)
    ev = eval_epoch(model,  val_loader,   device,
                    num_classes=NUM_CLASSES, class_weights=class_weights)

    # 스케줄러(Plateau) 기준: top1 최대화
    scheduler.step(ev["top1"])

    print(f"[E{epoch:02d}] train loss={tr['loss']:.4f} acc={tr['acc']:.4f} | "
          f"val loss={ev['loss']:.4f} top1={ev['top1']:.4f} top3={ev['top3']:.4f} macroF1={ev['macro_f1']:.4f}")

    if not best or ev["top1"] > best["top1"]:
        best, bad = {"epoch": epoch, **ev}, 0
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "metrics": ev,
            "num_classes": NUM_CLASSES,
            "pad_id": PAD_ID,
            "hparams": HYPERPARAMS,
        }, str(BEST_PATH))
        print(f"[SAVE] best -> {BEST_PATH}  (epoch={epoch}, top1={ev['top1']:.4f})")
    else:
        bad += 1
        if bad >= PATIENCE:
            print("Early stopping at epoch", epoch)
            break

print("\n[BEST FINAL]", best)