-
Notifications
You must be signed in to change notification settings - Fork 58
/
imagenet_v2.py
45 lines (37 loc) · 1.6 KB
/
imagenet_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""ImageNet-v2 tf.data input pipeline.
Uses TFDS https://www.tensorflow.org/datasets/catalog/imagenet_v2.
"""
import functools
from typing import Dict, Iterator, Tuple
import tensorflow_datasets as tfds
from algorithmic_efficiency import data_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \
input_pipeline
def get_imagenet_v2_iter(data_dir: str,
global_batch_size: int,
mean_rgb: Tuple[float, float, float],
stddev_rgb: Tuple[float, float, float],
image_size: int,
resize_size: int) -> Iterator[Dict[str, spec.Tensor]]:
"""Always caches and repeats indefinitely."""
ds = tfds.load(
'imagenet_v2/matched-frequency:3.0.0',
split='test',
data_dir=data_dir,
decoders={
'image': tfds.decode.SkipDecoding(),
})
def _decode_example(example: Dict[str, float]) -> Dict[str, float]:
image = input_pipeline.preprocess_for_eval(example['image'],
mean_rgb,
stddev_rgb,
image_size,
resize_size)
return {'inputs': image, 'targets': example['label']}
ds = ds.map(_decode_example, num_parallel_calls=16)
ds = ds.batch(global_batch_size)
shard_pad_fn = functools.partial(
data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size)
it = map(shard_pad_fn, iter(ds))
return it