In [None]:
import ipcmagic

In [None]:
%ipcluster start -n 2

In [None]:
%pxconfig --progress-after -1

In [None]:
%%px
import glob
import time
import numpy as np
import tensorflow_datasets as tfds
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import tensorflow as tf
from torch.nn.parallel import DistributedDataParallel
from torchvision import models
from pt_distr_env import DistributedEnviron

In [None]:
%%px
data_dir = '/scratch/snx3000/datasets/imagenet/ILSVRC2012_1k/'

tfrec_files = glob.glob(f'{data_dir}/train/*')

In [None]:
%%px
tf.config.set_visible_devices(
    tf.config.list_physical_devices('CPU')
)

In [None]:
%%px
distr_env = DistributedEnviron()
dist.init_process_group(backend="nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()
device = 0

In [None]:
%%px
batch_size = 128

In [None]:
%%px
def decode(serialized_example):
    """Decode and resize"""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    image = tf.image.resize_with_crop_or_pad(image, *(224, 224))
    image = tf.transpose(image, (2, 0, 1)) # rgb channels to the front
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    label = tf.cast(features['image/class/label'], tf.int64)
    return image, label - 1

In [None]:
%%px
dataset = tf.data.TFRecordDataset(tfrec_files)
dataset = dataset.map(decode, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
dataset = dataset.shard(world_size, rank)

In [None]:
%%px
dataset_np = tfds.as_numpy(dataset)

In [None]:
%%px
_model = models.resnet50()
_model.to(device);

ddp_model = DistributedDataParallel(_model, device_ids=[device])

In [None]:
%%px
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)

In [None]:
%%px
def benchmark_step(model, imgs, labels):
    optimizer.zero_grad()
    output = model(imgs)
    loss = F.cross_entropy(output, labels)
    loss.backward()
    optimizer.step()

In [None]:
%%px
num_epochs = 5
num_iters = 10
imgs_sec = []
for epoch in range(num_epochs):
    t0 = time.time()
    for step, (imgs, labels) in enumerate(dataset_np):
        if step > num_iters:
            break

        imgs = torch.tensor(imgs).to(device)
        labels = torch.tensor(labels).to(device)
        benchmark_step(ddp_model, imgs, labels)

    dt = time.time() - t0
    imgs_sec.append(batch_size * num_iters / dt)

    print(f' * Epoch {epoch:2d}: '
          f'{imgs_sec[epoch]:.2f} images/sec per GPU')

In [None]:
%ipcluster stop