# SETUP 및 CIFAR-10 DATASET 준비

In [None]:
#0. Setup

import os, time, random
import numpy as np
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
device


In [None]:
#1. CIFAR-10 load & split

DATA_ROOT = "./data"

CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD  = (0.2470, 0.2435, 0.2616)


tf_noaug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

tf_test = tf_noaug

def make_splits(n, val_ratio=0.1, seed=42):
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    val_size = int(n * val_ratio)
    val_idx = idx[:val_size].tolist()
    train_idx = idx[val_size:].tolist()
    return train_idx, val_idx

class CIFARSplit(Dataset):

    def __init__(self, base_dataset, indices, transform):
        self.base = base_dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]
        img, y = self.base[idx]
        if self.transform is not None:
            img = self.transform(img)
        return img, y

# base dataset
base_train_raw = datasets.CIFAR10(root=DATA_ROOT, train=True,  download=True, transform=None)
base_test_raw  = datasets.CIFAR10(root=DATA_ROOT, train=False, download=True, transform=None)

train_idx, val_idx = make_splits(len(base_train_raw), val_ratio=0.1, seed=42)

train_full = CIFARSplit(base_train_raw, train_idx, transform=tf_noaug)
val_set    = CIFARSplit(base_train_raw, val_idx,   transform=tf_noaug)
test_set   = CIFARSplit(base_test_raw,  list(range(len(base_test_raw))), transform=tf_test)

print("full train:", len(train_full), "val:", len(val_set), "test:", len(test_set))


# ALEXNET IMPLEMENTATION & OVERFITTING

In [None]:
#2. AlexNet style model implementation

class AlexNetCIFAR(nn.Module):
    def __init__(self, num_classes=10, dropout_p=0.0):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 192, 3, 1, 1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(192, 384, 3, 1, 1), nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, 3, 1, 1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout_p),
            nn.Linear(256 * 4 * 4, 1024), nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_p),
            nn.Linear(1024, 512), nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        if m.bias is not None:
            nn.init.zeros_(m.bias)

model = AlexNetCIFAR(dropout_p=0.0).to(device)
model.apply(init_weights)

x, y = next(iter(DataLoader(train_full, batch_size=8, shuffle=True)))
logits = model(x.to(device))
print("logits shape:", logits.shape)


In [None]:
#3. Overfitting

OVERFIT_SUBSET = 1000
BATCH_SIZE = 128
EPOCHS = 200
LR = 0.05
MOMENTUM = 0.9

# 1) train subset
train_overfit = torch.utils.data.Subset(
    train_full,
    list(range(min(OVERFIT_SUBSET, len(train_full))))
)

train_loader = DataLoader(train_overfit, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_set,       batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# 2) model with no regularization
model = AlexNetCIFAR(dropout_p=0.0).to(device)
model.apply(init_weights)

# 3) loss / optimizer
crit = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=0.0)

def accuracy(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()

history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}


OUT_OVER_DIR = "./outputs/overfit"
os.makedirs(OUT_OVER_DIR, exist_ok=True)
OVERFIT_CKPT_PATH = os.path.join(OUT_OVER_DIR, "overfit.pth")

# early stop
TARGET_TRAIN_ACC = 0.99
GAP_THRESHOLD = 0.15
PATIENCE = 2
hit = 0

epoch_times_over = []

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    #train
    model.train()
    tr_loss, tr_acc, n = 0.0, 0.0, 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = crit(logits, y)
        loss.backward()
        optimizer.step()

        bs = x.size(0)
        tr_loss += loss.item() * bs
        tr_acc  += accuracy(logits, y) * bs
        n += bs

    tr_loss /= n
    tr_acc  /= n

    #val
    model.eval()
    va_loss, va_acc, n = 0.0, 0.0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)

            bs = x.size(0)
            va_loss += loss.item() * bs
            va_acc  += accuracy(logits, y) * bs
            n += bs

    va_loss /= n
    va_acc  /= n

    history["train_loss"].append(tr_loss)
    history["train_acc"].append(tr_acc)
    history["val_loss"].append(va_loss)
    history["val_acc"].append(va_acc)

    if epoch == 1 or epoch % 10 == 0 or epoch == EPOCHS:
        print(f"Epoch {epoch:03d}/{EPOCHS} | "
              f"train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f} | "
              f"{time.time()-t0:.1f}s")

    epoch_times_over.append(time.time() - t0)

    # early stop check
    # a. train acc가 충분히 높으면 바로 종료
    cond_train_fit = (tr_acc >= TARGET_TRAIN_ACC)

    # b. train, val의 gap이 커야함
    cond_gap = ((tr_acc - va_acc) >= GAP_THRESHOLD)


    if cond_train_fit and cond_gap:
        hit += 1
    else:
        hit = 0

    # train이 거의 완벽하면 바로 종료
    if cond_train_fit:
        print(f"Early stop: train acc reached {tr_acc:.4f} at epoch {epoch}. Saving checkpoint...")
        torch.save({
            "model": model.state_dict(),
            "meta": {
                "subset": OVERFIT_SUBSET,
                "epoch": epoch,
                "train_acc": tr_acc,
                "val_acc": va_acc,
                "stop_reason": "train_acc_threshold"
            }
        }, OVERFIT_CKPT_PATH)
        break

if not os.path.exists(OVERFIT_CKPT_PATH):
    print("Early stop not triggered. Saving final checkpoint...")
    torch.save({
        "model": model.state_dict(),
        "meta": {
            "subset": OVERFIT_SUBSET,
            "epoch": EPOCHS,
            "train_acc": history["train_acc"][-1],
            "val_acc": history["val_acc"][-1]
        }
    }, OVERFIT_CKPT_PATH)

history_overfit = history

print("Final train acc:", history["train_acc"][-1], "Final val acc:", history["val_acc"][-1])
print("Saved overfit checkpoint to:", OVERFIT_CKPT_PATH)


test_loader_over = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

model.eval()
te_loss, te_acc, n = 0.0, 0.0, 0
with torch.no_grad():
    for x, y in test_loader_over:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = crit(logits, y)

        bs = x.size(0)
        te_loss += loss.item() * bs
        te_acc  += accuracy(logits, y) * bs
        n += bs

te_loss /= n
te_acc  /= n
print(f"[overfit] test loss = {te_loss:.4f} | test acc = {te_acc:.4f}")


# 마지막 acc 출력
print("Final train acc:", history["train_acc"][-1], "Final val acc:", history["val_acc"][-1])

plt.figure()
plt.plot(history["train_loss"], label="train_loss")
plt.plot(history["val_loss"], label="val_loss")
plt.legend(); plt.title("Overfit: loss"); plt.show()

plt.figure()
plt.plot(history["train_acc"], label="train_acc")
plt.plot(history["val_acc"], label="val_acc")
plt.legend(); plt.title("Overfit: acc"); plt.show()


# REGULARIZATION & DATA AUGMENTATION

In [None]:
#4. Regularization & Data Augmentation

# a. augmentation transform
tf_aug = transforms.Compose([
    transforms.RandomCrop(32, padding=4),   # aug #1
    transforms.RandomHorizontalFlip(),      # aug #2
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

# b. FULL train set을 augmentation transform으로 다시 만들기
train_full_aug = CIFARSplit(base_train_raw, train_idx, transform=tf_aug)

BATCH_SIZE_REG = 128
EPOCHS_REG = 100
LR_REG = 0.01
MOMENTUM_REG = 0.9

DROPOUT_P = 0.5        # regularization #1
WEIGHT_DECAY = 5e-4    # regularization #2

train_loader_reg = DataLoader(train_full_aug, batch_size=BATCH_SIZE_REG, shuffle=True,
                              num_workers=2, pin_memory=True)
val_loader_reg   = DataLoader(val_set,       batch_size=BATCH_SIZE_REG, shuffle=False,
                              num_workers=2, pin_memory=True)
test_loader      = DataLoader(test_set,      batch_size=BATCH_SIZE_REG, shuffle=False,
                              num_workers=2, pin_memory=True)

# 3) 모델 구성 (dropout ON)
model_reg = AlexNetCIFAR(dropout_p=DROPOUT_P).to(device)
model_reg.apply(init_weights)

crit = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_reg.parameters(), lr=LR_REG, momentum=MOMENTUM_REG, weight_decay=WEIGHT_DECAY)

def accuracy(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()

def eval_loop(model, loader):
    model.eval()
    total_loss, total_acc, n = 0.0, 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = crit(logits, y)
            bs = x.size(0)
            total_loss += loss.item() * bs
            total_acc  += accuracy(logits, y) * bs
            n += bs
    return total_loss / n, total_acc / n

history_reg = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
best_val_acc = -1.0

OUT_REG_DIR = "./outputs/regularized"
os.makedirs(OUT_REG_DIR, exist_ok=True)
REG_CKPT_PATH = os.path.join(OUT_REG_DIR, "regularized.pth")

epoch_times_reg = []


for epoch in range(1, EPOCHS_REG + 1):
    t0 = time.time()

    # train
    model_reg.train()
    tr_loss, tr_acc, n = 0.0, 0.0, 0
    for x, y in train_loader_reg:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model_reg(x)
        loss = crit(logits, y)
        loss.backward()
        optimizer.step()

        bs = x.size(0)
        tr_loss += loss.item() * bs
        tr_acc  += accuracy(logits, y) * bs
        n += bs

    tr_loss /= n
    tr_acc  /= n

    # val
    va_loss, va_acc = eval_loop(model_reg, val_loader_reg)

    history_reg["train_loss"].append(tr_loss)
    history_reg["train_acc"].append(tr_acc)
    history_reg["val_loss"].append(va_loss)
    history_reg["val_acc"].append(va_acc)

    # best 저장
    if va_acc > best_val_acc:
        best_val_acc = va_acc
        torch.save({
            "model": model_reg.state_dict(),
            "meta": {
                "dropout_p": DROPOUT_P,
                "weight_decay": WEIGHT_DECAY,
                "aug": ["RandomCrop(32,pad=4)", "RandomHorizontalFlip"],
                "epoch": epoch,
                "best_val_acc": best_val_acc
            }
        }, REG_CKPT_PATH)

    if epoch == 1 or epoch % 10 == 0 or epoch == EPOCHS_REG:
        print(f"[regularized] Epoch {epoch:03d}/{EPOCHS_REG} | "
              f"train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f} | "
              f"{time.time()-t0:.1f}s")


    epoch_times_reg.append(time.time() - t0)


# test 평가
ckpt = torch.load(REG_CKPT_PATH, map_location=device)
model_reg.load_state_dict(ckpt["model"])
test_loss_reg, test_acc_reg = eval_loop(model_reg, test_loader)

print(f"[regularized] BEST val acc = {best_val_acc:.4f} | test acc = {test_acc_reg:.4f}")

with open(os.path.join(OUT_REG_DIR, "log.json"), "w") as f:
    json.dump({**history_reg, "best_val_acc": best_val_acc, "test_loss": test_loss_reg, "test_acc": test_acc_reg}, f, indent=2)


In [None]:
# regularized model 결과 시각화

import os, json
import matplotlib.pyplot as plt

if "history_reg" not in globals() or history_reg is None or len(history_reg.get("train_loss", [])) == 0:
    log_path = os.path.join("./outputs/regularized", "log.json")
    assert os.path.exists(log_path), f"log.json not found: {log_path}"
    with open(log_path, "r") as f:
        history_reg = json.load(f)

plt.figure(figsize=(10,4))

# loss
plt.subplot(1,2,1)
plt.plot(history_reg["train_loss"], label="train_loss")
plt.plot(history_reg["val_loss"],   label="val_loss")
plt.legend()
plt.title("Regularized: loss")
plt.xlabel("epoch")
plt.ylabel("loss")

# acc
plt.subplot(1,2,2)
plt.plot(history_reg["train_acc"], label="train_acc")
plt.plot(history_reg["val_acc"],   label="val_acc")
plt.legend()
plt.title("Regularized: acc")
plt.xlabel("epoch")
plt.ylabel("accuracy")

plt.tight_layout()
plt.show()


# GENERALIZATION ANALYSIS

In [None]:
#5-1. Quantitative comparison

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader


h_over = history_overfit
h_reg  = history_reg


OVERFIT_CKPT_PATH = "./outputs/overfit/overfit.pth"
REG_CKPT_PATH     = "./outputs/regularized/regularized.pth"

@torch.no_grad()
def eval_acc(model, loader):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total

def load_model_from_ckpt(ckpt_path, dropout_p):
    ckpt = torch.load(ckpt_path, map_location=device)
    state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    m = AlexNetCIFAR(dropout_p=dropout_p).to(device)
    m.load_state_dict(state)
    return m


m_over = load_model_from_ckpt(OVERFIT_CKPT_PATH, dropout_p=0.0)
m_reg  = load_model_from_ckpt(REG_CKPT_PATH,     dropout_p=DROPOUT_P)  # 학습 설정과 동일

# same evaluation sets
EVAL_BS = 256

train_eval_loader_over = DataLoader(train_overfit, batch_size=EVAL_BS, shuffle=False, num_workers=2, pin_memory=True)

train_eval_loader = DataLoader(train_full, batch_size=EVAL_BS, shuffle=False, num_workers=2, pin_memory=True)

val_eval_loader   = DataLoader(val_set,    batch_size=EVAL_BS, shuffle=False, num_workers=2, pin_memory=True)
test_eval_loader  = DataLoader(test_set,   batch_size=EVAL_BS, shuffle=False, num_workers=2, pin_memory=True)

# accuracies
final_train_over = eval_acc(m_over, train_eval_loader_over)  # ✅ (수정) 1000장 subset 기준
final_val_over   = eval_acc(m_over, val_eval_loader)
final_test_over  = eval_acc(m_over, test_eval_loader)

final_train_reg = eval_acc(m_reg, train_eval_loader)
final_val_reg   = eval_acc(m_reg, val_eval_loader)
final_test_reg  = eval_acc(m_reg, test_eval_loader)


test_acc_over, test_acc_reg = final_test_over, final_test_reg

rows = [
    ["Overfit",     final_train_over, final_val_over, final_test_over, final_train_over - final_val_over],
    ["Regularized", final_train_reg,  final_val_reg,  final_test_reg,  final_train_reg  - final_val_reg],
]

print("=== Final Accuracy Summary (SAME val/test; Overfit-train is on its 1000-sample subset) ===")
print(f"{'Model':<12} | {'Train Acc':>9} | {'Val Acc':>7} | {'Test Acc':>8} | {'Train-Val Gap':>12}")
print("-"*78)
for r in rows:
    print(f"{r[0]:<12} | {r[1]:9.4f} | {r[2]:7.4f} | {r[3]:8.4f} | {r[4]:12.4f}")


# curves
m = min(len(h_over["train_loss"]), len(h_reg["train_loss"]))
xs = np.arange(1, m+1)

plt.figure()
plt.plot(xs, h_over["train_loss"][:m], label="overfit_train_loss")
plt.plot(xs, h_over["val_loss"][:m],   label="overfit_val_loss")
plt.plot(xs, h_reg["train_loss"][:m],  label="reg_train_loss")
plt.plot(xs, h_reg["val_loss"][:m],    label="reg_val_loss")
plt.xlabel("epoch"); plt.ylabel("loss")
plt.title("Train vs Val Loss (Both Models on same plot)")
plt.legend(); plt.tight_layout(); plt.show()

plt.figure()
plt.plot(xs, h_over["train_acc"][:m], label="overfit_train_acc")
plt.plot(xs, h_over["val_acc"][:m],   label="overfit_val_acc")
plt.plot(xs, h_reg["train_acc"][:m],  label="reg_train_acc")
plt.plot(xs, h_reg["val_acc"][:m],    label="reg_val_acc")
plt.xlabel("epoch"); plt.ylabel("accuracy")
plt.title("Train vs Val Accuracy (Both Models on same plot)")
plt.legend(); plt.tight_layout(); plt.show()

gap_over = np.array(h_over["train_acc"][:m]) - np.array(h_over["val_acc"][:m])
gap_reg  = np.array(h_reg["train_acc"][:m])  - np.array(h_reg["val_acc"][:m])

plt.figure()
plt.plot(xs, gap_over, label="overfit_gap(train-val)")
plt.plot(xs, gap_reg,  label="reg_gap(train-val)")
plt.xlabel("epoch"); plt.ylabel("train-val acc gap")
plt.title("Overfitting evidence: Train-Val Accuracy Gap")
plt.legend(); plt.tight_layout(); plt.show()



In [None]:
#5-2. Qualitative comparison

import math
import matplotlib.pyplot as plt
import numpy as np
import torch


if "m_over" not in globals():
    m_over = load_model_from_ckpt(OVERFIT_CKPT_PATH, dropout_p=0.0)
if "m_reg" not in globals():
    m_reg  = load_model_from_ckpt(REG_CKPT_PATH, dropout_p=DROPOUT_P)


CLASS_NAMES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

@torch.no_grad()
def collect_correct_incorrect(model, base_test_raw, tf_test, n_each=8):
    model.eval()
    correct, incorrect = [], []
    for idx in range(len(base_test_raw)):
        img_pil, y = base_test_raw[idx]
        x = tf_test(img_pil).unsqueeze(0).to(device)
        pred = model(x).argmax(dim=1).item()

        if pred == y and len(correct) < n_each:
            correct.append((idx, img_pil, y, pred))
        if pred != y and len(incorrect) < n_each:
            incorrect.append((idx, img_pil, y, pred))

        if len(correct) >= n_each and len(incorrect) >= n_each:
            break
    return correct, incorrect

def show_grid(samples, title):
    cols = 4
    rows = math.ceil(len(samples)/cols)
    plt.figure(figsize=(cols*3, rows*3))
    for i, (idx, img, y, pred) in enumerate(samples):
        plt.subplot(rows, cols, i+1)
        plt.imshow(img)
        plt.axis("off")
        plt.title(f"GT:{CLASS_NAMES[y]}\nPred:{CLASS_NAMES[pred]}\nidx={idx}", fontsize=9)
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# a. 정/오분류 시각화
correct_over, incorrect_over = collect_correct_incorrect(m_over, base_test_raw, tf_test, n_each=8)
correct_reg,  incorrect_reg  = collect_correct_incorrect(m_reg,  base_test_raw, tf_test, n_each=8)

show_grid(correct_over,   "Overfit model: Correctly classified (test)")
show_grid(incorrect_over, "Overfit model: Misclassified (test)")
show_grid(correct_reg,    "Regularized model: Correctly classified (test)")
show_grid(incorrect_reg,  "Regularized model: Misclassified (test)")

# b. Train time 비교
if "epoch_times_over" in globals() and "epoch_times_reg" in globals():
    print("=== Training Time Comparison ===")
    print(f"Overfit:     total {sum(epoch_times_over):.1f}s | per-epoch {np.mean(epoch_times_over):.2f}s")
    print(f"Regularized: total {sum(epoch_times_reg):.1f}s | per-epoch {np.mean(epoch_times_reg):.2f}s")

    plt.figure()
    plt.plot(epoch_times_over, label="overfit_epoch_time")
    plt.plot(epoch_times_reg,  label="reg_epoch_time")
    plt.xlabel("epoch"); plt.ylabel("seconds")
    plt.title("Training Time per Epoch (if recorded)")
    plt.legend(); plt.tight_layout(); plt.show()
else:
    print("NOTE: epoch_times_over/epoch_times_reg not found. (Optional) If you want, I can show how to log epoch times during training.")

# c. overfit 근거 시각화
m = min(len(h_over["val_loss"]), len(h_reg["val_loss"]))
xs = np.arange(1, m+1)

plt.figure()
plt.plot(xs, h_over["val_loss"][:m], label="overfit_val_loss")
plt.plot(xs, h_reg["val_loss"][:m],  label="reg_val_loss")
plt.xlabel("epoch"); plt.ylabel("val loss")
plt.title("Overfitting evidence: Validation Loss (Overfit vs Regularized)")
plt.legend(); plt.tight_layout(); plt.show()

# d. training dynamics 시각화: 개선 속도 비교 (val_acc가 epoch마다 얼마나 빨리 증가/감소하는지)
val_acc_over = np.array(h_over["val_acc"][:m])
val_acc_reg  = np.array(h_reg["val_acc"][:m])

delta_over = np.diff(val_acc_over, prepend=val_acc_over[0])
delta_reg  = np.diff(val_acc_reg,  prepend=val_acc_reg[0])

plt.figure()
plt.plot(xs, delta_over, label="overfit Δval_acc")
plt.plot(xs, delta_reg,  label="reg Δval_acc")
plt.xlabel("epoch"); plt.ylabel("change in val_acc")
plt.title("Training dynamics: Val accuracy improvement speed (Δval_acc)")
plt.legend(); plt.tight_layout(); plt.show()

# e. 최종 일반화 성능 시각화: bar chart (train/val/test acc)
train_accs = [final_train_over, final_train_reg]
val_accs   = [final_val_over,   final_val_reg]
test_accs  = [test_acc_over,    test_acc_reg]
labels = ["Overfit", "Regularized"]
x = np.arange(len(labels))
w = 0.25

plt.figure()
plt.bar(x - w, train_accs, width=w, label="train")
plt.bar(x,     val_accs,   width=w, label="val")
plt.bar(x + w, test_accs,  width=w, label="test")
plt.xticks(x, labels)
plt.ylim(0, 1.0)
plt.ylabel("accuracy")
plt.title("Final Generalization Performance (Train/Val/Test)")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
# 6. 모델 저장

import os, torch

SUB_DIR = "./submission_models"
os.makedirs(SUB_DIR, exist_ok=True)

OVERFIT_SRC = "./outputs/overfit/overfit.pth"
REG_SRC     = "./outputs/regularized/regularized.pth"

def save_state_dict_only(src_path, dst_path):
    ckpt = torch.load(src_path, map_location="cpu")
    state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
    torch.save(state, dst_path)

save_state_dict_only(OVERFIT_SRC, os.path.join(SUB_DIR, "overfit_state_dict.pth"))
save_state_dict_only(REG_SRC,     os.path.join(SUB_DIR, "regularized_state_dict.pth"))

print("✅ Saved clean model parameters (state_dict only) to:", SUB_DIR)
print(" -", os.path.join(SUB_DIR, "overfit_state_dict.pth"))
print(" -", os.path.join(SUB_DIR, "regularized_state_dict.pth"))
