데이터로더 뼈대 코드

In [None]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from PIL import Image

class SilhouetteDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        self.silhouette_dir = os.path.join(root_dir, 'pikachu_sil_128') # 데이터셋의 각 폴더명 수정해서 사용
        self.original_dir = os.path.join(root_dir, 'pikachu_128')

        self.silhouette_filenames = os.listdir(self.silhouette_dir)

    def __len__(self):
        return len(self.silhouette_filenames)

    def __getitem__(self, idx):
        silhouette_path = os.path.join(self.silhouette_dir, self.silhouette_filenames[idx])
        original_path = os.path.join(self.original_dir, self.silhouette_filenames[idx])

        silhouette_image = Image.open(silhouette_path)
        original_image = Image.open(original_path)

        if self.transform:
            silhouette_image = self.transform(silhouette_image)
            original_image = self.transform(original_image)

        return silhouette_image, original_image

# 데이터셋의 루트 디렉토리 수정해서 사용
root_dir = '/content/drive/MyDrive/whosthatpok/pikachu_128x128png'

# 전처리 및 데이터 증강을 위한 변환 정의
transform = transforms.Compose([
    transforms.ToTensor()
])

# 데이터셋 생성
dataset = SilhouetteDataset(root_dir, transform=transform)

train_size = 0.8
#indices = list(range(len(dataset)))
train_dataset, test_dataset = train_test_split(dataset, train_size=train_size, shuffle=True, random_state=42)

# 데이터로더 생성
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

테스트 코드

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 데이터셋 생성
dataset = SilhouetteDataset(root_dir, transform=transform)

# 몇 개의 샘플을 시각화
num_samples_to_show = 4
fig, axes = plt.subplots(2, num_samples_to_show, figsize=(10, 4))

for i in range(num_samples_to_show):
    silhouette_image, original_image = dataset[i]

    # 시각화를 위해 텐서를 이미지로 변환
    silhouette_image = vutils.make_grid(silhouette_image, normalize=True).numpy().transpose(1, 2, 0)
    original_image = vutils.make_grid(original_image, normalize=True).numpy().transpose(1, 2, 0)

    axes[0, i].imshow(silhouette_image, cmap='gray')
    axes[0, i].axis('off')
    axes[0, i].set_title('Silhouette')

    axes[1, i].imshow(original_image)
    axes[1, i].axis('off')
    axes[1, i].set_title('Original')

plt.show()