# Reading ImageNet from tfrecord files using NVidia DALI

Here we will use the [NVidia DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html)'s [tfrecord](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/operations/nvidia.dali.fn.readers.tfrecord.html#nvidia-dali-fn-readers-tfrecord) reader to read the ImageNet dataset stored in tfrecord files.

Besides the tfrecord files, Nvidia DALI needs information of the position of the records on the file. That's created by the Nvidia DALI's utility script [tfrecord2idx](https://github.com/NVIDIA/DALI/blob/main/tools/tfrecord2idx).

In [None]:
import glob
import matplotlib.pyplot as plt
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy

In [None]:
data_dir = '/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/'

tfrec_files = sorted(glob.glob(f'{data_dir}/train/*'))
index_files = sorted(glob.glob(f'{data_dir}/idx_files/train/*'))

In [None]:
tfrec.FixedLenFeature?

In [None]:
pipe = Pipeline(batch_size=64,
                num_threads=12,
                device_id=0)

with pipe:
    example = fn.readers.tfrecord(
        path=tfrec_files,
        index_path=index_files,
        features={
            'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ''),
            'image/class/label': tfrec.FixedLenFeature((), tfrec.int64, -1),
        }
    )
    label = example['image/class/label'] - 1
    image = fn.decoders.image(example['image/encoded'], device='mixed', output_type=types.RGB)
    image = fn.resize(image, device='gpu', size=(224, 224))
    pipe.set_outputs(image, label)

In [None]:
pipe.build()

In [None]:
loader = DALIClassificationIterator(
    pipe,
    last_batch_padded=False,
    auto_reset=True,
    last_batch_policy=LastBatchPolicy.DROP,
)

In [None]:
for i, samples in enumerate(loader):
    imgs, labels = (samples[0]['data'], samples[0]['label'])
    if i > 10:
        break

In [None]:
imgs.shape

In [None]:
labels.shape

In [None]:
plt.imshow(imgs.cpu()[22])
plt.axis('off')
plt.show()