-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
/
Copy pathdataset_factory.py
186 lines (174 loc) · 6.83 KB
/
dataset_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
""" Dataset Factory
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
from torchvision.datasets import Places365
has_places365 = True
except ImportError:
has_places365 = False
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
try:
from torchvision.datasets import QMNIST
has_qmnist = True
except ImportError:
has_qmnist = False
try:
from torchvision.datasets import ImageNet
has_imagenet = True
except ImportError:
has_imagenet = False
from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = dict(train=None, training=None)
_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
split_name = split.split('[')[0]
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
def _try(syn):
for s in syn:
try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
seed=42,
repeats=0,
**kwargs
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* HFDS - Hugging Face Datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (HFDS, TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
batch_size: batch size hint for (TFDS, WDS)
seed: seed for iterable datasets (TFDS, WDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower()
if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'qmnist':
assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet':
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
download=download,
batch_size=batch_size,
repeats=repeats,
seed=seed,
**kwargs
)
elif name.startswith('wds/'):
ds = IterableImageDataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,
seed=seed,
**kwargs
)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, reader=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds