In [None]:
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from functools import partial
from torchvision.transforms.functional import InterpolationMode
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.backend import TensorListGPU
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.fn as fn

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

DATADIR = "/data"
BATCH_SIZE = 4 # batch size per GPU

In [None]:
#https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/image_processing/decoder_examples.html
def show_images(image_batch):
    columns = 4
    rows = (BATCH_SIZE + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(image_batch.at(j))

def show_pipeline_output(pipe):
    pipe.build()
    images, _ = pipe.run()
    if isinstance(images, TensorListGPU):
        images = images.as_cpu()
    show_images(images)

In [None]:
@pipeline_def
def pipeline(crop=224):
    traindir = os.path.join(DATADIR, "train")
    jpegs, labels = fn.readers.file(file_root=traindir,
            shard_id=0,
            num_shards=1,
            random_shuffle=True,
            pad_last_batch=True,)
    images = fn.decoders.image(jpegs,
        device="mixed",
        output_type=types.RGB,
        device_memory_padding=211025920,
        host_memory_padding=140544512,
    )
    images = fn.random_resized_crop(
        images,
        device="gpu",
        size=[crop, crop],
        interp_type=types.INTERP_LINEAR,
        random_aspect_ratio=[0.75, 4.0 / 3.0],
        random_area=[0.08, 1.0],
        num_attempts=100,
        antialias=False,
    )
    images = fn.crop_mirror_normalize(
        images,
        device="gpu",
        dtype=types.FLOAT,
#         output_layout=types.NCHW,
        output_layout="HWC",
        crop=(crop, crop),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )
    return images, labels

In [None]:
pipe = pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)
show_pipeline_output(pipe)