In [2]:
import os

import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms



# 강아지, 고양이 미지 데이터셋 처리 클래스
class DogCat2DImageDataset(Dataset):
  def __init__(self):
    self.image_transforms = transforms.Compose([
      transforms.Resize(size=(256, 256)),
      transforms.ToTensor()
    ])


    # 데이터 로드
    dogs_dir = os.path.join(os.path.pardir, os.path.pardir, "_00_data", "a_image-dog")
    cats_dir = os.path.join(os.path.pardir, os.path.pardir, "_00_data", "b_image-cats")

    # 이미지 파일 읽기
    image_lst = [
      Image.open(os.path.join(dogs_dir, "bobby.jpg")),  # (1280, 720, 3)
      Image.open(os.path.join(cats_dir, "cat1.png")),  # (256, 256, 3)
      Image.open(os.path.join(cats_dir, "cat2.png")),  # (256, 256, 3)
      Image.open(os.path.join(cats_dir, "cat3.png"))  # (256, 256, 3)
    ]

    # 이미지 256*256으로 변환 후 텐서 저장
    image_lst = [self.image_transforms(img) for img in image_lst]
    self.images = torch.stack(image_lst, dim=0)

    # 0: "dog", 1: "cat"
    self.image_labels = torch.tensor([[0], [1], [1], [1]])

  # 총 이미지 갯수
  def __len__(self):
    return len(self.images)

  # 인덱스에 해당하는 이미지와 레이블 반환
  def __getitem__(self, idx):
    return self.images[idx], self.image_labels[idx]

   # 데이터셋 정보 출력
  def __str__(self):
    str = "Data Size: {0}, Input Shape: {1}, Target Shape: {2}".format(
      len(self.images), self.images.shape, self.image_labels.shape
    )
    return str


if __name__ == "__main__":
  dog_cat_2d_image_dataset = DogCat2DImageDataset()

  # 데이터셋 정보 출력
  print(dog_cat_2d_image_dataset)

  print("#" * 50, 1)

  # 각 샘플 출력
  for idx, sample in enumerate(dog_cat_2d_image_dataset):
    input, target = sample
    print("{0} - {1}: {2}".format(idx, input.shape, target))

  # 데이터셋 분할(7:3)
  train_dataset, test_dataset = random_split(dog_cat_2d_image_dataset, [0.7, 0.3])

  print("#" * 50, 2)

  # 분할 데이터셋 크기 출력
  print(len(train_dataset), len(test_dataset))

  print("#" * 50, 3)

  # 위에서 선언한 클래스로 데이터셋 로드
  train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    shuffle=True
  )

  for idx, batch in enumerate(train_data_loader):
    input, target = batch
    print("{0} - {1}: {2}".format(idx, input.shape, target))


Data Size: 4, Input Shape: torch.Size([4, 3, 256, 256]), Target Shape: torch.Size([4, 1])
################################################## 1
0 - torch.Size([3, 256, 256]): tensor([0])
1 - torch.Size([3, 256, 256]): tensor([1])
2 - torch.Size([3, 256, 256]): tensor([1])
3 - torch.Size([3, 256, 256]): tensor([1])
################################################## 2
3 1
################################################## 3
0 - torch.Size([2, 3, 256, 256]): tensor([[0],
        [1]])
1 - torch.Size([1, 3, 256, 256]): tensor([[1]])
