-
Notifications
You must be signed in to change notification settings - Fork 527
/
dataset.py
200 lines (156 loc) · 5.91 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import glob
import scipy
import torch
import random
import numpy as np
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from PIL import Image
from scipy.misc import imread
from skimage.feature import canny
from skimage.color import rgb2gray, gray2rgb
from .utils import create_mask
class Dataset(torch.utils.data.Dataset):
def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True):
super(Dataset, self).__init__()
self.augment = augment
self.training = training
self.data = self.load_flist(flist)
self.edge_data = self.load_flist(edge_flist)
self.mask_data = self.load_flist(mask_flist)
self.input_size = config.INPUT_SIZE
self.sigma = config.SIGMA
self.edge = config.EDGE
self.mask = config.MASK
self.nms = config.NMS
# in test mode, there's a one-to-one relationship between mask and image
# masks are loaded non random
if config.MODE == 2:
self.mask = 6
def __len__(self):
return len(self.data)
def __getitem__(self, index):
try:
item = self.load_item(index)
except:
print('loading error: ' + self.data[index])
item = self.load_item(0)
return item
def load_name(self, index):
name = self.data[index]
return os.path.basename(name)
def load_item(self, index):
size = self.input_size
# load image
img = imread(self.data[index])
# gray to rgb
if len(img.shape) < 3:
img = gray2rgb(img)
# resize/crop if needed
if size != 0:
img = self.resize(img, size, size)
# create grayscale image
img_gray = rgb2gray(img)
# load mask
mask = self.load_mask(img, index)
# load edge
edge = self.load_edge(img_gray, index, mask)
# augment data
if self.augment and np.random.binomial(1, 0.5) > 0:
img = img[:, ::-1, ...]
img_gray = img_gray[:, ::-1, ...]
edge = edge[:, ::-1, ...]
mask = mask[:, ::-1, ...]
return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask)
def load_edge(self, img, index, mask):
sigma = self.sigma
# in test mode images are masked (with masked regions),
# using 'mask' parameter prevents canny to detect edges for the masked regions
mask = None if self.training else (1 - mask / 255).astype(np.bool)
# canny
if self.edge == 1:
# no edge
if sigma == -1:
return np.zeros(img.shape).astype(np.float)
# random sigma
if sigma == 0:
sigma = random.randint(1, 4)
return canny(img, sigma=sigma, mask=mask).astype(np.float)
# external
else:
imgh, imgw = img.shape[0:2]
edge = imread(self.edge_data[index])
edge = self.resize(edge, imgh, imgw)
# non-max suppression
if self.nms == 1:
edge = edge * canny(img, sigma=sigma, mask=mask)
return edge
def load_mask(self, img, index):
imgh, imgw = img.shape[0:2]
mask_type = self.mask
# external + random block
if mask_type == 4:
mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3
# external + random block + half
elif mask_type == 5:
mask_type = np.random.randint(1, 4)
# random block
if mask_type == 1:
return create_mask(imgw, imgh, imgw // 2, imgh // 2)
# half
if mask_type == 2:
# randomly choose right or left
return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
# external
if mask_type == 3:
mask_index = random.randint(0, len(self.mask_data) - 1)
mask = imread(self.mask_data[mask_index])
mask = self.resize(mask, imgh, imgw)
mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation
return mask
# test mode: load mask non random
if mask_type == 6:
mask = imread(self.mask_data[index])
mask = self.resize(mask, imgh, imgw, centerCrop=False)
mask = rgb2gray(mask)
mask = (mask > 0).astype(np.uint8) * 255
return mask
def to_tensor(self, img):
img = Image.fromarray(img)
img_t = F.to_tensor(img).float()
return img_t
def resize(self, img, height, width, centerCrop=True):
imgh, imgw = img.shape[0:2]
if centerCrop and imgh != imgw:
# center crop
side = np.minimum(imgh, imgw)
j = (imgh - side) // 2
i = (imgw - side) // 2
img = img[j:j + side, i:i + side, ...]
img = scipy.misc.imresize(img, [height, width])
return img
def load_flist(self, flist):
if isinstance(flist, list):
return flist
# flist: image file path, image directory path, text file flist path
if isinstance(flist, str):
if os.path.isdir(flist):
flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
flist.sort()
return flist
if os.path.isfile(flist):
try:
return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
except:
return [flist]
return []
def create_iterator(self, batch_size):
while True:
sample_loader = DataLoader(
dataset=self,
batch_size=batch_size,
drop_last=True
)
for item in sample_loader:
yield item