In [None]:
import pathlib

In [3]:
path = pathlib.Path()
path

WindowsPath('.')

In [4]:
path.absolute()

WindowsPath('c:/workspace_deep_learning')

In [11]:
p = path / "data" / "fruit"
p.absolute()

WindowsPath('c:/workspace_deep_learning/data/fruit')

In [None]:
for i in p.iterdir():
    list(i.glob("**/*.jpg"))

<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000001F01325A7D0>
<map object at 0x000001F01325A530>
<map object at 0x000

1. 이미지 경로를 설정
2. 디렉토리 이름을 별도로 가져와야 함
3. 각 디렉토리에 맞춰서 이미지를 관리해야 함
4. 이미지를 텐서로 변경해야 함 (레이블을 추가해야 함)
5. 전체 이미지를 7:2:1(학습, 검증, 테스트) 나눠야 함

In [50]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [51]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

In [52]:

def load_data(root_dir):
    root_path = Path(root_dir)
    # 클래스(레이블) 생성
    classes = sorted([d.name for d in root_path.iterdir() if d.is_dir()])
    class_to_idx = {class_name : idx for idx, class_name in enumerate(classes)}
    idx_to_class = {idx : class_name for idx, class_name in enumerate(classes)}

    # 이미지 파일 수집
    images = []
    labels = []
    for cls_name in classes:
        cls_dir = root_path / cls_name
        for img_path in cls_dir.glob("*.jpg"):
            images.append(img_path)
            labels.append(class_to_idx[cls_name])

    return images, labels, class_to_idx, idx_to_class

In [59]:
def get_transform(image_size=64, augment=True):
    if augment:
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(30),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    else :
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
    return transform

In [54]:
class FruitDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = self.images[index]
        label = self.labels[index]
        
        image = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)
        
        label = torch.tensor(label, dtype=torch.long)
        return image, label 

In [55]:
def split_data(image, label, train_ratio = 0.8):
    n_sample = len(image)
    n_train = int(n_sample * train_ratio)

    indices = np.random.permutation(n_sample)
    train_indices = indices[:n_train]
    test_indices = indices[n_train:]

    train_images = [image[i] for i in train_indices]
    val_images = [image[i] for i in test_indices]
    train_labels = [label[i] for i in train_indices]
    val_labels = [label[i] for i in test_indices] 

    return train_images, val_images, train_labels, val_labels

In [56]:
data_dir = "data/fruit"
max_epochs = 3
batch_size = 32
image_size = 64

torch.manual_seed(42)
np.random.seed(42)

device = get_device()
print(f'{device}를 사용합니다.')

cpu를 사용합니다.


In [62]:
images, labels, class_to_idx, idx_to_class = load_data(data_dir)
train_images, val_images, train_labels, val_labels = split_data(images, labels)
transform_train = get_transform(image_size, augment=True)
transform_val = get_transform(image_size, augment=False)

train_dataset = FruitDataset(train_images, train_labels, transform_train)
val_dataset = FruitDataset(val_images, val_labels, transform_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


# 14장에 나오는 CNN을 해야 합니다


In [None]:
import torch.nn as nn

class SimpleCNN(nn.Module):
    pass