In [None]:
import glob
import matplotlib.pyplot as plt
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import nvidia.dali.tfrecord as tfrec
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy

In [None]:
data_dir = '/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/'

tfrec_files = sorted(glob.glob(f'{data_dir}/train/*'))
index_files = sorted(glob.glob(f'{data_dir}/idx_files/train/*'))

In [None]:
batch_size = 64

pipe = Pipeline(batch_size=batch_size,
                num_threads=12,
                device_id=0)

with pipe:
    inputs = fn.readers.tfrecord(
        path=tfrec_files,
        index_path=index_files,
        features={
            'image/encoded': tfrec.FixedLenFeature((), tfrec.string, ""),
            'image/class/label': tfrec.FixedLenFeature([1], tfrec.int64,  -1),
        }
    )
    jpegs = inputs["image/encoded"]
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
    resized = fn.resize(images, device="gpu", size=(224, 224))
    pipe.set_outputs(resized, inputs["image/class/label"] - 1)

In [None]:
pipe.build()

In [None]:
loader = DALIClassificationIterator(
    pipe,
    last_batch_padded=False,
    auto_reset=True,
    last_batch_policy=LastBatchPolicy.DROP,
)

In [None]:
for i, samples in enumerate(loader):
    imgs, labels = (samples[0]['data'], samples[0]['label'])
    if i > 10:
        break

In [None]:
imgs.shape

In [None]:
labels.shape

In [None]:
plt.imshow(imgs.cpu()[22])
plt.axis('off')
plt.show()