# **This notebook demonstrates the training of LYFT data on Tensorflow TPU using custom training loop .** 

Few keypoints of the notebook are :


* Since LYFT data is very large so i have zipped the tfrecord files to store more data into one file, tfrecord files contains images of size 224 and channel dimension 25 , num_history_channel 10 . I have added the target positions and availabilities to the validation tfrecord just like training tfrecord that will allow evaluation during training and also faster experiments.

* To increase the data input to TPU vectorization(operate on batches) on user defined function has been done before map function that increases the throughput upto 4 times and significantly reduces the training time .To make this happen tf.io.parse_example is used instead of tf.io.parse_single_example that parses single example at a time

* dataset.batch(drop remainder = True) is used for both training and validation that is slighlty faster as discussed [here](http://https://www.kaggle.com/mgornergoogle/custom-training-loop-with-100-flowers-on-tpu), Although it will drop some validation example but its ok as we have large amount of validation samples.  

* For validation dataset.cache() has not been used due to the large size of data that results in memory overflow and TPU throws socket closed error after few epochs.

* Custom loss ,Transform points function has been modified to calculate on batches. 

all the other necessary information have been given under comments in the code

In [None]:

import tensorflow as tf
import pandas as pd
import numpy as np
from pathlib import Path
from kaggle_datasets import KaggleDatasets
from collections import namedtuple
import matplotlib.pyplot as plt
import time


# TPU Detection And Initialization

In [None]:

try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
# combining all the tfrecord files

train_files = []
# for i in range(3):
#     TRAIN_GCS_PATH = KaggleDatasets().get_gcs_path(f'../input/tfreclyft')
train_files += tf.io.gfile.glob('../input/tfreclyft/tfrecords' + '/training' + '/shard*.tfrecord' )
 


VALID_GCS_PATH = KaggleDatasets().get_gcs_path('lyft-validation-tfrecord-224')
valid_files = tf.io.gfile.glob(VALID_GCS_PATH  +'/validation' + '/shard*.tfrecord' )

np.random.shuffle(train_files)       

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
IMG_DIM = 224
CHANNEL_DIM = 25
EPOCHS = 10
STEPS_PER_EPOCH = 500
VALIDATION_STEP = 200
GLOBAL_BATCH_SIZE = 16*strategy.num_replicas_in_sync 

# Dataset PipeLine

In [None]:

feature_descriptions_train = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'target_positions': tf.io.FixedLenFeature([], tf.string),
    'target_availabilities': tf.io.FixedLenFeature([], tf.string),
}

feature_dtypes_train = {
    'image': tf.uint8,
    'target_positions': tf.float32,
    'target_availabilities': tf.uint8,
}

feature_descriptions_valid = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'target_positions': tf.io.FixedLenFeature([], tf.string),
    'target_availabilities': tf.io.FixedLenFeature([], tf.string),
    'target_yaws': tf.io.FixedLenFeature([], tf.string),
    'world_from_agent' : tf.io.FixedLenFeature([], tf.string),
    'history_positions': tf.io.FixedLenFeature([], tf.string),
    'history_yaws': tf.io.FixedLenFeature([], tf.string),
    'history_availabilities': tf.io.FixedLenFeature([], tf.string),
    'world_to_image': tf.io.FixedLenFeature([], tf.string),
    'track_id': tf.io.FixedLenFeature([], tf.string),
    'timestamp': tf.io.FixedLenFeature([], tf.string),
    'centroid': tf.io.FixedLenFeature([], tf.string),
    'yaw': tf.io.FixedLenFeature([], tf.string),
    'extent': tf.io.FixedLenFeature([], tf.string),
}

feature_dtypes_valid = {
    'image': tf.uint8,
    'target_positions': tf.float32,
    'target_availabilities': tf.float32,
    'target_yaws': tf.float32,
    'world_from_agent' : tf.float64,
    'history_positions': tf.float32,
    'history_yaws':tf.float32,
    'history_availabilities': tf.float32,
    'world_to_image': tf.float64,
    'track_id': tf.int64,
    'timestamp': tf.int64,
    'centroid': tf.float64,
    'yaw': tf.float64,
    'extent': tf.float32,
}

In [None]:
def read_train_tfrecord(example):
    
    # tf.io.parse_example parse examples in batches
    example = tf.io.parse_example(example, feature_descriptions_train)
    data = {}
    data['image']  =   [tf.squeeze(tf.io.parse_tensor(example['image'][i], tf.uint8)) for i in range(GLOBAL_BATCH_SIZE)]
    data['target_availabilities'] = [tf.squeeze(tf.io.parse_tensor(example['target_availabilities'][i]  , tf.uint8)) for i in range(GLOBAL_BATCH_SIZE)]
    data['target_positions']  = [tf.squeeze(tf.io.parse_tensor(example['target_positions'][i], tf.float32)) for i in range(GLOBAL_BATCH_SIZE)]

    image = tf.image.convert_image_dtype(data['image'], dtype = tf.float32)
    target_avail  = tf.image.convert_image_dtype(data['target_availabilities'], dtype = tf.float32)
    
    image = tf.transpose(image , [0, 2, 3, 1])               # converting images to format (batch_size, height, width, channel)
    image = tf.reshape(image, shape = (GLOBAL_BATCH_SIZE,IMG_DIM, IMG_DIM, CHANNEL_DIM))
    target_pos = tf.reshape(data['target_positions'] , shape = (GLOBAL_BATCH_SIZE, 50, 2))
    target_avail =  tf.reshape(target_avail , shape = (GLOBAL_BATCH_SIZE, 50))
    return image , target_pos , target_avail



def read_validation_tfrecord(example):
    
    example = tf.io.parse_example(example, feature_descriptions_valid)
    data = {}
    data['image'] = [tf.io.parse_tensor(example['image'][i], tf.uint8) for i in range(GLOBAL_BATCH_SIZE)],
    data['target_availabilities'] = [tf.io.parse_tensor(example['target_availabilities'][i], tf.float32) for i in range(GLOBAL_BATCH_SIZE)]
    data['target_positions']  = [tf.io.parse_tensor(example['target_positions'][i], tf.float32) for i in range(GLOBAL_BATCH_SIZE)],
    data['world_from_agent'] = [tf.io.parse_tensor(example['world_from_agent'][i], tf.float64) for i in range(GLOBAL_BATCH_SIZE)],
    
    image = tf.image.convert_image_dtype(data['image'], dtype = tf.float32)
    
    image = tf.transpose(tf.squeeze(image), [0,2,3,1]) 
    image = tf.reshape(image, shape = (GLOBAL_BATCH_SIZE,IMG_DIM, IMG_DIM, CHANNEL_DIM)) 
    target_pos = tf.reshape(tf.squeeze(data['target_positions']) , shape = (GLOBAL_BATCH_SIZE, 50, 2))
    target_avail = tf.reshape(tf.squeeze(data['target_availabilities']) , shape = (GLOBAL_BATCH_SIZE, 50))
    world_from_agent = tf.reshape(tf.squeeze(data['world_from_agent']) , shape = (GLOBAL_BATCH_SIZE ,3 , 3))
    return image , target_pos , target_avail , world_from_agent

In [None]:
def load_dataset(filenames, labeled = True, ordered = False , training = True):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    
    ignore_order = tf.data.Options()
    if not ordered:
        # disable order, increase speed
        ignore_order.experimental_deterministic = False 
        
    # automatically interleaves reads from multiple files
    # tfrecords file are zipped
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO )

    # use data as soon as it streams in, rather than in its original order
    dataset = dataset.with_options(ignore_order)
    
    # here we are vectorizing(operate over batches) the function by doing batch transformation before map function
    dataset = dataset.batch(GLOBAL_BATCH_SIZE , drop_remainder = True)
    dataset = dataset.map(read_train_tfrecord if training else read_validation_tfrecord, num_parallel_calls = AUTO)
    return dataset


def get_dataset(files , training = True):
   
    if training:
        dataset = load_dataset(files , training = True)
        dataset = dataset.shuffle(128).prefetch(AUTO)

    else:
        # we are not caching the validation data due to larger size which will result in memory overflow
        dataset = load_dataset(files , training = False)
        dataset = dataset.prefetch(AUTO)
        
    return dataset

# Visualization
Now we will visualize the model input


In [None]:

# loading the dataset
training_dataset = get_dataset('../input/tfreclyft/tfrecords/training/shard_000.tfrecord', training = True)
validation_dataset = get_dataset(valid_files , training = False)
training_dataset.apply(tf.data.experimental.ignore_errors())
image , target_pos , target_avail = next(iter(training_dataset))
print('shapes of images, target_positions, target_avail are ', (image.shape , target_pos.shape , target_avail.shape))

#plotting input image
plt.figure(figsize=(15, 15))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(image[45][:, :, i])
plt.show()

In [None]:
image , target_pos , target_avail,_ = next(iter(validation_dataset))
print('shapes of images, target_positions, target_avail are ', (image.shape , target_pos.shape , target_avail.shape))

#plotting input image
plt.figure(figsize=(15, 15))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(image[45][:, :, i])
plt.show()

# Custom Metric

In [None]:
def neg_multi_log_likelihood(gt, pred, confidences, avails,):
    
    # function calculates loss in Batch
    
    assert len(pred.shape) == 4, f"expected 4D (B,M,T,C) array for pred, got {pred.shape}"
    batch_size , num_modes, future_len, num_coords = pred.shape

    assert gt.shape == (batch_size , future_len, num_coords), f"expected 3D (Batch, Time , Coords) array for gt, got {gt.shape}"
    assert confidences.shape == (batch_size , num_modes,), f"expected 2D (Batch, Modes) array for gt, got {confidences.shape}"
    assert avails.shape == (batch_size , future_len,), f"expected 2D (Batch,Time) array for gt, got {avails.shape}"
    
    gt = tf.expand_dims(gt, axis = 1)  # add modes
    avails = avails[: , None, :, None]  # add modes and cords
    error = tf.math.reduce_sum(((gt - pred) * avails) ** 2, axis=-1)    # reduce coords and use availability
    confidences = tf.clip_by_value(confidences , clip_value_min = tf.pow(0.1, 12) , clip_value_max = 1.0)   # to avoid exploding gradient
    error = tf.math.log(confidences ) - 0.5 * tf.math.reduce_sum(error, axis=-1)  # reduce time

    # use max aggregator on modes for numerical stability
    max_value = tf.math.reduce_max(error, axis=1, keepdims=True)  # error are negative at this point, so max() gives the minimum one
    error = -tf.math.log(tf.math.reduce_sum(tf.exp(error - max_value), axis=-1, keepdims=True)) - max_value  # reduce modes
    return error


   
def transform_points(points, transf_matrix):
    
    # transform prediction to world coordinates in batches
    
    transf_matrix = tf.expand_dims(transf_matrix , axis = -1)
    assert len(points.shape) == len(transf_matrix.shape) == 4, (
    f"dimensions mismatch, both points ({points.shape}) and "
    f"transf_matrix ({transf_matrix.shape}) needs to be tensors of rank 4."
    )

    if points.shape[3] not in [2, 3]:
        raise AssertionError(f"Points input should be (N, 2) or (N, 3) shape, received {points.shape}")

    assert points.shape[3] == transf_matrix.shape[2] - 1, "points dim should be one less than matrix dim"

    points = tf.cast(points , tf.float64)
    points = tf.matmul(points , tf.transpose(transf_matrix[:, :-1, :-1, :] , perm = [0,3,2,1])) 
    return tf.cast(points , tf.float32)


In [None]:
# changing input channel to match rasterizer output will change the weights dimension of the layer that will 
# take input , so accordingly we need to change the weights dimension of that layer

# resnet50 first convolution layer has weights shape (7, 7, 3, 64) for 3 input channel
# and required is (7, 7, 25, 64) for 25 input channel

def modified_resnet50():
    
     # model with 3 input channel dim with pretrained weights
    pretrained_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape = (None, None, 3)) 
    
    # model with 25 input channel dim without pretrained weight
    modified_model = tf.keras.applications.ResNet50(include_top=False, weights= None, input_shape = (None, None, 25))                

    for pretrained_model_layer , modified_model_layer in zip(pretrained_model.layers , modified_model.layers):
        layer_to_modify = ['conv1_conv']                    # conv1_conv is name of layer that takes the input and will be modified
        if pretrained_model_layer.name in layer_to_modify :          
            kernel = pretrained_model_layer.get_weights()[0]  # kernel weight shape is (7, 7 ,3, 64)
            bias = pretrained_model_layer.get_weights()[1]
            
            # concatenating along channel axis to make channel dimension 25
            weights = np.concatenate((kernel[:, :, -1: ,:] , np.tile( kernel , [1, 1, 8, 1]) ) , axis=  -2)  
            modified_model_layer.set_weights((weights , bias))
        else:
            modified_model_layer.set_weights(pretrained_model_layer.get_weights())

    return modified_model

# Define Model

In [None]:

class LyftModel(tf.keras.Model):
    def __init__(self ,num_modes = 3,future_pred_frames = 50 ,):
        super(LyftModel , self).__init__()
        
        self.model = modified_resnet50()
        self.gap = tf.keras.layers.GlobalAveragePooling2D()
        self.dropout = tf.keras.layers.Dropout(0.2)
        
        self.future_len = num_modes * future_pred_frames * 2      
        self.future_pred_frames = future_pred_frames
        self.num_modes = num_modes
        
        self.dense1 = tf.keras.layers.Dense(self.future_len + self.num_modes ,)
  

        
    def call(self,inputs):
        x = self.model(inputs)
        x = self.gap(x)
        x = self.dropout(x)
        x  = self.dense1(x)
        
        batch_size, _  = x.shape
        pred , confidence = tf.split(x , num_or_size_splits = [self.future_len, self.num_modes], axis = 1)
        assert confidence.shape == (batch_size , self.num_modes) , f'confidence got shape {confidence.shape}'
        pred = tf.reshape(pred , shape = (batch_size ,self.num_modes, self.future_pred_frames , 2))
        confidence = tf.nn.softmax(confidence , axis = 1)
        return pred , confidence 
        


In [None]:
# defining the model and variables

def get_model_and_variables():
    model = LyftModel()
    model.build((GLOBAL_BATCH_SIZE , IMG_DIM , IMG_DIM , CHANNEL_DIM))
    model.summary()
    optimizer = tf.keras.optimizers.Adam(learning_rate= 0.001)

    # tf.nn.compute_average_loss will aggregate the per example loss across all replicas and returns the average scalar loss 
    loss_func = lambda a, b, c, d : tf.nn.compute_average_loss(neg_multi_log_likelihood(a, b, c, d) , 
                                global_batch_size = GLOBAL_BATCH_SIZE )
    
    transf_points = lambda pred , world_from_agent : transform_points(pred, world_from_agent)

    # metrics
    training_loss = tf.keras.metrics.Sum()       
    validation_loss = tf.keras.metrics.Sum()
    return model , optimizer , loss_func , transf_points ,training_loss , validation_loss


# Train and Validation step function

In [None]:
# @tf.function compiles the function to tensorflow graph to run on tpu

@tf.function
def train_step(image, target_pos , target_avail): 
    # this function will run on each replica and constitutes one training step
    
    with tf.GradientTape() as tape :
        pred , confidence = model(image , training = True)
        loss_value = loss_func(target_pos , pred, confidence , target_avail)
    grads = tape.gradient(loss_value , model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads , model.trainable_variables)))
    training_loss.update_state(loss_value)

        
@tf.function
def valid_step(image , target_pos, target_avail, world_from_agent):
    pred , confidence = model(image , training = False)
    pred = transf_points(pred , world_from_agent)          # transforming the points to world space coordinates
    val_loss = loss_func(target_pos , pred, confidence , target_avail)  # calculating validation loss
    validation_loss.update_state(val_loss)



# Custom Training Loop

In [None]:
# training loop

def train_model(train_dataset , valid_dataset ,EPOCHS =  5, STEPS_PER_EPOCH = 200 ,VALIDATION_STEP = 100):
    # now we will distribute the dataset according to the strategy here it is TPUStrategy
    
    train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)  
    valid_dist_dataset = strategy.experimental_distribute_dataset(valid_dataset)
    
    start_time = epoch_start_time = time.time()

    print("Steps per epoch:", STEPS_PER_EPOCH)
    History = namedtuple('History', 'history')
    history = History(history={'train_loss': [], 'val_loss': [],})

    epoch = 0
    for step , (images, target_pos, target_avails) in enumerate(train_dist_dataset):

        # each iteration on train dist dataset returns per replica object dictionary containing data for each worker or replica
        # batch size for each replica is GLOBAL_BATCH_SIZE /strategy.num_replicas_in_sync
        
        #strategy.run will distribute train_step and execute operation as specified by function on each replica 

        strategy.run(train_step, args=(images, target_pos, target_avails))
        #print('=' , end = ' ' , flush = True)

        # validation run at the end of each epoch
        if ((step+1) // STEPS_PER_EPOCH) > epoch:

            # validation run
            for val_step ,(images , target_pos , target_avails , world_from_agent) in enumerate(valid_dist_dataset):
                strategy.run(valid_step, args=(images, target_pos, target_avails ,world_from_agent))
                if (val_step + 1) % VALIDATION_STEP == 0: 
                    break

            # storing result 
            history.history['train_loss'].append(training_loss.result().numpy() / STEPS_PER_EPOCH )
            history.history['val_loss'].append((validation_loss.result().numpy() / VALIDATION_STEP))

            # show metrics
            epoch_time = time.time() - epoch_start_time
            print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
            print('time: {:0.1f}s'.format(epoch_time),
                  'loss: {:0.4f}'.format(history.history['train_loss'][-1]),
                  'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),)
            
            # saving the model
            model.save_weights('./epoch {:d}, train_loss {:0.4f}, val_loss {:0.4f} model.h5'.format(epoch+1 , 
                                                                               history.history['train_loss'][-1],
                                                                               history.history['val_loss'][-1]  ))                   
                                                                              
            
            # set up next epoch
            
            epoch = (step+1) // STEPS_PER_EPOCH
            epoch_start_time = time.time()
            validation_loss.reset_states()
            training_loss.reset_states()
        
        if epoch >= EPOCHS:
            break


In [None]:
# calling the model and variables under strategy.scope to allow tpu to track and compute the variables

with strategy.scope():
    model , optimizer , loss_func , transf_points ,training_loss , validation_loss = get_model_and_variables()
    

train_model(training_dataset ,validation_dataset ,EPOCHS = EPOCHS ,STEPS_PER_EPOCH = STEPS_PER_EPOCH ,VALIDATION_STEP = VALIDATION_STEP)