In [2]:
from pathlib import Path
import tensorflow as tf
import numpy as np
from tqdm.notebook import tqdm
from model.dali_pipe import dali_generator
from model.resnet import Resnet50
from model.lars import LARS
from model.scheduler import WarmupExponentialDecay
import horovod.tensorflow as hvd
import tensorflow_addons as tfa

In [3]:
hvd.init()

data_dir = Path('/workspace/shared_workspace/data/imagenet/')
index_dir = Path('/workspace/shared_workspace/data/imagenet_index/')
train_files = [i.as_posix() for i in data_dir.glob('*1024')]
train_index = [i.as_posix() for i in index_dir.glob('*1024')]

global_batch = 512
per_gpu_batch = global_batch//hvd.size()
image_count = 1282048
steps_per_epoch = image_count//global_batch
learning_rate = 0.01*global_batch/256
scaled_rate = 3.7

tf.keras.backend.set_floatx('float16')
tf.keras.backend.set_epsilon(1e-4)
tf.config.optimizer.set_jit(True)

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

In [4]:
scheduler = WarmupExponentialDecay(tf.cast(learning_rate, tf.float16), 
                                   scaled_rate, steps_per_epoch, steps_per_epoch*10, 0.0001)
train_tf = dali_generator(train_files, train_index, per_gpu_batch, num_threads=8, device_id=hvd.rank(), total_devices=hvd.size())
#model = Resnet50()
model = tf.keras.applications.ResNet50(weights=None, input_shape=(224, 224, 3), classes=1000)
optimizer = LARS(scheduler, use_nesterov=False, clip=False)
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

In [5]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        pred = model(images, training=True)
        loss = loss_func(labels, pred)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

In [6]:
for epoch in range(30):
    loss = []
    progressbar = tqdm(range(steps_per_epoch))
    for batch in progressbar:
        images, labels = next(train_tf)
        loss.append(train_step(images, labels).numpy())
        progressbar.set_description("train_loss: {0:.4f}".format(np.array(loss[-100:]).mean()))

HBox(children=(FloatProgress(value=0.0, max=2504.0), HTML(value='')))

KeyboardInterrupt: 

In [6]:
def validation_step(images, labels):
    with tf.device('/gpu:0'):
        pred = model(images, training=False)
        loss = loss_func(labels, pred)
        top_5_pred = tf.math.top_k(pred, k=5)[1]
        top_1_pred = tf.math.top_k(pred, k=1)[1]
        labels = tf.cast(labels, tf.int32)
    top_1 = sum([label in a_pred for label, a_pred in zip(labels, top_1_pred)])
    top_5 = sum([label in a_pred for label, a_pred in zip(labels, top_5_pred)])
    return loss, top_1, top_5

def validation(steps = 128):
    loss_tracker = []
    top_1_tracker = 0
    top_5_tracker = 0
    for _ in range(steps):
        images, labels = next(validation_tdf)
        loss, top_1, top_5 = validation_step(images, labels)
        loss_tracker.append(loss.numpy())
        top_1_tracker += top_1
        top_5_tracker += top_5
    return sum(loss_tracker)/len(loss_tracker), top_1_tracker/(steps*batch_size), top_5_tracker/(steps*batch_size)


In [7]:
validation_files = [i.as_posix() for i in data_dir.glob('*0128')]
validation_index = [i.as_posix() for i in index_dir.glob('*0128')]
validation_tdf = dali_generator(validation_files, validation_index, batch_size, num_threads=4, device_id=hvd.rank(), total_devices=hvd.size())

In [8]:
validation()

(7.04437255859375, 0.00079345703125, 0.00439453125)