In [51]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [62]:
# import
import torch
import zipfile, os
import glob
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split, Subset, ConcatDataset
import torchvision.transforms as transforms

In [63]:
# Transform 정의
transform = transforms.Compose([
    transforms.CenterCrop(178), # 중앙 178X178 자르기
    transforms.Resize((64, 64)), # 64x64 크기 변경
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3) # -1 ~ 1 norm
])

# image dataset
class ImageDataset(Dataset):
  def __init__(self, img_dir, transform=None):
    self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*.jpg")) +
                            glob.glob(os.path.join(img_dir, "*.png")))
    self.transform = transform

  def __len__(self):
    return len(self.img_paths) # 이미지 개수

  def __getitem__(self, idx):
    image = Image.open(self.img_paths[idx]).convert("RGB") # RGB 채널 변환, PIL
    if self.transform:
      image = self.transform(image)
    return image


In [73]:
# Dataset Loading
CelebADataset = ImageDataset("/content/drive/MyDrive/Colab Notebooks/Proj. DeepFake/dataset/CelebA/", transform=transform)
FFHQDataset = ImageDataset("/content/drive/MyDrive/Colab Notebooks/Proj. DeepFake/dataset/FFHQ/", transform=transform)


# Dataset concat
full_dataset = ConcatDataset([CelebADataset, FFHQDataset])
full_indices = list(range(len(full_dataset)))

In [78]:
print("CelebA 이미지 수:", len(CelebADataset))
print("FFHQ 이미지 수:", len(FFHQDataset))
print("합쳐진 전체 이미지 수:", len(full_dataset))
print(type(full_dataset))
print(type(CelebADataset))

CelebA 이미지 수: 13385
FFHQ 이미지 수: 881
합쳐진 전체 이미지 수: 14266
<class 'torch.utils.data.dataset.ConcatDataset'>
<class '__main__.ImageDataset'>


In [79]:
# 14266장 샘플링
sample_size = 14266
batch_size = 32
sampled_count = min(sample_size, len(full_dataset))
random.seed(42)
sampled_indices = random.sample(full_indices, sampled_count)
np.save('/content/drive/MyDrive/Colab Notebooks/Proj. DeepFake/dataset/train_indices.npy', sampled_indices)

subset_dataset = Subset(full_dataset, sampled_indices)


In [80]:
sample = subset_dataset[0]
print(type(sample))   # <class 'torch.Tensor'>
print(sample.shape)   # (3, 64, 64)

<class 'torch.Tensor'>
torch.Size([3, 64, 64])


In [82]:
import time, tqdm
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


# 로딩 속도 확인
load_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
start = time.time()
for _ in tqdm.tqdm(load_loader, desc="Loading 14266 images"): pass
print(f"Loaded {sampled_count} images in {(time.time()-start):.2f}s")


# Train/Val 분할
train_img = int(0.8 * sampled_count)
val_img = sampled_count - train_img
train_ds, val_ds = random_split(subset_dataset, [train_img, val_img])

# DataLoader 생성
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# val_indices.npy
unused_indices = list(set(full_indices) - set(sampled_indices))
np.save('/content/drive/MyDrive/Colab Notebooks/Proj. DeepFake/dataset/val_indices.npy', unused_indices)

Loading 14266 images: 100%|██████████| 446/446 [02:23<00:00,  3.11it/s]

Loaded 14266 images in 143.34s



