# Basic training on multiple GPUs of a CNN on imagenet from tfrecord files using TensorFlow's `tf.data` API

Here we will run a simplified training loop for a CNN model on ImageNet. We will create a TensorFlow's [`tf.data` API](https://www.tensorflow.org/guide/data) input pipeline based to feed to model with ImageNet data stored in tfrecord files.

We use [TensorFlow Datasets](https://www.tensorflow.org/datasets) to convert a `tf.data.Dataset` dataset to an iterable of NumPy arrays:
```python
np_dataset = tfds.as_numpy(tf_dataset)
```
from which the data is converted to `torch.tensor` and then moved to the GPU.

In this notebook we will use PyTorch's [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#all-together-automatic-mixed-precision).

In [None]:
import ipcmagic

In [None]:
%ipcluster start -n 2

In [None]:
%pxconfig --progress-after -1

In [None]:
%%px
import glob
import time
import numpy as np
import tensorflow_datasets as tfds
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import tensorflow as tf
from torch.nn.parallel import DistributedDataParallel
from torchvision import models
from pt_distr_env import DistributedEnviron

In [None]:
%%px
tfrec_files = glob.glob(f'/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k//train/*')

In [None]:
%%px
tf.config.set_visible_devices(
    tf.config.list_physical_devices('CPU')
)

In [None]:
%%px
distr_env = DistributedEnviron()
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()
device = 0

In [None]:
%%px
batch_size = 128

In [None]:
%%px
def decode(serialized_example):
    """Decode and resize"""
    example = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(example['image/encoded'], channels=3)
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
    image = tf.transpose(image, (2, 0, 1)) # rgb channels to the front
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    label = example['image/class/label'] - 1  # -> [0-999]
    return image, label

In [None]:
%%px
dataset = tf.data.TFRecordDataset(tfrec_files)
dataset = dataset.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.shard(world_size, rank)

In [None]:
%%px
dataset_np = tfds.as_numpy(dataset)

In [None]:
%%px
_model = models.resnet50()
_model.to(device);

ddp_model = DistributedDataParallel(_model, device_ids=[device])

In [None]:
%%px
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

In [None]:
%%px
use_amp = True

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

def benchmark_step_amp(model, imgs, labels):
    optimizer.zero_grad()
    with torch.autocast(device_type='cuda',
                        dtype=torch.float16,
                        enabled=use_amp):
        output = model(imgs)
        loss = F.cross_entropy(output, labels)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

In [None]:
%%px
print()
num_epochs = 5
num_iters = 10
imgs_sec = []
for epoch in range(num_epochs):
    t0 = time.time()
    for step, (imgs, labels) in enumerate(dataset_np):
        if step > num_iters:
            break

        imgs = torch.from_numpy(imgs).to(device)
        labels = torch.from_numpy(labels).to(device)
        benchmark_step_amp(ddp_model, imgs, labels)

    dt = time.time() - t0
    imgs_sec.append(batch_size * num_iters / dt)

    print(f' * Epoch {epoch:2d}: '
          f'{imgs_sec[epoch]:.2f} images/sec per GPU')

In [None]:
%ipcluster stop