# Decoding ImageNet TFRecords with Nvidia DALI

In [None]:
%reload_ext autoreload
%autoreload 2

import os
from glob import glob
import socket
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils import data
import torchvision.models as models

from torchsummary import summary
from tqdm.notebook import trange, tqdm

In [None]:
from dali_pt_dataloader import dali_dataloader

train_loader = dali_dataloader(
    batch_size=256,
    num_threads=os.cpu_count(),
    tfrec_filenames=sorted(glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/train/*')),
    tfrec_idx_filenames=sorted(glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/idx_files/train/*')),
    shard_id=0,
    num_shards=1,
    gpu_aug=True,
    gpu_out=True,
    training=True,
)
valid_loader = dali_dataloader(
    batch_size=200,
    num_threads=os.cpu_count(),
    tfrec_filenames=sorted(glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/validation/*')),
    tfrec_idx_filenames=sorted(glob('/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/idx_files/validation/*')),
    shard_id=0,
    num_shards=1,
    gpu_aug=True,
    gpu_out=True,
    training=False,
)

dict(train_loader=len(train_loader), valid_loader=len(valid_loader))

In [None]:
n_classes = 1000
img_size = 224

class ConvBlock(nn.Sequential):
    def __init__(self, channels_in, channels_out, kernel_size=4, padding=1, stride=2, bias=False):
        super().__init__(OrderedDict([
            ('conv', nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, padding=padding, stride=stride, bias=bias)),
            ('bn', nn.BatchNorm2d(channels_out)),
            ('act', nn.SiLU()),
        ]))

model = nn.Sequential(OrderedDict([
    ('b224', ConvBlock(3, 32)),
    ('b112', ConvBlock(32, 64)),
    ('b56', ConvBlock(64, 128)),
    ('b28', ConvBlock(128, 256)),
    ('b14', ConvBlock(256, 512)),
    ('b7', ConvBlock(512, 1024, kernel_size=3, padding=1, stride=1)),
    ('avg', nn.AdaptiveAvgPool2d(1)),
    ('flat', nn.Flatten()),
    ('drop', nn.Dropout(0.2)),
    ('classifier', nn.Linear(1024, n_classes, bias=False)),
]))

summary(model, (3, img_size, img_size), verbose=0)

In [None]:
%%writefile train_classifier.py
import torch
import torch.nn.functional as F
from tqdm.auto import trange, tqdm


def train(model, train_loader, optimizer, scheduler, logger, log_every=10):
    losses = RunningAverage('Loss')
    acc = RunningAverage('Acc')
    model.train()
    
    for i, (batch,) in enumerate(tqdm(train_loader, leave=False)):
        imgs, target = batch['data'], batch['label'].ravel()

        # Calculate CE loss
        logits = F.log_softmax(model(imgs), dim=1)
        loss = F.nll_loss(logits, target)

        # Zero the parameter gradients
        optimizer.zero_grad()
        # Calculate gradients
        loss.backward()
        # Training step
        optimizer.step()
        # Update LR
        scheduler.step()

        # Update metrics
        with torch.no_grad():
            losses.update(loss.item())
            acc.update((logits.argmax(dim=1) == target).float().mean().item())
            if i % log_every == 0:
                logger(f'{acc} {losses} lr:{optimizer.param_groups[0]["lr"]:.01e}')


@torch.no_grad()
def validate(model, valid_loader):
    losses = AverageMeter('Loss')
    acc = AverageMeter('Acc')
    model.eval()

    for i, (batch,) in enumerate(tqdm(valid_loader, leave=False)):
        imgs, target = batch['data'], batch['label'].ravel()
        # Calculate CE loss
        logits = F.log_softmax(model(imgs), dim=1)
        loss = F.nll_loss(logits, target)
        # Update metrics
        losses.update(loss.item(), len(target))
        acc.update((logits.argmax(dim=1) == target).float().mean().item(), len(target))

    return losses, acc


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':.03f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self, device='cuda'):
        total = torch.FloatTensor([self.sum, self.count], device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name}:{avg' + self.fmt + '}'
        return fmtstr.format(**self.__dict__)

    
class RunningAverage(object):
    """Computes and stores the running average of the given value"""
    def __init__(self, name, fmt=':.03f', beta=0.98):
        self.name = name
        self.fmt = fmt
        self.beta = beta
        self.reset()

    def reset(self):
        self.avg = None

    def update(self, val):
        if self.avg is None:
            self.avg = val
        self.avg = self.beta*self.avg + (1-self.beta)*val

    def __str__(self):
        fmtstr = '{name}:{avg' + self.fmt + '}'
        return fmtstr.format(**self.__dict__)

In [None]:
from train_classifier import train, validate

epochs = 10
history = []
lr = 0.01

optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=1e-5, nesterov=True)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=lr,
                                                epochs=epochs,
                                                steps_per_epoch=len(train_loader),
                                                pct_start=0.01,
                                                final_div_factor=100)

pbar = trange(epochs)
for e in pbar:
    train(model, train_loader, optimizer, scheduler, logger=pbar.set_description)
    losses, acc = validate(model, valid_loader)
    print(f'{e}: {acc} {losses}')