# Basic training of a CNN on imagenet from tfrecord files using TensorFlow's `tf.data` API

Here we will run a simplified training loop for a CNN model on ImageNet. We will create a TensorFlow's [`tf.data` API](https://www.tensorflow.org/guide/data) input pipeline based to feed to model with ImageNet data stored in tfrecord files. We will apply random transformations to the images as done [here](https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset).

We use [TensorFlow Datasets](https://www.tensorflow.org/datasets) to convert a `tf.data.Dataset` dataset to an iterable of NumPy arrays:
```python
np_dataset = tfds.as_numpy(tf_dataset)
```
from which the data is converted to `torch.tensor` and then moved to the GPU.

In [None]:
import glob
import time
import numpy as np
import torch
import torch.nn.functional as F
import tensorflow_datasets as tfds
import torch.optim as optim
import tensorflow as tf
from torchvision import models

In [None]:
tfrec_files = glob.glob(f'/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k//train/*')

In [None]:
tf.config.set_visible_devices(
    tf.config.list_physical_devices('CPU')
)

In [None]:
batch_size = 128

In [None]:
def decode(serialized_example):
    """Decode and resize"""
    example = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(example['image/encoded'], channels=3)
    label = example['image/class/label'] - 1  # -> [0-999]
    return image, label

In [None]:
IMG_SIZE = 224

def resize_and_rescale(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = (image / 255.0)
    return image, label

def _augment(image, label, seed):
    # image, label = image_label
    image, label = resize_and_rescale(image, label)
    image = tf.image.resize_with_crop_or_pad(image, IMG_SIZE + 6, IMG_SIZE + 6)
    # Make a new seed.
    new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
    # Random crop back to the original size.
    image = tf.image.stateless_random_crop(
      image, size=[IMG_SIZE, IMG_SIZE, 3], seed=seed)
    # Random brightness.
    image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
    image = tf.clip_by_value(image, 0, 1)
    image = tf.transpose(image, (2, 0, 1))
    return image, label

# Create a generator.
rng = tf.random.Generator.from_seed(123, alg='philox')

# Create a wrapper function for updating seeds.
def augment(x, y):
    seed = rng.make_seeds(2)[0]
    image, label = _augment(x, y, seed)
    return image, label

In [None]:
dataset = tf.data.TFRecordDataset(tfrec_files)
dataset = dataset.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

In [None]:
dataset_np = tfds.as_numpy(dataset)

In [None]:
device = 0

model = models.resnet50()
model.to(device);

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
def benchmark_step(model, imgs, labels):
    optimizer.zero_grad()
    output = model(imgs)
    loss = F.cross_entropy(output, labels)
    loss.backward()
    optimizer.step()

In [None]:
num_epochs = 5
num_iters = 10
imgs_sec = []
for epoch in range(num_epochs):
    t0 = time.time()
    for step, (imgs, labels) in enumerate(dataset_np):
        if step > num_iters:
            break

        imgs = torch.from_numpy(imgs).to(device)
        labels = torch.from_numpy(labels).to(device)
        benchmark_step(model, imgs, labels)

    dt = time.time() - t0
    imgs_sec.append(batch_size * num_iters / dt)

    print(f' * Epoch {epoch:2d}: '
          f'{imgs_sec[epoch]:.2f} images/sec per GPU')