# Dataset sub-class

In [1]:
import os
from glob import glob

import torch
from torchvision import datasets, transforms

from PIL import Image

# 직접 `torch.utils.data.Dataset` 을 상속받아서 데이터셋 구현하기
## 1. dir 만들기

In [5]:
raw_path = "E:/공부/제로베이스/Part 10. 텐서플로 & Part 11. 파이토치/deeplearning_frameworks_zerobaseDSS/"


cifar_dir = raw_path + "datasets/cifar/"
os.listdir(cifar_dir)

['labels.txt', 'test_dataset.csv', 'train_dataset.csv', 'test', 'train']

In [7]:
train_dir = cifar_dir + "train"

os.listdir(train_dir)

['0_frog.png',
 '10000_automobile.png',
 '10001_frog.png',
 '10002_frog.png',
 '10003_ship.png',
 '10004_ship.png',
 '10005_cat.png',
 '10006_deer.png',
 '10007_frog.png',
 '10008_airplane.png',
 '10009_frog.png',
 '1000_truck.png',
 '10010_airplane.png',
 '10011_cat.png',
 '10012_frog.png',
 '10013_frog.png',
 '10014_dog.png',
 '10015_deer.png',
 '10016_ship.png',
 '10017_cat.png',
 '10018_bird.png',
 '10019_frog.png',
 '1001_deer.png',
 '10020_airplane.png',
 '10021_cat.png',
 '10022_automobile.png',
 '10023_deer.png',
 '10024_airplane.png',
 '10025_frog.png',
 '10026_frog.png',
 '10027_bird.png',
 '10028_horse.png',
 '10029_frog.png',
 '1002_cat.png',
 '10030_truck.png',
 '10031_airplane.png',
 '10032_deer.png',
 '10033_dog.png',
 '10034_horse.png',
 '10035_automobile.png',
 '10036_frog.png',
 '10037_horse.png',
 '10038_truck.png',
 '10039_automobile.png',
 '1003_bird.png',
 '10040_horse.png',
 '10041_horse.png',
 '10042_ship.png',
 '10043_airplane.png',
 '10044_cat.png',
 '10045_ho

In [8]:
test_dir = cifar_dir + "test"
os.listdir(test_dir)

['0_cat.png',
 '1000_dog.png',
 '1001_airplane.png',
 '1002_ship.png',
 '1003_deer.png',
 '1004_ship.png',
 '1005_automobile.png',
 '1006_automobile.png',
 '1007_ship.png',
 '1008_truck.png',
 '1009_frog.png',
 '100_deer.png',
 '1010_airplane.png',
 '1011_ship.png',
 '1012_frog.png',
 '1013_automobile.png',
 '1014_cat.png',
 '1015_deer.png',
 '1016_automobile.png',
 '1017_frog.png',
 '1018_airplane.png',
 '1019_dog.png',
 '101_dog.png',
 '1020_automobile.png',
 '1021_automobile.png',
 '1022_airplane.png',
 '1023_airplane.png',
 '1024_cat.png',
 '1025_dog.png',
 '1026_airplane.png',
 '1027_airplane.png',
 '1028_frog.png',
 '1029_frog.png',
 '102_frog.png',
 '1030_cat.png',
 '1031_cat.png',
 '1032_frog.png',
 '1033_cat.png',
 '1034_frog.png',
 '1035_frog.png',
 '1036_airplane.png',
 '1037_horse.png',
 '1038_bird.png',
 '1039_bird.png',
 '103_cat.png',
 '1040_horse.png',
 '1041_dog.png',
 '1042_dog.png',
 '1043_bird.png',
 '1044_ship.png',
 '1045_dog.png',
 '1046_bird.png',
 '1047_automob

- label이 파일 이름에 적혀있다(dog, frog 등)
- path를 넘겨야 label 정보를 추출할 수 있다

## 2. label 정보 처리
- labels.txt 파일을 사용해 label의 인덱스 번호를 알아낼 예정


In [10]:
with open(os.path.join(cifar_dir, "labels.txt"), 'r') as f:
    label_list = f.read().strip().split("\n")
    
label_list

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

- os.path.join(cifar_dir, "labels.txt") = cifar_dir + "labels.txt"

In [11]:
label_list.index("deer")

4

In [13]:
train_paths = glob(train_dir + "/*.png")
test_paths = glob(test_dir + "/*.png")

In [24]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_paths, transform=None):
        super(Dataset).__init__()
        self.data_paths = data_paths
        self.transform = transform
        
    def __len__(self): 
        return len(self.data_paths)
        
    def __getitem__(self, idx):  # 인덱스를 인수로 입력하면, 해당 인덱스를 이미지와 label로 만들어줌
        path = self.data_paths[idx]
        image = Image.open(path)
        label_name = path.split(".png")[0].split("_")[-1].strip()
        laebl = label_list.index(label_name)
        
        if self.transform:
            image = self.transform(image)
        
        return image, laebl

## 3. 데이터 불러오기

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32

In [26]:
train_loader = torch.utils.data.DataLoader(
                            Dataset(train_paths, transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True
                                )

test_loader = torch.utils.data.DataLoader(
                            Dataset(test_paths, transform=transforms.ToTensor()),
                            batch_size=batch_size
                                )

In [27]:
x, y = next(iter(train_loader))

In [28]:
x.shape, y.shape

(torch.Size([32, 3, 32, 32]), torch.Size([32]))