In [17]:
# -*- coding: utf-8 -*-
""" Fashion-MNIST 예측 + 시각화 스크립트
- 학습된 모델(best_fashion_cnn.pt) 로드
- 임의 이미지(내가 찍은 사진) 전처리 → 예측 → 요약 이미지 저장
- 결과물: ./outputs_fashion/<이름>_preprocessed.png, <이름>_summary.png,
summary_all.png
"""
import os, math, argparse
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
# ===== 기본 설정(고정 실행용 기본값) =====
DEFAULT_MODEL = "./outputs_fashion/best_fashion_cnn.pt"
DEFAULT_IMAGES = ["신발.jpg", "셔츠.jpg"] # 여기에 본인 파일명 넣어도 됨
DEFAULT_OUTDIR = "./outputs_fashion"
CLASS_NAMES = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot"]

In [19]:
# ===== 모델 정의 (훈련 코드와 동일) =====
class Block(nn.Module):
 def __init__(self, in_c, out_c):
  super().__init__()
  self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
  self.bn1 = nn.BatchNorm2d(out_c)
  self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
  self.bn2 = nn.BatchNorm2d(out_c)

 def forward(self, x):
  x = F.relu(self.bn1(self.conv1(x)))
  x = F.relu(self.bn2(self.conv2(x)))
  return x


class FashionCNN(nn.Module):
 def __init__(self, dropout=0.3):
  super().__init__()
  self.block1 = Block(1, 32) # 28x28
  self.pool1 = nn.MaxPool2d(2,2) # 14x14
  self.block2 = Block(32, 64) # 14x14
  self.pool2 = nn.MaxPool2d(2,2) # 7x7
  self.fc1 = nn.Linear(64*7*7, 256)
  self.drop = nn.Dropout(dropout)
  self.fc2 = nn.Linear(256, 10)

 def forward(self, x):
  x = self.pool1(self.block1(x))
  x = self.pool2(self.block2(x))
  x = x.view(x.size(0), -1)
  x = F.relu(self.fc1(x))
  x = self.drop(x)
  return self.fc2(x)

In [20]:
def _otsu_threshold(gray_u8: np.ndarray) -> int:
 hist, _ = np.histogram(gray_u8, bins=256, range=(0,256))
 total = gray_u8.size
 sum_total = np.dot(np.arange(256), hist)
 sum_b = w_b = max_var = 0.0
 thr = 0
 for t in range(256):
  w_b += hist[t]
  if w_b == 0: continue
  w_f = total - w_b
  if w_f == 0: break
  sum_b += t * hist[t]
  m_b = sum_b / w_b
  m_f = (sum_total - sum_b) / w_f
  var_between = w_b * w_f * (m_b - m_f) ** 2
  if var_between > max_var:
   max_var = var_between
   thr = t
  return thr


In [21]:
def preprocess_photo_to_fashion(path: str, save_preview: str | None = None):
 """
 1) Grayscale → 2) Otsu 이진화로 전경(BBOX) → 3) 정사각 패딩(여백 20%)
 4) 28×28 리사이즈 → 5) Tensor/Normalize(mean=0.2861, std=0.3530)
 반환:
 tens_norm: (1,28,28) float 텐서
 pre_u8: 28×28 uint8 (시각화용)
 """
 im = Image.open(path).convert("L")
 arr = np.array(im, dtype=np.uint8)

 # 전경(옷) 추정
 th = _otsu_threshold(arr)
 # 일반적으로 배경이 더 밝고, 물체(옷)가 더 어두운 경우가 많음
 bin_fg = (arr < th).astype(np.uint8)
 if bin_fg.sum() < 100: # 실패 시 반전 가정
  bin_fg = (arr > th).astype(np.uint8)

 ys, xs = np.where(bin_fg > 0)
 if len(xs) == 0 or len(ys) == 0:
  crop = arr # 전경 못 찾으면 전체 사용
 else:
  x0, x1 = xs.min(), xs.max()
  y0, y1 = ys.min(), ys.max()
  crop = arr[y0:y1+1, x0:x1+1]

 # 정사각 패딩 + 여백 20%
 h, w = crop.shape
 size = int(max(h, w) * 1.2)
 canvas = np.full((size, size), 255, dtype=np.uint8) # 흰 배경
 y_off = (size - h) // 2
 x_off = (size - w) // 2
 canvas[y_off:y_off+h, x_off:x_off+w] = crop

 # 28×28 리사이즈
 arr28 = np.array(Image.fromarray(canvas).resize((28, 28), Image.BILINEAR), dtype=np.uint8)

 # Tensor + Normalize(Fashion-MNIST 권장)
 tens = torch.from_numpy(arr28).float().unsqueeze(0) / 255.0 # [0,1]
 mean, std = 0.2861, 0.3530
 tens_norm = (tens - mean) / std # (1,28,28)
 if save_preview:
  Image.fromarray(arr28).save(save_preview)
 return tens_norm, arr28

In [22]:
# ===== 예측/시각화 =====
@torch.no_grad()
def predict(model, tens_norm, device):
 x = tens_norm.unsqueeze(0).to(device) # (1,1,28,28)
 logits = model(x)
 probs = torch.softmax(logits, dim=1).cpu().numpy().ravel()
 pred = int(probs.argmax())
 return pred, probs

def topk(probs: np.ndarray, k=3):
 idx = probs.argsort()[-k:][::-1]
 return [(int(i), float(probs[i])) for i in idx]

def topk_str(probs: np.ndarray, k=3):
 items = topk(probs, k)
 return ", ".join([f"{CLASS_NAMES[i]}:{p:.3f}" for i,p in items])

In [23]:
def make_summary(orig_path: str, pre_u8: np.ndarray, pred: int, probs: np.ndarray, save_path: str):
 plt.figure(figsize=(9,4))
 # 원본
 plt.subplot(1,2,1)
 img = Image.open(orig_path).convert("RGB")
 plt.imshow(img); plt.title("Original"); plt.axis("off")
 # 전처리 + 결과
 plt.subplot(1,2,2)
 plt.imshow(pre_u8, cmap="gray", interpolation="nearest")
 plt.title(f"Preprocessed (28×28)\nPred: {CLASS_NAMES[pred]}\nTop-3:{topk_str(probs,3)}")
 plt.axis("off")
 plt.tight_layout(); plt.savefig(save_path, dpi=150); plt.close()

def make_all_grid(summary_paths, save_path, cols=2):
 if not summary_paths: return
 ims = [Image.open(p).convert("RGB") for p in summary_paths]
 w, h = ims[0].size
 rows = int(np.ceil(len(ims)/cols))
 canvas = Image.new("RGB", (cols*w, rows*h), (255,255,255))
 for i, im in enumerate(ims):
  r, c = divmod(i, cols)
  canvas.paste(im, (c*w, r*h))
 canvas.save(save_path)

In [24]:
# ===== 메인 =====
def run(model_path: str, images: list[str], out_dir: str):
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 os.makedirs(out_dir, exist_ok=True)
 assert os.path.exists(model_path), f"모델이 없습니다: {model_path}"

 model = FashionCNN()
 model.load_state_dict(torch.load(model_path, map_location=device))
 model.to(device).eval()
 print(f"✅ 모델 로드 완료: {model_path}")

 summary_paths = []
 for path in images:
  assert os.path.exists(path), f"이미지 없음: {path}"
  base = os.path.splitext(os.path.basename(path))[0]
  pre_path = os.path.join(out_dir, f"{base}_preprocessed.png")
  sum_path = os.path.join(out_dir, f"{base}_summary.png")

  tens_norm, pre_u8 = preprocess_photo_to_fashion(path, save_preview=pre_path)
  pred, probs = predict(model, tens_norm, device)

  print(f"\n {path}")
  print(f" - Pred : {CLASS_NAMES[pred]} ({pred})")
  print(f" - Top-3: {topk_str(probs, 3)}")
  print(f" - Preprocessed saved: {pre_path}")

  make_summary(path, pre_u8, pred, probs, sum_path)
  print(f" - Summary saved: {sum_path}")
  summary_paths.append(sum_path)

 grid_path = os.path.join(out_dir, "summary_all.png")
 make_all_grid(summary_paths, grid_path, cols=2)
 print(f"\n 전체 요약 그리드: {grid_path}")

In [25]:
def parse_args():
 ap = argparse.ArgumentParser()
 ap.add_argument("--model", type=str, default=DEFAULT_MODEL)
 ap.add_argument("--images", nargs="*", default=DEFAULT_IMAGES)
 ap.add_argument("--out-dir", type=str, default=DEFAULT_OUTDIR)
 return ap.parse_known_args()

if __name__ == "__main__":
 args, _ = parse_args()
 run(args.model, args.images, args.out_dir)

✅ 모델 로드 완료: ./outputs_fashion/best_fashion_cnn.pt

 신발.jpg
 - Pred : Bag (8)
 - Top-3: Bag:0.995, T-shirt/top:0.002, Trouser:0.001
 - Preprocessed saved: ./outputs_fashion\신발_preprocessed.png
 - Summary saved: ./outputs_fashion\신발_summary.png

 셔츠.jpg
 - Pred : Bag (8)
 - Top-3: Bag:0.871, T-shirt/top:0.055, Ankle boot:0.020
 - Preprocessed saved: ./outputs_fashion\셔츠_preprocessed.png
 - Summary saved: ./outputs_fashion\셔츠_summary.png

 전체 요약 그리드: ./outputs_fashion\summary_all.png
