-
Notifications
You must be signed in to change notification settings - Fork 136
/
cifar10.py
88 lines (76 loc) · 2.37 KB
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""CIFAR10 streaming dataset conversion scripts."""
from argparse import ArgumentParser, Namespace
from torchvision.datasets import CIFAR10
from streaming.base.util import get_list_arg
from streaming.vision.convert.base import convert_image_class_dataset
def parse_args() -> Namespace:
"""Parse command-line arguments.
Args:
Namespace: command-line arguments.
"""
args = ArgumentParser()
args.add_argument(
'--in_root',
type=str,
required=True,
help='Local directory path of the input raw dataset',
)
args.add_argument(
'--out_root',
type=str,
required=True,
help='Directory path to store the output dataset',
)
args.add_argument(
'--splits',
type=str,
default='train,val',
help='Split to use. Default: train,val',
)
args.add_argument(
'--compression',
type=str,
default='',
help='Compression algorithm to use. Default: None',
)
args.add_argument(
'--hashes',
type=str,
default='sha1,xxh64',
help='Hashing algorithms to apply to shard files. Default: sha1,xxh64',
)
args.add_argument(
'--size_limit',
type=int,
default=1 << 20,
help='Shard size limit, after which point to start a new shard. Default: 1 << 20',
)
args.add_argument(
'--progress_bar',
type=int,
default=1,
help='tqdm progress bar. Default: 1 (True)',
)
args.add_argument(
'--leave',
type=int,
default=0,
help='Keeps all traces of the progressbar upon termination of iteration. Default: 0 ' +
'(False)',
)
return args.parse_args()
def main(args: Namespace) -> None:
"""Main: create streaming CIFAR10 dataset.
Args:
args (Namespace): command-line arguments.
"""
splits = get_list_arg(args.splits)
hashes = get_list_arg(args.hashes)
for split in splits:
dataset = CIFAR10(root=args.in_root, train=(split == 'train'), download=True)
convert_image_class_dataset(dataset, args.out_root, split, args.compression, hashes,
args.size_limit, args.progress_bar, args.leave, 'pil')
if __name__ == '__main__':
main(parse_args())