-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
110 lines (104 loc) · 3.72 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torchvision.datasets as dset
import torchvision.transforms as transforms
from utils import Invert
from utils import Gray
DATA_PATH = 'data'
def data_path(folder):
return os.path.join(DATA_PATH, folder)
def load_dataset(dataset_name, split='full', size=None):
if dataset_name == 'mnist':
dataset = dset.MNIST(
root=data_path('mnist'),
download=True,
transform=transforms.Compose([
transforms.Scale(32),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
)
return dataset
elif dataset_name == 'chairs':
dataset = dset.ImageFolder(root=data_path('chairs'),
transform=transforms.Compose([
transforms.CenterCrop(256),
transforms.Scale(128),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'coco':
dataset = dset.ImageFolder(root=data_path('coco'),
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'coco_256':
dataset = dset.ImageFolder(root=data_path('coco'),
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'footwear':
dataset = dset.ImageFolder(root=data_path('shoes/ut-zap50k-images'),
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'celeba_256':
dataset = dset.ImageFolder(root=data_path('celeba'),
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'celeba':
dataset = dset.ImageFolder(root=data_path('celeba'),
transform=transforms.Compose([
transforms.Scale(78),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'birds':
dataset = dset.ImageFolder(root=data_path('birds/full'),
transform=transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset
elif dataset_name == 'fonts':
dataset = dset.ImageFolder(root=data_path('fonts/full'),
transform=transforms.Compose([
transforms.ToTensor(),
Invert(),
Gray(),
transforms.Normalize((0.5,), (0.5,)),
]))
return dataset
else:
if size is None:
size = 64
else:
size = int(size)
dataset = dset.ImageFolder(root=data_path(dataset_name),
transform=transforms.Compose([
transforms.Scale(size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return dataset