In [None]:
!pip install ultralytics

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
from ultralytics import YOLO

# Fashion MNIST 데이터 로드
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

fashion_mnist_train = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
fashion_mnist_test = datasets.FashionMNIST(root='data', train=False, download=True, transform=transform)

# 데이터와 라벨 추출
x_train = fashion_mnist_train.data.numpy()
y_train = fashion_mnist_train.targets.numpy()
x_test = fashion_mnist_test.data.numpy()
y_test = fashion_mnist_test.targets.numpy()

# YOLO 형식 변환
# YOLOv8은 RGB 이미지를 ㅅ ㅏ용하므로 채널 추가 및 크기 변환 필요
def preprocess_and_save(data, labels, base_dir):
  os.makedirs(base_dir, exist_ok=True)

  for i, (img, label) in enumerate(zip(data, labels)):
    # Gray -> RGB 변환
    img_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

    # YOLOv8 요구 크기로 조정
    img_resized = cv2.resize(img_rgb, (224, 224))

    # 클래스 디렉토리 생성
    label_dir = os.path.join(base_dir, str(label))
    os.makedirs(label_dir, exist_ok=True)

    # 이미지 저장
    cv2.imwrite(os.path.join(label_dir, f"{i}.jpg"), img_resized)

# 데이터 디렉토리 생성
os.makedirs("datasets/fashion_mnist/train", exist_ok=True)
os.makedirs("datasets/fashion_mnist/val", exist_ok=True)

# 데이터 분리
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

preprocess_and_save(x_train, y_train, "datasets/fashion_mnist/train")
preprocess_and_save(x_val, y_val, "datasets/fashion_mnist/val")
preprocess_and_save(x_test, y_test, "datasets/fashion_mnist/test")  # 테스트 데이터도 저장 (옵션)

In [None]:
model = YOLO('yolov8n-cls.pt')

data_path = "datasets/fashion_mnist"

model.train(
    data=data_path,
    epochs=10,
    imgsz=224,
    batch=32
)

In [None]:
# 모델 평가
metrics = model.val()
print(metrics)

In [None]:
# 모델 예측
results = model.predict(source="datasets/fashion_mnist/val/0", save=True)
print(results)