### For using DALI recommended use [NGC PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch).

In [1]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali as dali
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from torch.utils.data import DataLoader

from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

In [2]:
dali.__version__

'0.29.0'

### To get test data you need to use [DALI extra](https://github.com/NVIDIA/DALI_extra).

In [3]:
# define pipeline
data_paths = {
    'train': 'DALI_extra/db/MNIST/training/',
    'valid': 'DALI_extra/db/MNIST/testing/',
}

class MNISTPipeline(Pipeline):
    def __init__(
        self,
        mode: str = 'train',
        batch_size: int = 16,
        num_threads: int = 4,
        device_id: int = 0,
    ):
        super().__init__(
            batch_size=batch_size,
            num_threads=num_threads,
            device_id=device_id
        )
        self.mode = mode
        
        self.input = ops.Caffe2Reader(path=data_paths[mode], random_shuffle=True, name='Reader')
        self.decode = ops.ImageDecoder(device = 'mixed', output_type = types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device="gpu",
            dtype=types.FLOAT,
            std=[0.3081 * 255],
            mean=[0.1307 * 255],
            output_layout=types.NCHW,
        )
    
    def define_graph(self):
        jpegs, labels = self.input()
        images = self.decode(jpegs)
        images = self.cmn(images)
        return images, labels.gpu()
    
    def __len__(self):
        return 60000 if self.mode == 'train' else 10000

In [4]:
# Customizing DALI loader for using in catalyst.
class DALILoader(DataLoader):
    def __init__(
        self,
        mode: str = 'train',
        batch_size: int = 32,
        num_workers: int = 4,
    ):
        self.batch_size = batch_size
        
        self.pipeline = MNISTPipeline(mode=mode, batch_size=batch_size, num_threads=num_workers)
        self.pipeline.build()
        
        self.loader = DALIGenericIterator(
            pipelines=self.pipeline,
            output_map=['features', 'targets'],
            size=len(self.pipeline),
            auto_reset=True,
            last_batch_policy=LastBatchPolicy.PARTIAL,
        )
        
    def __len__(self):
        return len(self.loader)
    
    def __iter__(self):
        return ({'features': batch[0]["features"], 'targets': batch[0]["targets"].squeeze().long()} for batch in self.loader)
    
    def sampler(self):
        return None
    
    def batch_sampler(self):
        return None

In [5]:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl

In [6]:
BATCH_SIZE = 32
NUM_WORKERS = 8

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

loaders = {
    'train': DALILoader(mode='train', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS),
    'valid': DALILoader(mode='valid', batch_size=BATCH_SIZE, num_workers=NUM_WORKERS),
}



In [7]:
runner = dl.SupervisedRunner()

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
        dl.AUCCallback(input_key="logits", target_key="targets"),
        # catalyst[ml] required ``pip install catalyst[ml]``
        dl.ConfusionMatrixCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
    ]
)

HBox(children=(HTML(value='1/1 * Epoch (train)'), FloatProgress(value=0.0, max=1875.0), HTML(value='')))


train (1/1) accuracy: 0.8600333333333344 | accuracy/std: 0.06956738615119247 | accuracy01: 0.8600333333333344 | accuracy01/std: 0.06956738615119247 | accuracy03: 0.9711833333333335 | accuracy03/std: 0.03746301216102217 | accuracy05: 0.9904 | accuracy05/std: 0.021567091111003207 | auc: 0.9589463472366333 | auc/_macro: 0.9589463472366333 | auc/_micro: 0.9643569778395064 | auc/_weighted: 0.9597810506820679 | auc/class_00: 0.9915761351585388 | auc/class_01: 0.9921147227287292 | auc/class_02: 0.9578843116760254 | auc/class_03: 0.948928952217102 | auc/class_04: 0.9723076224327087 | auc/class_05: 0.9283470511436462 | auc/class_06: 0.9845988154411316 | auc/class_07: 0.9818652272224426 | auc/class_08: 0.9049604535102844 | auc/class_09: 0.9268798232078552 | f1/_macro: 0.8582526362879174 | f1/_micro: 0.8600283333624017 | f1/_weighted: 0.8599756467422307 | f1/class_00: 0.9331256991150798 | f1/class_01: 0.9339300448150216 | f1/class_02: 0.848520581821531 | f1/class_03: 0.8315052041164961 | f1/cla

HBox(children=(HTML(value='1/1 * Epoch (valid)'), FloatProgress(value=0.0, max=313.0), HTML(value='')))


valid (1/1) accuracy: 0.8767000000000005 | accuracy/std: 0.05568253814711977 | accuracy01: 0.8767000000000005 | accuracy01/std: 0.05568253814711977 | accuracy03: 0.9773 | accuracy03/std: 0.02544159443652752 | accuracy05: 0.9931999999999997 | accuracy05/std: 0.014492446415298916 | auc: 0.9759805798530579 | auc/_macro: 0.9759805798530579 | auc/_micro: 0.9729051977777778 | auc/_weighted: 0.9763350486755371 | auc/class_00: 0.9960879683494568 | auc/class_01: 0.9964075684547424 | auc/class_02: 0.9756388068199158 | auc/class_03: 0.975641667842865 | auc/class_04: 0.9830971956253052 | auc/class_05: 0.9661995768547058 | auc/class_06: 0.9908614754676819 | auc/class_07: 0.9855570793151855 | auc/class_08: 0.9420429468154907 | auc/class_09: 0.9482711553573608 | f1/_macro: 0.8759836176129701 | f1/_micro: 0.8766950000285159 | f1/_weighted: 0.8775865870995762 | f1/class_00: 0.9440820901827106 | f1/class_01: 0.9570668515130788 | f1/class_02: 0.8349113495850965 | f1/class_03: 0.8058284587100432 | f1/cl