# Reading ImageNet from tfrecord files using TensorFlow's `tf.data` API

Here we will use TensorFlow's [`tf.data` API](https://www.tensorflow.org/guide/data) to read the ImageNet dataset stored in tfrecord files.

In [None]:
import os
import glob
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
tfrec_files = glob.glob(f'/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/train/*')

In [None]:
def decode(serialized_example):
    example = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(example['image/encoded'], channels=3)
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
    label = example['image/class/label'] - 1  # -> [0-999]
    return image, label

In [None]:
dataset = tf.data.TFRecordDataset(tfrec_files)
dataset = dataset.map(decode)
dataset = dataset.batch(64)

In [None]:
for imgs, labels in dataset.take(10):
    pass

In [None]:
imgs.shape

In [None]:
labels.shape

In [None]:
plt.imshow(imgs[17])
plt.axis('off')
plt.show()