In [30]:
# -*- coding: utf-8 -*-
""" Fashion-MNIST 의류 이미지 분류 (올인원) - 확장 CNN + BatchNorm + Dropout - 학습/검증/테스트 + 혼동행렬 + 오분류 샘플 + 학습곡선 자동 저장
- 결과물: ./outputs_fashion/
"""
import os, time, random, argparse
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [31]:
# -------------------- 공통 셋업 --------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
 random.seed(seed); np.random.seed(seed)
 torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
 torch.backends.cudnn.benchmark = False
 torch.backends.cudnn.deterministic = True

CLASS_NAMES = ["T-shirt/top","Trouser","Pullover","Dress","Coat",
 "Sandal","Shirt","Sneaker","Bag","Ankle boot"]

In [32]:
# -------------------- 데이터 --------------------
def get_loaders(batch=64, num_workers=2):
 # 권장 통계: mean=0.2861, std=0.3530
 tf_train = transforms.Compose([
 transforms.RandomHorizontalFlip(p=0.5), # 약한 증강
 transforms.ToTensor(),
 transforms.Normalize((0.2861,), (0.3530,))
 ])
 tf_eval = transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.2861,), (0.3530,))
 ])
 full_train = datasets.FashionMNIST("./data", train=True, download=True,
transform=tf_train)
 test_set = datasets.FashionMNIST("./data", train=False, download=True,
transform=tf_eval)
 train_len = 55000
 valid_len = len(full_train) - train_len
 train_set, valid_set = random_split(full_train, [train_len, valid_len],
 generator=torch.Generator().manual_seed(42))
 # valid에는 증강 없이 평가 변환 적용
 valid_set.dataset.transform = tf_eval
 train_loader = DataLoader(train_set, batch_size=batch, shuffle=True,
 num_workers=num_workers, pin_memory=True)
 valid_loader = DataLoader(valid_set, batch_size=batch, shuffle=False,
 num_workers=num_workers, pin_memory=True)
 test_loader = DataLoader(test_set, batch_size=batch, shuffle=False,
 num_workers=num_workers, pin_memory=True)
 return train_loader, valid_loader, test_loader

In [33]:
# -------------------- 모델 --------------------
class Block(nn.Module):

 def __init__(self, in_c, out_c):
  super().__init__()
  self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
  self.bn1 = nn.BatchNorm2d(out_c)
  self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
  self.bn2 = nn.BatchNorm2d(out_c)

 def forward(self, x):
  x = F.relu(self.bn1(self.conv1(x)))
  x = F.relu(self.bn2(self.conv2(x)))
  return x

In [34]:
class FashionCNN(nn.Module):
 """
 입력: (N,1,28,28)
 출력: 10 클래스 logits
 """
 def __init__(self, dropout=0.3):
  super().__init__()
  self.block1 = Block(1, 32) # 28x28
  self.pool1 = nn.MaxPool2d(2,2) # 14x14
  self.block2 = Block(32, 64) # 14x14
  self.pool2 = nn.MaxPool2d(2,2) # 7x7
  self.fc1 = nn.Linear(64*7*7, 256)
  self.drop = nn.Dropout(dropout)
  self.fc2 = nn.Linear(256, 10)
 def forward(self, x):
  x = self.pool1(self.block1(x))
  x = self.pool2(self.block2(x))
  x = x.view(x.size(0), -1)
  x = F.relu(self.fc1(x))
  x = self.drop(x)
  return self.fc2(x)

In [35]:
# -------------------- 루프 --------------------
def train_one_epoch(model, loader, opt, scaler, use_amp=False):
 model.train()
 crit = nn.CrossEntropyLoss()
 total, correct, loss_sum = 0, 0, 0.0
 for xb, yb in loader:
  xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE,
non_blocking=True)
  opt.zero_grad(set_to_none=True)
  if use_amp and DEVICE.type == "cuda":
   with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
    logits = model(xb); loss = crit(logits, yb)
   scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
  else:
   logits = model(xb); loss = crit(logits, yb); loss.backward(); opt.step()
  loss_sum += loss.item() * xb.size(0)
  correct += (logits.argmax(1) == yb).sum().item()
  total += xb.size(0)
 return loss_sum/total, correct/total


In [36]:
@torch.no_grad()
def evaluate(model, loader):
 model.eval()
 crit = nn.CrossEntropyLoss()
 total, correct, loss_sum = 0, 0, 0.0
 all_preds, all_targets = [], []
 for xb, yb in loader:
  xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE,
non_blocking=True)
  logits = model(xb); loss = crit(logits, yb)
  loss_sum += loss.item() * xb.size(0)
  preds = logits.argmax(1)
  correct += (preds == yb).sum().item()
  total += xb.size(0)
  all_preds.append(preds.cpu().numpy()); all_targets.append(yb.cpu().numpy())
  return loss_sum/total, correct/total, np.concatenate(all_preds), np.concatenate(all_targets)

In [37]:
# -------------------- 시각화 --------------------
def plot_curves(hist, save_path):
 os.makedirs(os.path.dirname(save_path), exist_ok=True)
 ep = range(1, len(hist["train_loss"])+1)
 plt.figure(figsize=(10,4))
 plt.subplot(1,2,1)
 plt.plot(ep, hist["train_loss"], label="Train Loss")
 plt.plot(ep, hist["valid_loss"], label="Valid Loss")
 plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curve"); plt.legend()
 plt.subplot(1,2,2)
 plt.plot(ep, hist["train_acc"], label="Train Acc")
 plt.plot(ep, hist["valid_acc"], label="Valid Acc")
 plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title("Accuracy Curve"); plt.legend()
 plt.tight_layout(); plt.savefig(save_path, dpi=150); plt.close()

In [38]:
def plot_confmat(y_true, y_pred, save_path):
 os.makedirs(os.path.dirname(save_path), exist_ok=True)
 cm = confusion_matrix(y_true, y_pred)
 disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=CLASS_NAMES)
 fig, ax = plt.subplots(figsize=(7,7))
 disp.plot(ax=ax, cmap="Blues", xticks_rotation=45, colorbar=False,
values_format="d")
 plt.title("Confusion Matrix (Fashion-MNIST)"); plt.tight_layout()
 plt.savefig(save_path, dpi=150); plt.close()

def save_misclassified(y_true, y_pred, save_path, max_samples=25):
 os.makedirs(os.path.dirname(save_path), exist_ok=True)
 mis_idx = np.where(y_true != y_pred)[0][:max_samples]
 if len(mis_idx)==0:
  fig = plt.figure(figsize=(6,2)); plt.text(0.5,0.5,"No misclassified samples.", ha="center", va="center")
  plt.axis("off"); plt.savefig(save_path, dpi=150); plt.close();
  return
   # 정규화 되돌린 간단한 시각화
 inv = transforms.Normalize(mean=(-0.2861/0.3530,), std=(1/0.3530,))
 ds = datasets.FashionMNIST("./data", train=False, download=True,
                           transform=transforms.Compose([transforms.ToTensor(),
                                                         transforms.Normalize((0.2861,), (0.3530,))]))
 cols=5; rows=int(np.ceil(len(mis_idx)/cols))
 plt.figure(figsize=(cols*2.2, rows*2.6))
 for i, idx in enumerate(mis_idx):
  img, gt = ds[idx]
  img = inv(img).squeeze().numpy()
  pred = y_pred[idx]
  plt.subplot(rows, cols, i+1)
  plt.imshow(img, cmap="gray")
  plt.title(f"GT:{CLASS_NAMES[gt]}\nPred:{CLASS_NAMES[pred]}", fontsize=9)
  plt.axis("off")
 plt.tight_layout(); plt.savefig(save_path, dpi=150); plt.close()

In [39]:
# -------------------- 메인 --------------------
def run_train(args):
 set_seed(args.seed)
 os.makedirs(args.out_dir, exist_ok=True)
 print(f"✅ Device: {DEVICE} | AMP: {'ON' if (args.amp and DEVICE.type=='cuda') else 'OFF'}")

 tr_loader, va_loader, te_loader = get_loaders(args.batch_size, args.num_workers)
 model = FashionCNN(dropout=args.dropout).to(DEVICE)
 opt = Adam(model.parameters(), lr=args.lr)
 scaler = torch.amp.GradScaler(enabled=(args.amp and DEVICE.type=='cuda'))

 hist = {"train_loss":[], "valid_loss":[], "train_acc":[], "valid_acc":[]}
 best_acc, best_path = 0.0, os.path.join(args.out_dir, "best_fashion_cnn.pt")

 for epoch in range(1, args.epochs+1):
  t0=time.time()
  tr_loss, tr_acc = train_one_epoch(model, tr_loader, opt, scaler, use_amp=args.amp)
  va_loss, va_acc, _, _ = evaluate(model, va_loader)
  hist["train_loss"].append(tr_loss); hist["valid_loss"].append(va_loss)
  hist["train_acc"].append(tr_acc); hist["valid_acc"].append(va_acc)
  if va_acc > best_acc:
   best_acc = va_acc; torch.save(model.state_dict(), best_path)
  print(f"[Epoch {epoch:02d}/{args.epochs}] "
        f"Train loss {tr_loss:.4f}, acc {tr_acc:.4f} | "
        f"Valid loss {va_loss:.4f}, acc {va_acc:.4f} | BestAcc {best_acc:.4f} | "
        f"{time.time()-t0:.1f}s")

 plot_curves(hist, os.path.join(args.out_dir, "train_curve.png"))

 # Test
 model.load_state_dict(torch.load(best_path, map_location=DEVICE))
 te_loss, te_acc, y_pred, y_true = evaluate(model, te_loader)
 print(f"\n Test: loss={te_loss:.4f}, acc={te_acc:.4f}")

 plot_confmat(y_true, y_pred, os.path.join(args.out_dir, "confusion_matrix.png"))
 save_misclassified(y_true, y_pred, os.path.join(args.out_dir, "misclassified_samples.png"))

 print(f"\n 결과 저장 위치: {os.path.abspath(args.out_dir)}")
 print(" - 모델 가중치: best_fashion_cnn.pt")
 print(" - 학습곡선: train_curve.png")
 print(" - 혼동행렬: confusion_matrix.png")
 print(" - 오분류샘플: misclassified_samples.png")

In [40]:
def main():
 ap = argparse.ArgumentParser("Fashion-MNIST CNN (All-in-One)")
 ap.add_argument("--epochs", type=int, default=10)
 ap.add_argument("--batch-size", type=int, default=64)
 ap.add_argument("--lr", type=float, default=1e-3)
 ap.add_argument("--dropout", type=float, default=0.3)
 ap.add_argument("--num-workers", type=int, default=2)
 ap.add_argument("--amp", action="store_true")
 ap.add_argument("--seed", type=int, default=42)
 ap.add_argument("--out-dir", type=str, default="./outputs_fashion")
 args, _ = ap.parse_known_args()
 run_train(args)

if __name__ == "__main__":
 main()

✅ Device: cpu | AMP: OFF
[Epoch 01/10] Train loss 0.3925, acc 0.8580 | Valid loss 0.3167, acc 0.8281 | BestAcc 0.8281 | 39.9s
[Epoch 02/10] Train loss 0.2585, acc 0.9058 | Valid loss 0.2177, acc 0.9219 | BestAcc 0.9219 | 44.9s
[Epoch 03/10] Train loss 0.2199, acc 0.9201 | Valid loss 0.3204, acc 0.8750 | BestAcc 0.9219 | 44.9s
[Epoch 04/10] Train loss 0.1949, acc 0.9274 | Valid loss 0.2014, acc 0.9062 | BestAcc 0.9219 | 45.3s
[Epoch 05/10] Train loss 0.1711, acc 0.9362 | Valid loss 0.2488, acc 0.9062 | BestAcc 0.9219 | 41.6s
[Epoch 06/10] Train loss 0.1529, acc 0.9438 | Valid loss 0.1948, acc 0.9375 | BestAcc 0.9375 | 41.5s
[Epoch 07/10] Train loss 0.1329, acc 0.9507 | Valid loss 0.3821, acc 0.8594 | BestAcc 0.9375 | 41.1s
[Epoch 08/10] Train loss 0.1192, acc 0.9556 | Valid loss 0.2614, acc 0.9219 | BestAcc 0.9375 | 41.5s
[Epoch 09/10] Train loss 0.1045, acc 0.9606 | Valid loss 0.4213, acc 0.8906 | BestAcc 0.9375 | 41.4s
[Epoch 10/10] Train loss 0.0926, acc 0.9655 | Valid loss 0.2797, a