In [1]:
import os
import glob
import torch
import torchvision.transforms as T

from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [2]:
class Cifar10(Dataset):
    # image dataset 전체 경로 저장 -> tranform
    def __init__(self, root, transform=None):
        super(Cifar10, self).__init__()
        self.make_dataset(root)
        self.transform = transform
    
    # image dataset 전체 경로 저장
    def make_dataset(self, root):
        # class(폴더명) 불러오기
        self.data = []
        categories = os.listdir(root)
        categories = sorted(categories)
        
        # class -> label 변환 + 각 class의 이미지 파일 전부 가져오기
        for label, category in enumerate(categories):
            images = glob.glob(f'{root}/{category}/*.png')
            for image in images:
                self.data.append((image, label))
    
    # data 개수
    def __len__(self):
        return len(self.data)
    
    # 경로에 있는, 지정한 idx의 이미지 읽기 -> RGB 변환 -> tranform -> image, label 반환
    def __getitem__(self, idx):
        image, label = self.data[idx]
        image = self.read_image(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
    # 경로에 있는 image 읽기 -> RGB 변환
    def read_image(self, path):
        image = Image.open(path)
        return image.convert('RGB')

In [8]:
# image transform
transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
])

# train_dataset load
train_root = 'data/Cifar10/train'
train_data = Cifar10(train_root, transform)

# 확인용 코드 : train_data shape + label
for image, label in train_data:
    print(image.shape, label)
    break

FileNotFoundError: [WinError 3] 지정된 경로를 찾을 수 없습니다: './data/Cifar10/train'

In [12]:
# image transform
transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
])

# train_data load
train_root = 'data/Cifar10/train'
train_data = Cifar10(train_root, transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True)

# test_data load
test_root = 'data/Cifar10/test'
test_data = Cifar10(test_root, transform)
test_loader = DataLoader(test_data, batch_size=1)

# 확인용 코드
print(len(train_data), len(train_loader)) # train_data 개수, train_loader batch set 개수
print(len(test_data), len(test_loader)) # test_data 개수, test_loader batch set 개수

50000 781
10000 10000
