This repository has been archived by the owner on Aug 18, 2020. It is now read-only.
/
data.py
180 lines (158 loc) · 9.15 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/08_vision.data.ipynb (unless otherwise specified).
__all__ = ['get_grid', 'clip_remove_empty', 'bb_pad', 'ImageBlock', 'MaskBlock', 'PointBlock', 'BBoxBlock',
'BBoxLblBlock', 'ImageDataLoaders', 'SegmentationDataLoaders']
# Cell
from ..torch_basics import *
from ..data.all import *
from .core import *
# Cell
@delegates(subplots)
def get_grid(n, nrows=None, ncols=None, add_vert=0, figsize=None, double=False, title=None, return_fig=False, **kwargs):
"Return a grid of `n` axes, `rows` by `cols`"
nrows = nrows or int(math.sqrt(n))
ncols = ncols or int(np.ceil(n/nrows))
if double: ncols*=2 ; n*=2
fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
axs = [ax if i<n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
if title is not None: fig.suptitle(title, weight='bold', size=14)
return (fig,axs) if return_fig else axs
# Cell
def clip_remove_empty(bbox, label):
"Clip bounding boxes with image border and label background the empty ones"
bbox = torch.clamp(bbox, -1, 1)
empty = ((bbox[...,2] - bbox[...,0])*(bbox[...,3] - bbox[...,1]) < 0.)
return (bbox[~empty], label[~empty])
# Cell
def bb_pad(samples, pad_idx=0):
"Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`."
samples = [(s[0], *clip_remove_empty(*s[1:])) for s in samples]
max_len = max([len(s[2]) for s in samples])
def _f(img,bbox,lbl):
bbox = torch.cat([bbox,bbox.new_zeros(max_len-bbox.shape[0], 4)])
lbl = torch.cat([lbl, lbl .new_zeros(max_len-lbl .shape[0])+pad_idx])
return img,bbox,lbl
return [_f(*s) for s in samples]
# Cell
@typedispatch
def show_batch(x:TensorImage, y, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
return ctxs
# Cell
@typedispatch
def show_batch(x:TensorImage, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, add_vert=1, figsize=figsize, double=True)
for i in range(2):
ctxs[i::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[i::2],range(max_n))]
return ctxs
# Cell
def ImageBlock(cls=PILImage):
"A `TransformBlock` for images of `cls`"
return TransformBlock(type_tfms=cls.create, batch_tfms=IntToFloatTensor)
# Cell
def MaskBlock(codes=None):
"A `TransformBlock` for segmentation masks, potentially with `codes`"
return TransformBlock(type_tfms=PILMask.create, item_tfms=AddMaskCodes(codes=codes), batch_tfms=IntToFloatTensor)
# Cell
PointBlock = TransformBlock(type_tfms=TensorPoint.create, item_tfms=PointScaler)
BBoxBlock = TransformBlock(type_tfms=TensorBBox.create, item_tfms=PointScaler, dls_kwargs = {'before_batch': bb_pad})
PointBlock.__doc__ = "A `TransformBlock` for points in an image"
BBoxBlock.__doc__ = "A `TransformBlock` for bounding boxes in an image"
# Cell
def BBoxLblBlock(vocab=None, add_na=True):
"A `TransformBlock` for labeled bounding boxes, potentially with `vocab`"
return TransformBlock(type_tfms=MultiCategorize(vocab=vocab, add_na=add_na), item_tfms=BBoxLabeler)
# Cell
class ImageDataLoaders(DataLoaders):
"Basic wrapper around several `DataLoader`s with factory methods for computer vision problems"
@classmethod
@delegates(DataLoaders.from_dblock)
def from_folder(cls, path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, item_tfms=None,
batch_tfms=None, **kwargs):
"Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)"
splitter = GrandparentSplitter(train_name=train, valid_name=valid) if valid_pct is None else RandomSplitter(valid_pct, seed=seed)
get_items = get_image_files if valid_pct else partial(get_image_files, folders=[train, valid])
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock(vocab=vocab)),
get_items=get_items,
splitter=splitter,
get_y=parent_label,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
return cls.from_dblock(dblock, path, path=path, **kwargs)
@classmethod
@delegates(DataLoaders.from_dblock)
def from_path_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, item_tfms=None, batch_tfms=None, **kwargs):
"Create from list of `fnames` in `path`s with `label_func`"
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
splitter=RandomSplitter(valid_pct, seed=seed),
get_y=label_func,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
return cls.from_dblock(dblock, fnames, path=path, **kwargs)
@classmethod
def from_name_func(cls, path, fnames, label_func, **kwargs):
"Create from the name attrs of `fnames` in `path`s with `label_func`"
f = using_attr(label_func, 'name')
return cls.from_path_func(path, fnames, f, **kwargs)
@classmethod
def from_path_re(cls, path, fnames, pat, **kwargs):
"Create from list of `fnames` in `path`s with re expression `pat`"
return cls.from_path_func(path, fnames, RegexLabeller(pat), **kwargs)
@classmethod
@delegates(DataLoaders.from_dblock)
def from_name_re(cls, path, fnames, pat, **kwargs):
"Create from the name attrs of `fnames` in `path`s with re expression `pat`"
return cls.from_name_func(path, fnames, RegexLabeller(pat), **kwargs)
@classmethod
@delegates(DataLoaders.from_dblock)
def from_df(cls, df, path='.', valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None,
y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, **kwargs):
"Create from `df` using `fn_col` and `label_col`"
pref = f'{Path(path) if folder is None else Path(path)/folder}{os.path.sep}'
if y_block is None:
is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None
y_block = MultiCategoryBlock if is_multi else CategoryBlock
splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)
dblock = DataBlock(blocks=(ImageBlock, y_block),
get_x=ColReader(fn_col, pref=pref, suff=suff),
get_y=ColReader(label_col, label_delim=label_delim),
splitter=splitter,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
return cls.from_dblock(dblock, df, path=path, **kwargs)
@classmethod
def from_csv(cls, path, csv_fname='labels.csv', header='infer', delimiter=None, **kwargs):
"Create from `path/csv_fname` using `fn_col` and `label_col`"
df = pd.read_csv(Path(path)/csv_fname, header=header, delimiter=delimiter)
return cls.from_df(df, path=path, **kwargs)
@classmethod
@delegates(DataLoaders.from_dblock)
def from_lists(cls, path, fnames, labels, valid_pct=0.2, seed:int=None, y_block=None, item_tfms=None, batch_tfms=None,
**kwargs):
"Create from list of `fnames` and `labels` in `path`"
if y_block is None:
y_block = MultiCategoryBlock if is_listy(labels[0]) and len(labels[0]) > 1 else (
RegressionBlock if isinstance(labels[0], float) else CategoryBlock)
dblock = DataBlock.from_columns(blocks=(ImageBlock, y_block),
splitter=RandomSplitter(valid_pct, seed=seed),
item_tfms=item_tfms,
batch_tfms=batch_tfms)
return cls.from_dblock(dblock, (fnames, labels), path=path, **kwargs)
ImageDataLoaders.from_csv = delegates(to=ImageDataLoaders.from_df)(ImageDataLoaders.from_csv)
ImageDataLoaders.from_name_func = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_name_func)
ImageDataLoaders.from_path_re = delegates(to=ImageDataLoaders.from_path_func)(ImageDataLoaders.from_path_re)
ImageDataLoaders.from_name_re = delegates(to=ImageDataLoaders.from_name_func)(ImageDataLoaders.from_name_re)
# Cell
class SegmentationDataLoaders(DataLoaders):
"Basic wrapper around several `DataLoader`s with factory methods for segmentation problems"
@classmethod
@delegates(DataLoaders.from_dblock)
def from_label_func(cls, path, fnames, label_func, valid_pct=0.2, seed=None, codes=None, item_tfms=None, batch_tfms=None, **kwargs):
"Create from list of `fnames` in `path`s with `label_func`."
dblock = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),
splitter=RandomSplitter(valid_pct, seed=seed),
get_y=label_func,
item_tfms=item_tfms,
batch_tfms=batch_tfms)
res = cls.from_dblock(dblock, fnames, path=path, **kwargs)
return res