-
Notifications
You must be signed in to change notification settings - Fork 68
/
ffhq_dataset.py
88 lines (69 loc) · 2.84 KB
/
ffhq_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import random
import numpy as np
from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from data.image_folder import make_dataset
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from data.base_dataset import BaseDataset
from utils.utils import onehot_parse_map
class FFHQDataset(BaseDataset):
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.img_size = opt.Pimg_size
self.lr_size = opt.Gin_size
self.hr_size = opt.Gout_size
self.shuffle = True if opt.isTrain else False
self.img_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'imgs1024')))
self.mask_dataset = sorted(make_dataset(os.path.join(opt.dataroot, 'masks512')))
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.random_crop = transforms.RandomCrop(self.hr_size)
def __len__(self,):
return len(self.img_dataset)
def __getitem__(self, idx):
sample = {}
img_path = self.img_dataset[idx]
mask_path = self.mask_dataset[idx]
hr_img = Image.open(img_path).convert('RGB')
mask_img = Image.open(mask_path).convert('RGB')
hr_img = hr_img.resize((self.hr_size, self.hr_size))
hr_img = random_gray(hr_img, p=0.3)
scale_size = np.random.randint(32, 256)
lr_img = complex_imgaug(hr_img, self.img_size, scale_size)
mask_img = mask_img.resize((self.hr_size, self.hr_size))
mask_label = onehot_parse_map(mask_img)
mask_label = torch.tensor(mask_label).float()
hr_tensor = self.to_tensor(hr_img)
lr_tensor = self.to_tensor(lr_img)
return {'HR': hr_tensor, 'LR': lr_tensor, 'HR_paths': img_path, 'Mask': mask_label}
def complex_imgaug(x, org_size, scale_size):
"""input single RGB PIL Image instance"""
x = np.array(x)
x = x[np.newaxis, :, :, :]
aug_seq = iaa.Sequential([
iaa.Sometimes(0.5, iaa.OneOf([
iaa.GaussianBlur((3, 15)),
iaa.AverageBlur(k=(3, 15)),
iaa.MedianBlur(k=(3, 15)),
iaa.MotionBlur((5, 25))
])),
iaa.Resize(scale_size, interpolation=ia.ALL),
iaa.Sometimes(0.2, iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.1*255), per_channel=0.5)),
iaa.Sometimes(0.7, iaa.JpegCompression(compression=(10, 65))),
iaa.Resize(org_size),
])
aug_img = aug_seq(images=x)
return aug_img[0]
def random_gray(x, p=0.5):
"""input single RGB PIL Image instance"""
x = np.array(x)
x = x[np.newaxis, :, :, :]
aug = iaa.Sometimes(p, iaa.Grayscale(alpha=1.0))
aug_img = aug(images=x)
return aug_img[0]