In [None]:
import os
import torch
from random import shuffle
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from PIL import Image
from functools import partial
from torchvision.transforms.functional import InterpolationMode
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.backend import TensorListGPU
from nvidia.dali import pipeline_def, Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import nvidia.dali.fn as fn

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

DATADIR = "/data"
BATCH_SIZE = 4 # batch size per GPU

In [None]:
#https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/image_processing/decoder_examples.html
def show_images(image_batch):
    columns = 4
    rows = (BATCH_SIZE + 1) // (columns)
    fig = plt.figure(figsize=(32, (32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(image_batch.at(j))

def show_pipeline_output(pipe):
    pipe.build()
    images, _ = pipe.run()
    if isinstance(images, TensorListGPU):
        images = images.as_cpu()
    show_images(images)

In [None]:
@pipeline_def
def int_pipeline(crop=224):
    traindir = os.path.join(DATADIR, "train")
    jpegs, labels = fn.readers.file(file_root=traindir,
            shard_id=0,
            num_shards=1,
            random_shuffle=True,
            pad_last_batch=True,)
    images = fn.decoders.image(jpegs,
        device="mixed",
        output_type=types.RGB,
        device_memory_padding=211025920,
        host_memory_padding=140544512,
    )
    images = fn.random_resized_crop(
        images,
        device="gpu",
        size=[crop, crop],
        interp_type=types.INTERP_LINEAR,
        random_aspect_ratio=[0.75, 4.0 / 3.0],
        random_area=[0.08, 1.0],
        num_attempts=100,
        antialias=False,
    )
    images = fn.crop_mirror_normalize(
        images,
        device="gpu",
        dtype=types.FLOAT,
#         output_layout=types.NCHW,
        output_layout="HWC",
        crop=(crop, crop),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )
    return images, labels

In [None]:
class ExternalInputIterator(object):
    def __init__(self, batch_size, device_id, num_gpus):
        self.images_dir = "/data/train/"
        self.batch_size = batch_size
        dirs = {os.path.join(self.images_dir, d): i for i, d in enumerate(sorted(os.listdir(self.images_dir)))}
        self.files = [(os.path.join(root, file), dirs[root]) for root, _, files in os.walk(self.images_dir) for file in files]
        # whole data set size
        self.data_set_len = len(self.files)
        # based on the device_id and total number of GPUs - world size
        # get proper shard
        self.files = self.files[
            self.data_set_len
            * device_id
            // num_gpus : self.data_set_len
            * (device_id + 1)
            // num_gpus
        ]
        self.n = len(self.files)

    def __iter__(self):
        self.i = 0
        shuffle(self.files)
        return self

    def __next__(self):
        batch = []
        labels = []

        if self.i >= self.n:
            self.__iter__()
            raise StopIteration

        for _ in range(self.batch_size):
            jpeg, label = self.files[self.i % self.n]
            batch.append(
                np.fromfile(jpeg, dtype=np.uint8)
            )  # we can use numpy
            labels.append(
                torch.tensor([label], dtype=torch.int32)
            )  # or PyTorch's native tensors
            self.i += 1
        return (batch, labels)

    def __len__(self):
        return self.data_set_len

    next = __next__
    
@pipeline_def
def ext_pipeline(crop=224):
    traindir = os.path.join(DATADIR, "train")
    jpegs, labels = fn.external_source(
        source=ExternalInputIterator(BATCH_SIZE, 0, 1), num_outputs=2, dtype=[types.UINT8, types.INT32]
    )
    images = fn.decoders.image(jpegs,
        device="mixed",
        output_type=types.RGB,
        device_memory_padding=211025920,
        host_memory_padding=140544512,
    )
    images = fn.random_resized_crop(
        images,
        device="gpu",
        size=[crop, crop],
        interp_type=types.INTERP_LINEAR,
        random_aspect_ratio=[0.75, 4.0 / 3.0],
        random_area=[0.08, 1.0],
        num_attempts=100,
        antialias=False,
    )
    images = fn.crop_mirror_normalize(
        images,
        device="gpu",
        dtype=types.FLOAT,
#         output_layout=types.NCHW,
        output_layout="HWC",
        crop=(crop, crop),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )
    return images, labels

In [None]:
int_pipe = int_pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)
show_pipeline_output(int_pipe)

ext_pipe = ext_pipeline(batch_size=BATCH_SIZE, num_threads=2, device_id=0)
show_pipeline_output(ext_pipe)