In [14]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [15]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

In [16]:
def load_data(root_dir):
    root_path = Path(root_dir)

    # 클래스(레이블) 생성
    classes = sorted([d.name for d in root_path.iterdir() if d.is_dir()])
    class_to_idx = {cls_name:idx for idx, cls_name in enumerate(classes)}
    idx_to_class = {idx: cls_name for cls_name, idx in class_to_idx.items()}

    # 이미지 파일 수집
    images = []
    labels = []
    for cls_name in classes:
        cls_dir = root_path / cls_name
        for img_path in cls_dir.glob("*.jpg"):
            images.append(img_path)
            labels.append(class_to_idx[cls_name])
    return images, labels, class_to_idx, idx_to_class

In [17]:
def split_data(images, labels, train_ratio=0.8):
    n_sample = len(images)
    n_train = int(n_sample * train_ratio)

    indices = np.random.permutation(n_sample)
    train_indices = indices[:n_train]
    val_indices = indices[n_train:]

    train_images = [images[i] for i in train_indices]
    val_images = [images[i] for i in val_indices]
    train_labels = [labels[i] for i in train_indices]
    val_labels = [labels[i] for i in val_indices]

    return train_images, val_images, train_labels, val_labels

In [18]:
def get_trasforms(image_size=64, augment=True):
    if augment:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])     
    return transform

In [19]:
class FruitDataset(Dataset):
    # 과일 이미지 데이터셋
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        label = self.labels[index]

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

In [20]:
data_dir = "data/fruit" # 데이터 위치
num_epochs = 3 # 반복 
batch_size = 32 # 학습크기
image_size = 64

torch.manual_seed(42) # 재현성을 위한 랜덤 값
np.random.seed(42) # 재현성을 위한 랜덤 값

device = get_device() # 계산 장치 설정
print(f"{device}를 사용합니다.")

cpu를 사용합니다.


In [26]:
images, labels, class_to_idx, idx_to_class = load_data(data_dir)
train_images, val_images, train_labels, val_labels = split_data(images, labels)

transform_train = get_trasforms(image_size, augment=True)
transform_val = get_trasforms(image_size, augment=False)

train_dataset = FruitDataset(train_images, train_labels, transform=transform_train)
val_dataset = FruitDataset(val_images, val_labels, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset,batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
# 활성화
# 손실함수
# 최적화함수

In [23]:
# 14장에 나오는 CNN으로 가면 됩니다.
import torch.nn as nn
class SimpleCNN(nn.Module):
    pass