-
Notifications
You must be signed in to change notification settings - Fork 4
/
datasets.py
172 lines (136 loc) · 5.38 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import glob
import random
import os
import warnings
import numpy as np
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def pad_to_square(img, pad_value):
c, h, w = img.shape
dim_diff = np.abs(h - w)
# (upper / left) padding and (lower / right) padding
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# Determine padding
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
# Add padding
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
class ImageFolder(Dataset):
def __init__(self, folder_path, transform=None):
self.files = sorted(glob.glob("%s/*.*" % folder_path))
self.transform = transform
def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
img = np.array(
Image.open(img_path).convert('RGB'),
dtype=np.uint8)
# Label Placeholder
boxes = np.zeros((1, 5))
segmaps = np.zeros_like(img)
# Apply transforms
if self.transform:
img, _, _ = self.transform((img, boxes, segmaps))
return img_path, img
def __len__(self):
return len(self.files)
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, multiscale=True, transform=None):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = []
for path in self.img_files:
image_dir = os.path.dirname(path)
label_dir = "labels".join(image_dir.rsplit("images", 1))
assert label_dir != image_dir, \
f"Image path must contain a folder named 'images'! \n'{image_dir}'"
label_file = os.path.join(label_dir, os.path.basename(path))
label_file = os.path.splitext(label_file)[0] + '.txt'
self.label_files.append(label_file)
self.mask_files = []
for path in self.img_files:
image_dir = os.path.dirname(path)
mask_dir = "segmentations".join(image_dir.rsplit("images", 1))
assert mask_dir != image_dir, \
f"Image path must contain a folder named 'images'! \n'{image_dir}'"
mask_file = os.path.join(mask_dir, os.path.basename(path))
mask_file = os.path.splitext(mask_file)[0] + '.png'
self.mask_files.append(mask_file)
self.img_size = img_size
self.max_objects = 100
self.multiscale = multiscale
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
self.transform = transform
def __getitem__(self, index):
# ---------
# Image
# ---------
try:
img_path = self.img_files[index % len(self.img_files)].rstrip()
img = np.array(Image.open(img_path).convert('RGB'), dtype=np.uint8)
except Exception:
print(f"Could not read image '{img_path}'.")
return
# ---------
# Label
# ---------
try:
label_path = self.label_files[index % len(self.img_files)].rstrip()
# Ignore warning if file is empty
with warnings.catch_warnings():
warnings.simplefilter("ignore")
boxes = np.loadtxt(label_path).reshape(-1, 5)
except Exception:
print(f"Could not read label '{label_path}'.")
return
# ---------
# Segmentation Mask
# ---------
try:
mask_path = self.mask_files[index % len(self.img_files)].rstrip()
# Load segmentation mask as numpy array
mask = np.array(Image.open(mask_path).convert('RGB')) // 127
except FileNotFoundError as e:
print(f"Could not load mask '{mask_path}'.")
return
# -----------
# Transform
# -----------
if self.transform:
try:
img, bb_targets, mask_targets = self.transform(
(img, boxes, mask)
)
except Exception as e:
print(f"Could not apply transform.")
raise e
return
return img_path, img, bb_targets, mask_targets
def collate_fn(self, batch):
self.batch_count += 1
# Drop invalid images
batch = [data for data in batch if data is not None]
paths, imgs, bb_targets, mask_targets = list(zip(*batch))
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(
range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
# Add sample index to targets
for i, boxes in enumerate(bb_targets):
boxes[:, 0] = i
bb_targets = torch.cat(bb_targets, 0)
# Stack masks and drop the 2 duplicated channels
mask_targets = torch.stack([resize(mask, self.img_size)[0] for mask in mask_targets]).long()
return paths, imgs, bb_targets, mask_targets
def __len__(self):
return len(self.img_files)