<a href="https://colab.research.google.com/github/kp425/nlp_lab/blob/master/TPU_optimized_custom_training_loops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential
import random

tf.random.set_seed(101)
random.seed(101)

In [None]:
def connect_to_tpu():

    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
        print("Running on TPU  ", resolver.master())
    except ValueError: # detect GPUs
        resolver = None

    if resolver:
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("All devices: ", tf.config.list_logical_devices('TPU'))
        strategy = tf.distribute.experimental.TPUStrategy(resolver)

    else:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
        #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines
    
    return strategy


strategy = connect_to_tpu()

print("Number of accelerators: ", strategy.num_replicas_in_sync)

# Sample Input data

In [3]:
w = 16
h = 16
c = 3
n_classes = 10
sample_size = 1280
BATCH_SIZE = 128

images = tf.random.normal((sample_size, w, h, c), 2.0,0.01)
labels = [random.randint(0,n_classes-1) for _ in range(sample_size)]
labels = tf.convert_to_tensor(labels, dtype = tf.float32)

ds = tf.data.Dataset.from_tensor_slices((images, labels)).cache().batch(BATCH_SIZE)
dist_ds = strategy.experimental_distribute_dataset(ds)



Create Identical models for each `train_step`

In [8]:
tf.random.set_seed(101)
random.seed(101)


def make_model():
    return Sequential([layers.Flatten(),
                        layers.Dense(32, activation = 'relu'),
                        layers.Dense(n_classes)])



reference_model = make_model()
#initialize weights, so they can be copied to model in strategy scope
reference_model(next(iter(ds))[0])

model = None
model_weights = reference_model.get_weights()
train_loss = None
train_acc = None
test_loss = None
test_acc = None
optimizer = None
loss_object = None
loss_function = None
forward_pass = None

def refresh(apply_strategy=True):

    def _refresh():
        global model, model_weights, train_loss, train_acc, \
                test_loss, test_acc, optimizer, loss_object, \
                loss_function, forward_pass

        
        model = make_model()
        model.build((None, w,h,c))
        model.set_weights(model_weights)

        train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
        train_acc = tf.keras.metrics.SparseCategoricalAccuracy('training_acc', dtype=tf.float32)

        test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
        test_acc = tf.keras.metrics.SparseCategoricalAccuracy('test_acc', dtype=tf.float32)

        optimizer = tf.keras.optimizers.Adam(0.01)
    
        loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
                    from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

        def loss_function(labels, preds):
            per_example_loss = loss_object(labels, preds)
            return tf.nn.compute_average_loss(per_example_loss, global_batch_size= BATCH_SIZE)

        def forward_pass(ds_chunk):
            inputs, labels = ds_chunk
            preds = model(inputs)
            return preds, labels
    
    if apply_strategy:
        with strategy.scope():
            _refresh()
    else:
        _refresh()



# Different types of custom strategy loops

Always use `multi_train_steps` or `train_epochs`, since they use 80-90% MXU and 0% idle time on TPU

In [None]:
@tf.function
def simple_train_step(ds_chunk):

    with tf.GradientTape() as tape:
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_object(labels, preds)
    gradients = tape.gradient(loss_val, model.trainable_variables) 
    optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                 experimental_aggregate_gradients=True)
    train_loss.update_state(loss_val)
    train_acc.update_state(labels, preds)



@tf.function
def dist_train_step_v1(ds_chunk):
    def _train_step(ds_chunk):
        
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        
        '''custom gradient aggregation can be done by using replica_context
            
            grads = tape.gradient(loss, vars)
            grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
            # Processing aggregated gradients.
            optimizer.apply_gradients(zip(grads, vars),
                                experimental_aggregate_gradients=False)
        '''
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                     experimental_aggregate_gradients=True)
        return loss_val, preds, labels

    #computations run on each replica parallely
    per_replica_losses, preds, labels = strategy.run(_train_step, args = (ds_chunk,))
    # strategy implicitly agrregate loss_vals and gradients by itself....
    # the reason we execute the below line is only to record the value in metrics
    global_batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    preds = tf.concat(preds.values, axis = 0)
    labels = tf.concat(labels.values, axis = 0)
    train_loss.update_state(global_batch_loss)
    train_acc.update_state(labels, preds)


@tf.function
def dist_train_step_v2(ds_chunk):
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)

        # tf.debugging.assert_shapes([(preds, (BATCH_SIZE//strategy.num_replicas_in_sync, n_classes))])
        # tf.debugging.assert_shapes([(labels, (BATCH_SIZE//strategy.num_replicas_in_sync,))])
        
        gradients = tape.gradient(loss_val, model.trainable_variables) 
   
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                     experimental_aggregate_gradients=True)
        
        ''' Here the loss_val is the loss value of a single replica. 
            Since the metric we are using here is Mean, 
            the loss_val at the end of the batched is scaled-down to 1/(num_of_replicas),
            To negate this effect, we multiply each loss_val with num_of_replicas, 
            so when the mean is taken, we will real loss_val (like total_loss_per_batch/ batch_size)

            To avoid this complexity, we can simply use Sum metric
            eg:
             train_loss = tf.keras.metrics.Sum()
             train_loss.update_state(loss_val)

             https://www.tensorflow.org/tutorials/distribute/custom_training#tracking_training_loss_across_replicas
             https://www.tensorflow.org/guide/tpu#input_datasets
        '''
        train_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        train_acc.update_state(labels, preds)

    strategy.run(_train_step, args = (ds_chunk,))



@tf.function
def multiple_dist_train_steps_v1(dist_iter, steps):
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                 experimental_aggregate_gradients=True)
        return loss_val, preds, labels
    
    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        per_replica_losses, preds, labels = strategy.run(_train_step, args=(optional_data.get_value(),))
        global_batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        labels = tf.concat(labels.values, axis = 0)
        preds = tf.concat(preds.values, axis = 0)
        train_loss.update_state(global_batch_loss)
        train_acc.update_state(labels, preds)  



@tf.function
def multiple_dist_train_steps_v2(dist_iter, steps):
    
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        tf.debugging.assert_shapes([(preds, (BATCH_SIZE, 10))])
        # tf.debugging.assert_shapes([(labels, (142,))])
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                     experimental_aggregate_gradients=True)
        
        train_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        train_acc.update_state(labels, preds)
    
    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        strategy.run(_train_step, args=(optional_data.get_value(),))
        # tf.print(strategy.experimental_local_results(per_replica_results))


@tf.function
def dist_train_epoch(ds):
     #https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction
    def _train_step(ds_chunk):
        with tf.GradientTape() as tape:
            preds, labels = forward_pass(ds_chunk)
            loss_val = loss_function(labels, preds)
        gradients = tape.gradient(loss_val, model.trainable_variables) 
        optimizer.apply_gradients(zip(gradients, model.trainable_variables),
                                 experimental_aggregate_gradients=True)
      
        train_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        train_acc.update_state(labels, preds)
    for chunk in ds:
        strategy.run(_train_step, args = (chunk,))


steps_per_epoch = sum([1 for _ in dist_ds])
steps = 3
epochs = 1

refresh(apply_strategy=False)
for i,chunk in enumerate(ds):
    if i==steps:break
    simple_train_step(chunk)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")

refresh()
for i,chunk in enumerate(dist_ds):
    if i==steps:break
    dist_train_step_v1(chunk)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")

refresh()
for i,chunk in enumerate(dist_ds):
    if i==steps:break
    dist_train_step_v2(chunk)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")

refresh()
multiple_dist_train_steps_v1(iter(dist_ds),steps)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")

refresh()
multiple_dist_train_steps_v2(iter(dist_ds),steps)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")

refresh()
for epoch in range(epochs):
    dist_train_epoch(dist_ds)
print(f"loss :{train_loss.result()}  acc:{train_acc.result()}")





# Test steps

In [26]:
@tf.function
def simple_test_step(ds_chunk):
    preds, labels = forward_pass(ds_chunk)
    loss_val = loss_object(labels, preds)
    test_loss.update_state(loss_val)
    test_acc.update_state(labels, preds)
    return loss_val, test_acc.result()

@tf.function
def dist_test_step_v1(ds_chunk):
    def _test_step(ds_chunk):
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        return loss_val, preds, labels
    per_replica_losses, preds, labels = strategy.run(_test_step, args=(ds_chunk,))
    global_batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    labels = tf.concat(labels.values, axis = 0)
    preds = tf.concat(preds.values, axis = 0)
    test_loss.update_state(global_batch_loss)
    test_acc.update_state(labels, preds)
    return labels

@tf.function
def dist_test_step_v2(ds_chunk):
    def _test_step(ds_chunk):
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        test_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        test_acc.update_state(labels, preds)
    strategy.run(_test_step, args=(ds_chunk,))



@tf.function
def multiple_dist_test_steps_v1(dist_iter, steps):
    def _test_step(ds_chunk):
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        return loss_val, preds, labels
    
    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        per_replica_losses, preds, labels = strategy.run(_test_step, args=(optional_data.get_value(),))
        global_batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        labels = tf.concat(labels.values, axis = 0)
        preds = tf.concat(preds.values, axis = 0)
        test_loss.update_state(global_batch_loss)
        test_acc.update_state(labels, preds)


@tf.function
def multiple_dist_test_steps_v2(dist_iter, steps):
    def _test_step(ds_chunk):
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        test_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        test_acc.update_state(labels, preds)

    for _ in tf.range(steps):
        optional_data = dist_iter.get_next_as_optional()
        if not optional_data.has_value():
            break
        strategy.run(_test_step, args=(optional_data.get_value(),))


@tf.function
def dist_test_epoch(ds):
     #https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction
    def _test_step(ds_chunk):   
        preds, labels = forward_pass(ds_chunk)
        loss_val = loss_function(labels, preds)
        test_loss.update_state(loss_val * strategy.num_replicas_in_sync)
        test_acc.update_state(labels, preds)

    for chunk in ds:
        strategy.run(_test_step, args = (chunk,))



steps_per_epoch = sum([1 for _ in dist_ds])
epochs = 1
steps = steps_per_epoch

full = ds.unbatch().batch(2000)
chunk = next(iter(full))
refresh(apply_strategy=False)
simple_test_step(chunk)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")


refresh(apply_strategy=False)
for i,chunk in enumerate(ds):
    if i==steps:break
    simple_test_step(chunk)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")

refresh()
for i,chunk in enumerate(dist_ds):
    if i==steps:break
    dist_test_step_v1(chunk)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")

refresh()
for i,chunk in enumerate(dist_ds):
    if i==steps:break
    dist_test_step_v2(chunk)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")


refresh()
multiple_dist_test_steps_v1(iter(dist_ds), steps)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")

refresh()
multiple_dist_test_steps_v2(iter(dist_ds), steps)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")


refresh()
dist_test_epoch(dist_ds)
print(f"loss :{test_loss.result()}  acc:{test_acc.result()}")

loss :3.4065730571746826  acc:0.10234375298023224
loss :3.4065730571746826  acc:0.10234375298023224
loss :3.4080862998962402  acc:0.10234375298023224
loss :3.4080862998962402  acc:0.10234375298023224
loss :3.4080862998962402  acc:0.10234375298023224
loss :3.4080862998962402  acc:0.10234375298023224
loss :3.4080862998962402  acc:0.10234375298023224
