# Importing Packages

In [30]:
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras import layers, initializers, regularizers
from tensorflow.keras import losses, metrics, optimizers, callbacks
from tensorflow import keras

# Import Data

In [31]:
(train_dataset, val_dataset, test_dataset), dataset_info = tfds.load('stanford_dogs', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], with_info=True, as_supervised=True)



In [32]:
def preprocess(image, label):
    # A function that preprocesses each image by normalizing
    # its values in [0, 1].
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

In [33]:
def resize(image,label):
  image = tf.image.resize(image, (500, 375))
  image = tf.cast(image, tf.float32) / 255.0

  return image,label 

In [34]:
def augment(image, label):
  # We perform data augmentation in the data pipeline.
  # As an alternative, we can add layers of image augmentation
  # at the beginning of our model.
  #image = tf.image.random_brightness(image, 0.1)
  image = tf.image.random_flip_left_right(image)
  return image, label

In [35]:
train_dataset = train_dataset.map(resize)
val_dataset = val_dataset.map(resize)
test_dataset = test_dataset.map(resize)

### Filter only breeds of label 1 to 10

In [7]:
'''
def breeds_of_interest_bas(image,label):
  condition1 = tf.math.equal(label, 1)
  condition2 = tf.math.equal(label, 2)
  condition3 = tf.math.equal(label, 3)
  condition4 = tf.math.equal(label, 4)
  condition5 = tf.math.equal(label, 5)
  condition6 = tf.math.equal(label, 6)
  condition7 = tf.math.equal(label, 7)
  condition8 = tf.math.equal(label, 8)
  condition9 = tf.math.equal(label, 9)
  return tf.reduce_any([condition1, condition2, condition3, condition4, condition5,
                        condition6, condition7, condition8, condition9])
'''

'\ndef breeds_of_interest_bas(image,label):\n  condition1 = tf.math.equal(label, 1)\n  condition2 = tf.math.equal(label, 2)\n  condition3 = tf.math.equal(label, 3)\n  condition4 = tf.math.equal(label, 4)\n  condition5 = tf.math.equal(label, 5)\n  condition6 = tf.math.equal(label, 6)\n  condition7 = tf.math.equal(label, 7)\n  condition8 = tf.math.equal(label, 8)\n  condition9 = tf.math.equal(label, 9)\n  return tf.reduce_any([condition1, condition2, condition3, condition4, condition5,\n                        condition6, condition7, condition8, condition9])\n'

In [36]:
def breeds_of_interest_eleg(image,label):
  return tf.reduce_any(tf.equal(label, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

In [37]:
train_dataset_filtered = train_dataset.filter(breeds_of_interest_eleg)
val_dataset_filtered = val_dataset.filter(breeds_of_interest_eleg) 
test_dataset_filtered = test_dataset.filter(breeds_of_interest_eleg)

## Basic Neural Network **without** Early Exits

In [38]:
train_data_p = train_dataset_filtered.shuffle(1000).batch(12).map(augment)
val_data_p = val_dataset_filtered.batch(12)
test_data_p = test_dataset_filtered.batch(12)

In [39]:
def add_conv_block(x, n_filters):
    # This function applies a simple "CNN block" to the input,
    # built as Conv2D -> BN -> ReLU -> MaxPool2D.
    x = layers.Conv2D(n_filters, 5, padding='same', kernel_regularizer=regularizers.L2(10e-3))(x)
    x = layers.BatchNormalization()(x)
    x = tf.nn.relu(x)
    return layers.MaxPool2D(2)(x)

In [12]:
def classification_layer(x_inp):
  x = layers.GlobalAvgPool2D()(x_inp) # Output shape: (None, 32)
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  x = layers.Dense(10)(x)  
  return x

In [18]:
def build_model():
  # Input part
  inp = layers.Input(shape=(500, 375, 3))
  # First Convolutional Block     # Output: (None, 500, 375, 24)
  x = add_conv_block(inp, 24)     

  # Second Convolutional Block    # Output: (None, 250, 187, 48) 
  x = add_conv_block(x, 48)       

  # Third Convolutional Block     # Output: (None, 125, 93, 96)
  x = add_conv_block(x, 96)

  # Fourth Convolutional Block    # Output: (None, 62, 46, 192)
  x = add_conv_block(x, 192)     

  x = classification_layer(x)
  return tf.keras.Model(inputs=inp, outputs=x)

In [21]:
#model = build_model()

In [22]:
#model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 500, 375, 3)]     0         
                                                                 
 conv2d_4 (Conv2D)           (None, 500, 375, 24)      1824      
                                                                 
 batch_normalization_4 (Batc  (None, 500, 375, 24)     96        
 hNormalization)                                                 
                                                                 
 tf.nn.relu_4 (TFOpLambda)   (None, 500, 375, 24)      0         
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 250, 187, 24)     0         
 2D)                                                             
                                                                 
 conv2d_5 (Conv2D)           (None, 250, 187, 48)      28848 

In [40]:
cross_entropy = losses.SparseCategoricalCrossentropy(from_logits=True) #Remove this "from_logits" if put the softmax activation in last dense layer
accuracy = metrics.SparseCategoricalAccuracy()
optimizer = optimizers.Adam()

# Callbacks are objects that provide additional functionalities during training,
# allowing to plug-in things at will (in this case, we add a callback to immediately
# terminate when a NaN value is encountered, a callback to perform early stopping,
# and a callback to log the results for TensorBoard visualization).
cbs = [
    callbacks.TerminateOnNaN(),
    callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5, 
                            restore_best_weights=True, verbose=1),
    callbacks.TensorBoard(log_dir='logs', update_freq=50)      
]

In [None]:
#model.compile(loss=cross_entropy, optimizer=optimizer, metrics=[accuracy])

In [None]:
#model.fit(train_data_p, validation_data=val_data_p, epochs=5, callbacks=cbs)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f94575755e0>

# Neural Network WITH Early Exits

In [41]:
def early_exit(x, training = True):
  
  x = layers.GlobalAvgPool2D()(x) #Possibly have to change names
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee_output = layers.Dense(10, activation='softmax')(x)
  return ee_output

In [42]:
def early_exit_loss(y_true,y_preds):
  scce = tf.keras.losses.SparseCategoricalCrossentropy()
  loss_ee1 = scce(y_true, y_preds[0])
  loss_ee2 = scce(y_true, y_preds[1])
  loss_final = scce(y_true, y_preds[2])

  return loss_final + loss_ee1*1 + loss_ee2*1

In [43]:
def build_model(training = True):
  #*****INPUT******
  inp = layers.Input(shape=(500, 375, 3))
  #*******************************FIRST & SECOND CONVOLUTIONAL BLOCK***********************************#

  x_cb1_t = add_conv_block(inp, 24)


  x_cb2_t = add_conv_block(x_cb1_t, 48)

  #******************************FIRST & SECOND CONVOLUTIONAL BLOCK************************************#


  #-------------------------------------START FIRST EARLY EXIT--------------------------------------#
  # Classification done always. Input --> x_cb2_t || Output --> ee1_output
  x = layers.GlobalAvgPool2D()(x_cb2_t) 
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee1_output = layers.Dense(10, activation='softmax')(x)

  # Stack to calculate Loss. All confidences of EE1 saved in piu[0]
  piu = tf.stack(ee1_output, axis=0) 
    
  # Only during inference... filtering!
  if training == False: 

    threshold_ee1 = 0.9 
    batch_size = 12
    #Auxiliary tensor keeps track of the id of the images that haven't exited yet.
    auxiliary_tensor = tf.range(batch_size) # used to sort images that take EE.
    auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(ee1_output, axis = -1) # TAKING RESULT OF THE EARLY EXIT 1

    #Thresholding operation. 
    #New tensor: 0's where confidence < threshold --> shall be passed to subsequent layers
    #1's where confidence > threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee1, 0, 1), tf.int32)

    #Select images with confidences ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(ee1_output, mask_exiters) #Take probability vector for the exiter instances

    #Identify images that exited, to then sort them
    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take those that aren't zero, so they exited.
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Make 1 tensor with id of exited images.

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0) 
    input_non_exiters = tf.boolean_mask(x_cb2_t, mask_non_exiters) #Take the PIXELS of images that didn't exit here.

    
    #Appending and sortering section
    #Array containing vector of confidences for images that took the EE.
    output_list = []
    output_list.append(output)

    #Array containing id and order of the images that took the EE.
    sorting_list = []
    sorting_list.append(sorting_tensor)

    #Update auxiliary tensor by removing the ids of the images that took the early exit.
    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]
  #-------------------------------------END FIRST EARLY EXIT--------------------------------------#

  #**********************************THIRD CONVOLUTIONAL BLOCK***********************************#
  '''
  Difference must be made. If training, the 3rd CB should take pixels of ALL the images in Batch. 
  If inference, 3rd CB takes pixels of ONLY images that did't take the 1st EE.
  '''
  if training == True:
    input_3cb = x_cb2_t
  else: 
    input_3cb = input_non_exiters

  
  x_cb3_t = add_conv_block(input_3cb, 96)
  #**********************************THIRD CONVOLUTIONAL BLOCK***********************************#

  #-------------------------------------START SECOND EARLY EXIT--------------------------------------#
  # Classification done always. Input --> x_cb3_t || Output --> ee2_output
  
  x = layers.GlobalAvgPool2D()(x_cb3_t) 
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee2_output = layers.Dense(10, activation='softmax')(x)

  # Stack to calculate Loss. All confidences of EE2 saved in piu[1]
  piu = tf.stack([piu,ee2_output], axis=0) 

  # Only during inference... filtering!
  if training == False: 

    threshold_ee2 = 0.8 
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(ee2_output, axis = -1) # TAKING RESULT OF THE EARLY EXIT 2

    #Thresholding operation. 
    #New tensor: 0's where confidence < threshold --> shall be passed to subsequent layers
    #1's where confidence > threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee2, x = 0, y = 1), tf.int32)

    #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(ee2_output, mask_exiters) #Take probability vector for the exiter instances

    #Identify images that exited, to then sort them
    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take those that aren't zero, so they exited.
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Make 1 tensor with id of exited images.

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0) 
    input_non_exiters = tf.boolean_mask(x_cb3_t, mask_non_exiters) #Take the PIXELS of images that didn't exit here.

    #Appending and sortering section
    #Array containing vector of confidences for images that took the EE.
    output_list.append(output)

    #Array containing id and order of the images that took the EE.
    sorting_list.append(sorting_tensor)

    #Update auxiliary tensor by removing the ids of the images that took the early exit.
    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]

  #**********************************FOURTH CONVOLUTIONAL BLOCK***********************************#
  '''
  Difference must be made. If training, the 4rd CB should take pixels of ALL the images in Batch. 
  If inference, 4th CB takes pixels of ONLY images that did't take the 1st EE.
  '''
  if training == True:
    input_4cb = x_cb3_t
  else: 
    input_4cb = input_non_exiters

  
  x_cb4_t = add_conv_block(input_4cb, 192)
  #**********************************FOURTH CONVOLUTIONAL BLOCK***********************************#
   
  #-------------------------------------START FINAL EXIT--------------------------------------#
  # Classification done always. Input --> x_cb4_t || Output --> final_output

  x = layers.GlobalAvgPool2D()(x_cb4_t) 
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  final_output = layers.Dense(10, activation='softmax')(x)
  

  final_out = tf.expand_dims(final_output, axis=0)
  piu = tf.concat([piu, final_out], axis=0)                       
  assert piu.shape == (3, None, 10)
  x = piu
  
  if training == False:
    output_list.append(final_output)
    sorting_list.append(auxiliary_tensor)

    #Sorting operation
    sorting_list = tf.concat(sorting_list, axis=0)
    sorting_idx = tf.argsort(sorting_list)
    output_list = tf.concat(output_list, axis=0)

    x = tf.gather(output_list, sorting_idx, batch_dims = 0)

  return tf.keras.Model(inputs=inp, outputs=x)

In [44]:
model_ee = build_model()

In [45]:
model_ee.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 500, 375, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_12 (Conv2D)             (None, 500, 375, 24  1824        ['input_4[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization_12 (BatchN  (None, 500, 375, 24  96         ['conv2d_12[0][0]']              
 ormalization)                  )                                                           

In [46]:
model_ee.compile(loss=early_exit_loss, optimizer=optimizer, metrics=[accuracy])

In [47]:
model_ee.fit(train_data_p, validation_data=val_data_p, epochs=1, callbacks=cbs)



<keras.callbacks.History at 0x7fe3e9168e50>

# Sketching a new way...

In [None]:
def build_model(training = True):
  #*****INPUT******
  inp = layers.Input(shape=(500, 375, 3))
  #*****FIRST CONVOLUTIONAL BLOCK******

  if training == True:
    x_cb1_t = add_conv_block(inp, 6)
  else:
    x_cb1 = add_conv_block(inp,6)

  

  #*****FIRST EARLY EXIT******
  if training == True:
     
    x = layers.GlobalAvgPool2D()(x_cb1_t) #Possibly have to change names
    x = layers.Dense(100, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    ee1_output = layers.Dense(10, activation='softmax')(x)

    piu = tf.stack(ee1_output, axis=0) 
    

  else: 
    x = layers.GlobalAvgPool2D()(x_cb1) 
    x = layers.Dense(100, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    output = layers.Dense(10, activation='softmax')(x) 

    threshold_ee1 = 0.9 
    batch_size = 12
    auxiliary_tensor = tf.range(batch_size)
    auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(output, axis = -1)

    #Thresholding operation. 
    #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
    #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee1, 0, 1), tf.int32)

    #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(output, mask_exiters) #Take probability vector for the exiter instances

    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0)
    input_non_exiters = tf.boolean_mask(x_cb1, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed


    output_list = []
    output_list.append(output)

    sorting_list = []
    sorting_list.append(sorting_tensor)

    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]
    #*****FIRST EARLY EXIT******

  
  #*****SECOND CONVOLUTIONAL BLOCK******
  if training == True:
    x_cb2_t = add_conv_block(x_cb1_t, 6)
  else:
    x_cb2 = add_conv_block(input_non_exiters,6)

  #*****SECOND EARLY EXIT******
  if training == True:
     
    x = layers.GlobalAvgPool2D()(x_cb2_t) #Possibly have to change names
    x = layers.Dense(100, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    ee2_output = layers.Dense(10, activation='softmax')(x) 

    piu = tf.stack([piu,ee2_output], axis=0)                                                             
    

  else: 
    x = layers.GlobalAvgPool2D()(x_cb2) 
    x = layers.Dense(100, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    output = layers.Dense(10, activation='softmax')(x) 

    threshold_ee2 = 0.8 #define threshold?
  
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(output, axis = -1)

    #Thresholding operation. 
    #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
    #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee2, x = 0, y = 1), tf.int32)

    #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(output, mask_exiters) #Take probability vector for the exiter instances

    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0)
    input_non_exiters = tf.boolean_mask(x_cb2, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed

    #Appending operation
    output_list.append(output)
    sorting_list.append(sorting_tensor)

    #Auxiliary tensor update --> possibly dispensable
    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]

    #*****SECOND EARLY EXIT******

    #*****FINAL EXIT******
  if training == True:
    x = layers.GlobalAvgPool2D()(x_cb2_t) 
    x = layers.Dense(100, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    final_output = layers.Dense(10, activation='softmax')(x)
    
    final_output = tf.expand_dims(final_output, axis=0)
    piu = tf.concat([piu, final_output], axis=0)
    #piu = tf.stack([piu,final_output], axis=0)                         
    assert piu.shape == (3, None, 10)
    x = piu

  else:
    output = layers.GlobalAvgPool2D()(input_non_exiters) 
    output = layers.Dense(100, activation='relu')(output)
    output = layers.Dropout(0.3)(output)
    output = layers.Dense(10, activation='softmax')(output)    

    #Appending Operation
    output_list.append(output)
    sorting_list.append(sorting_tensor)
    
    #Sorting operation
    sorting_list = tf.concat(sorting_list, axis=0)
    sorting_idx = tf.argsort(sorting_list)
    output_list = tf.concat(output_list, axis=0)

    x = tf.gather(output_list, sorting_idx, batch_dims = 0)


  return tf.keras.Model(inputs=inp, outputs=x)

In [None]:
class CustomModel(keras.Model):
    
    
    def build_model(training = None):
      #*****INPUT******
      inp = layers.Input(shape=(500, 375, 3))
      #*****FIRST CONVOLUTIONAL BLOCK******

      if training == True:
        x_cb1_t = add_conv_block(inp, 6)
      else:
        x_cb1 = add_conv_block(inp,6)

      #*****FIRST EARLY EXIT******
      if training == True:
     
        x = layers.GlobalAvgPool2D()(x_cb1_t) #Possibly have to change names
        x = layers.Dense(100, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        ee1_output = layers.Dense(10, activation='softmax')(x)

        piu = tf.stack(ee1_output, axis=0) 
        print(piu)

      else: 
        x = layers.GlobalAvgPool2D()(x_cb1) 
        x = layers.Dense(100, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        output = layers.Dense(10, activation='softmax')(x) 

        threshold_ee1 = 0.9 
        batch_size = 12
        auxiliary_tensor = tf.range(batch_size)
        auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
    
        #Take for each image of the Batch the category with highest confidence after the softmax.
        max_confidence = tf.reduce_max(output, axis = -1)

        #Thresholding operation. 
        #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
        #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
        exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee1, 0, 1), tf.int32)

        #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
        mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
        output = tf.boolean_mask(output, mask_exiters) #Take probability vector for the exiter instances

        sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
        mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
        sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

        #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
        mask_non_exiters = tf.equal(exiting_instances, 0)
        input_non_exiters = tf.boolean_mask(x_cb1, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed


        output_list = []
        output_list.append(output)

        sorting_list = []
        sorting_list.append(sorting_tensor)

        auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
        auxiliary_tensor = auxiliary_tensor[0]
        #*****FIRST EARLY EXIT******

  
      #*****SECOND CONVOLUTIONAL BLOCK******
      if training == True:
        x_cb2_t = add_conv_block(x_cb1_t, 6)
      else:
        x_cb2 = add_conv_block(input_non_exiters,6)

      #*****SECOND EARLY EXIT******
      if training == True:
     
        x = layers.GlobalAvgPool2D()(x_cb2_t) #Possibly have to change names
        x = layers.Dense(100, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        ee2_output = layers.Dense(10, activation='softmax')(x) 

        piu = tf.stack([piu,ee2_output], axis=0)                                                             
        print(f"Piu EE2: {piu}")

      else: 
        x = layers.GlobalAvgPool2D()(x_cb2) 
        x = layers.Dense(100, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        output = layers.Dense(10, activation='softmax')(x) 

        threshold_ee2 = 0.8 #define threshold?
  
    
        #Take for each image of the Batch the category with highest confidence after the softmax.
        max_confidence = tf.reduce_max(output, axis = -1)

        #Thresholding operation. 
        #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
        #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
        exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee2, x = 0, y = 1), tf.int32)

        #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
        mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
        output = tf.boolean_mask(output, mask_exiters) #Take probability vector for the exiter instances

        sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
        mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
        sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

        #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
        mask_non_exiters = tf.equal(exiting_instances, 0)
        input_non_exiters = tf.boolean_mask(x_cb2, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed

        #Appending operation
        output_list.append(output)
        print(f"Early Exit 1: {output_list}")
        sorting_list.append(sorting_tensor)

        #Auxiliary tensor update --> possibly dispensable
        auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
        auxiliary_tensor = auxiliary_tensor[0]

        #*****SECOND EARLY EXIT******

        #*****FINAL EXIT******
      if training == True:
        x = layers.GlobalAvgPool2D()(x_cb2_t) 
        x = layers.Dense(100, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        final_output = layers.Dense(10, activation='softmax')(x)
    
        final_output = tf.expand_dims(final_output, axis=0)
        piu = tf.concat([piu, final_output], axis=0)
        #piu = tf.stack([piu,final_output], axis=0)                         
        assert piu.shape == (3, None, 10)
        x = piu

      else:
        output = layers.GlobalAvgPool2D()(input_non_exiters) 
        output = layers.Dense(100, activation='relu')(output)
        output = layers.Dropout(0.3)(output)
        output = layers.Dense(10, activation='softmax')(output)    

        #Appending Operation
        output_list.append(output)
        sorting_list.append(sorting_tensor)
    
        #Sorting operation
        sorting_list = tf.concat(sorting_list, axis=0)
        sorting_idx = tf.argsort(sorting_list)
        output_list = tf.concat(output_list, axis=0)

        x = tf.gather(output_list, sorting_idx, batch_dims = 0)


      return tf.keras.Model(inputs=inp, outputs=x)
    
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [None]:
model = CustomModel.build_model()

In [None]:
model.summary()

In [None]:
def early_exit_loss(y_true,y_preds):
  scce = tf.keras.losses.SparseCategoricalCrossentropy()
  loss_ee1 = scce(y_true, y_preds[0])
  loss_ee2 = scce(y_true, y_preds[1])
  loss_final = scce(y_true, y_preds[2])

  return loss_final + loss_ee1*1 + loss_ee2*1

In [None]:
cross_entropy = losses.SparseCategoricalCrossentropy(from_logits=True) 
accuracy = metrics.SparseCategoricalAccuracy()
optimizer = optimizers.Adam()

In [None]:
model.compile(loss=early_exit_loss, optimizer=optimizer, metrics=[accuracy])

In [None]:
model.compile(loss=early_exit_loss, optimizer=optimizer, metrics=[accuracy])
model.fit(train_data_p, validation_data=val_data_p, epochs=1, callbacks=cbs)



KeyboardInterrupt: ignored

# Experiments on tensor indexing

#### Testing the Boolean Mask trick - Retriving **Exiters** members of Batch

In [None]:
import tensorflow as tf
example_tensor = tf.constant([[0.33, 0.33, 0.34],
                             [0.9, 0.05, 0.05],
                             [0.2, 0.6, 0.2]])

In [None]:
max_confidence = tf.reduce_max(example_tensor, axis = -1)
max_confidence

In [None]:
exiting_elements = tf.cast(tf.where(max_confidence < 0.8, x=0, y=1), tf.int32)
#exiting_elements = tf.reshape(exiting_elements, [1, -1])
exiting_elements

In [None]:
ind_tensor = tf.cast(tf.constant([0,1,0]),tf.int32)
ind_tensor

In [None]:
mask = tf.equal(exiting_elements, 1)
result = tf.boolean_mask(example_tensor, mask)
result

#### Testing the Boolean Mask trick - Retriving **Non Exiters** members of Batch

In [None]:
for xb, yb in train_data_p:
  first_element_batch = xb
  second_element_batch = yb
  print(xb.shape)
  print(yb.shape)
  break

In [None]:
first_element_batch.shape

In [None]:
ind_tensor = tf.cast(tf.constant([1,0,1,0,0,0,0,0,0,0,0,0]),tf.int32)
ind_tensor

In [None]:
mask = tf.equal(ind_tensor, 0)
result = tf.boolean_mask(first_element_batch, mask)
result.shape

In [None]:
index_tensor = tf.constant([0, 2])

#### Testing the function to create the **sortering_tensor**
Why do we need to reorder the output? Because we're gonna give to the EarlyExit layer only the X (pixels of image) and not the Y (true label). Then, during inference we need to see which images exit in which early exit to compute the costs but, since the images won't exit in an ordered way and the order of the true labels (y) is static, we need to reorder the vector of class probabilities such that they match their corresponding true label. 

In [None]:
import tensorflow as tf
example_tensor = tf.constant([[0.33, 0.33, 0.34],
                             [0.9, 0.05, 0.05],
                             [0.2, 0.6, 0.2],
                              [0.85, 0.05, 0.1]])

In [None]:
max_confidence = tf.reduce_max(example_tensor, axis = -1)
max_confidence

In [None]:
exiting_elements = tf.cast(tf.where(max_confidence < 0.8, x=0, y=1), tf.int32)
exiting_elements

In [None]:
mask = tf.equal(exiting_elements, 1)
result = tf.boolean_mask(example_tensor, mask)
result

In [None]:
input_index = tf.range(example_tensor.shape[0])
input_index

In [None]:
exit_order = tf.math.multiply(input_index,exiting_elements)
exit_order

In [None]:
mask_non_zeros = tf.not_equal(exit_order, 0)
mask_non_zeros

In [None]:
exit_order = tf.boolean_mask(exit_order, mask_non_zeros)
exit_order

#### **Discovering how to the update the Auxiliary tensor**

In [None]:
''''
Create ficticious tensor that represents the auxiliary tensor 
'''
import tensorflow as tf
auxiliary_tensor = tf.constant([[0,1,2,3,4,5,6,7,8,9,10,11]])
auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
auxiliary_tensor

In [None]:
'''
Recreate situation in which image 3 and 7 have taken the EarlyExit
'''
sorting_tensor = tf.constant([[3,7]])
sorting_tensor = tf.reshape(sorting_tensor, [-1])
sorting_tensor

In [None]:
'''
Out of the EarlyExit layer, we remove from the Auxiliary Tensor, those elements that took the early exit.
This is done since we want to feed this new auxiliary_tensor to the 2nd EarlyExit Layer, which mustn't
contain the elements that exited on previous EarlyExit. Images must be identified with the same index
regardless of where they exit.

See that auxiliary_tensor doesn't contain the exited images anymore. 
'''
auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
auxiliary_tensor = auxiliary_tensor[0]
auxiliary_tensor

In [None]:
#For the difference with the auxiliary tensor during training.
'''
However, see that the EarlyExit layer outputs always 3 elements:
1) outputs: vector of probabilities.
2) sorting tensor: used to keep track of which images took the early exit
3) input_non_exiters: tensor of pixels of images that didn't take the EarlyExit
and we need to give to subsequent Layers 

During training and inference, they have different behaviors. 
1.1 Outputs - Training: contains vector of probabilities for all the instances, since all took the EarlyExit
1.2 Outputs - Inference: contains vector of probabilities of ONLY those who took the EarlyExit and EXITED. 

2.1 Sorting_tensor - Training: empty tensor bc we don't need to update the auxiliary_tensor 
(operation done out of the EarlyExit Layer) bc all instances took the early exit, therefore they're sorted. 
2.2 Sorting_tensor - Inference: it identifies the images that took the EarlyExit. We need this tensor
to update, out of the EarlyExit Layer, the auxiliary tensor that will keep track of which images will
enter the next EarlyExit. 

3.3 Input_non_exiters - Training: contain all the pixels, bc we need to pass all images to subsequent layers.
3.4 Input_non_exiters - Testing: contain pixels of only those images that didn't take the early exit. 

So, this empty tensor will be the sorting_tensor DURING TRAINING. This way, during training, the sorting_tensor output
has one single element -1 (tf.zeros or tf.ones no bc there are 0 and 1 indexes, but there's no index or element -1), 
such that the update operation of the auxiliary_tensor (done out of the earlyExitLayer) doesn't remove any element of 
the tensor. This last because during training all images enter the EarlyExit.
'''
empty_tensor = tf.constant([[-1]], dtype=tf.int32)
empty_tensor = tf.reshape(empty_tensor, [-1])
empty_tensor
#empty_tensor = tf.experimental.numpy.empty([1, 0], dtype=tf.float32)
#empty_tensor

In [None]:
'''
Operation done out of the EarlyExit layer. It updates the auxiliary_tensor such that
it contains the identifiers of the images that will enter the next EarlyExit. 
'''
#Testing the difference
auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, empty_tensor,index_dtype=tf.dtypes.int32)
auxiliary_tensor

#### **Recreating the sortering...during INFERENCE**

In [None]:
import tensorflow as tf
output = tf.constant([[0.33, 0.33, 0.34],
                             [0.9, 0.05, 0.05],
                             [0.2, 0.6, 0.2],
                              [0.85, 0.05, 0.1]])

In [None]:
auxiliary_tensor = tf.range(output.shape[0])
auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
auxiliary_tensor

In [None]:
max_confidence = tf.reduce_max(output, axis = -1) #Highest probs
exiting_instances = tf.cast(tf.where(max_confidence < 0.8, x=0, y=1), tf.int32) #Thresholding. 1 and 0 tensor
mask_exiters = tf.equal(exiting_instances, 1) #Mask. Extract only those who exited
output = tf.boolean_mask(output, mask_exiters) 
output

In [None]:
sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged
sorting_tensor

In [None]:
output_list = []
output_list.append(output)

In [None]:
sorting_list = []
sorting_list.append(sorting_tensor)

In [None]:
auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
auxiliary_tensor = auxiliary_tensor[0]
auxiliary_tensor

**Second Layer...**

In [None]:
output = tf.constant([[0.33, 0.33, 0.34],
                      [0.2, 0.6, 0.2]])

In [None]:
max_confidence = tf.reduce_max(output, axis = -1) #Highest probs
exiting_instances = tf.cast(tf.where(max_confidence < 0.6, x=0, y=1), tf.int32) #Thresholding. 1 and 0 tensor
mask_exiters = tf.equal(exiting_instances, 1) #Mask. Extract only those who exited
output = tf.boolean_mask(output, mask_exiters) 
output

In [None]:
sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged
sorting_tensor

In [None]:
auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
auxiliary_tensor = auxiliary_tensor[0]
auxiliary_tensor

**Appending step. Done after all the exits**

In [None]:
output_list.append(output)
sorting_list.append(sorting_tensor)

In [None]:
#Visualizing
output_list

In [None]:
#Visualizing
sorting_list

In [None]:
#Done only after all the exits
sorting_list = tf.concat(sorting_list, axis=0)
sorting_list

In [None]:
sorting_idx = tf.argsort(sorting_list)
sorting_idx

In [None]:
xxx = output_list
xxx = tf.concat(xxx, axis=0)
xxx
#output_list[sorting_idx]

In [None]:
tf.gather(xxx, sorting_idx, batch_dims = 0)

#### **Testing the output stacking for training**

In [None]:
import tensorflow as tf
output = tf.constant([[0.33, 0.33, 0.34],
                             [0.9, 0.05, 0.05],
                             [0.2, 0.6, 0.2],
                              [0.85, 0.05, 0.1]])

In [None]:
auxiliary_tensor = tf.range(output.shape[0])
auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
auxiliary_tensor

In [None]:
max_confidence = tf.reduce_max(output, axis = -1) #Highest probs
exiting_instances = tf.cast(tf.where(max_confidence < 0.8, x=0, y=1), tf.int32) #Thresholding. 1 and 0 tensor
mask_exiters = tf.equal(exiting_instances, 1) #Mask. Extract only those who exited
output = tf.boolean_mask(output, mask_exiters) 
output

In [None]:
ccc = tf.stack(output, axis=0)
ccc

In [None]:
ttt = tf.stack([ccc,output], axis = 0)
ttt

In [None]:
ttt[0]

In [None]:
ttt[1]

# Back to basics...

In [None]:
train_data_p = train_dataset_filtered.shuffle(1000).batch(12).map(augment)
val_data_p = val_dataset_filtered.batch(12)
test_data_p = test_dataset_filtered.batch(12)

In [None]:
def add_conv_block(x, n_filters): #12
    # This function applies a simple "CNN block" to the input,
    # built as Conv2D -> BN -> ReLU -> MaxPool2D.
    x = layers.Conv2D(n_filters, 3, padding='same', kernel_regularizer=regularizers.L2(10e-3))(x)
    x = layers.BatchNormalization()(x)
    x = tf.nn.relu(x)
    return layers.MaxPool2D(2)(x)

In [None]:
def classification_layer(x_inp):
  x = layers.GlobalAvgPool2D()(x_inp) # Output shape: (None, 32)
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  x = layers.Dense(10)(x)  
  return x

In [None]:
def early_exit(x, training = True):
  
  x = layers.GlobalAvgPool2D()(x) #Possibly have to change names
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee_output = layers.Dense(10, activation='softmax')(x)
  return ee_output

In [None]:
def early_exit_loss(y_true,y_preds):
  scce = tf.keras.losses.SparseCategoricalCrossentropy()
  loss_ee1 = scce(y_true, y_preds[0])
  loss_ee2 = scce(y_true, y_preds[1])
  loss_final = scce(y_true, y_preds[2])

  return loss_final + loss_ee1*1 + loss_ee2*1

In [None]:
def filter_ee1_res(vect_of_confidences, inp_bef_ee): #Vector of confidences of EE and input of the early exit. Call this only in testing
  threshold_ee1 = 0.9 
  batch_size = 12
  auxiliary_tensor = tf.range(batch_size)
  auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
    
  #Take for each image of the Batch the category with highest confidence after the softmax.
  max_confidence = tf.reduce_max(vect_of_confidences, axis = -1)

  #Thresholding operation. 
  #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
  #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
  exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee1, 0, 1), tf.int32)

  #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
  mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
  output = tf.boolean_mask(vect_of_confidences, mask_exiters) #Take probability vector for the exiter instances

  sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
  mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
  sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

  #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
  mask_non_exiters = tf.equal(exiting_instances, 0)
  input_non_exiters = tf.boolean_mask(inp_bef_ee, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed

  #list of number of exiters here

  output_list = []
  output_list.append(output)

  sorting_list = []
  sorting_list.append(sorting_tensor)

  auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
  auxiliary_tensor = auxiliary_tensor[0]

In [None]:
def build_model():
  # Input part
  inp = layers.Input(shape=(500, 375, 3))
  # First Convolutional Block     # Output: (None, 500, 375, 24)
  x = add_conv_block(inp, 24)     

  # Second Convolutional Block    # Output: (None, 250, 187, 48) 
  x = add_conv_block(x, 48)       

  # 1st Early Exit
  ee1 = early_exit(x) #These are predictions. Have to save this object in an array
  
  if training = False: 
    filter_ee1_res(ee1,x)
  # Third Convolutional Block     # Output: (None, 125, 93, 96)
  x = add_conv_block(x, 96)

  # 2nd Early Exit

  # Fourth Convolutional Block    # Output: (None, 62, 46, 192)
  x = add_conv_block(x, 192)     

  x = classification_layer(x)
  return tf.keras.Model(inputs=inp, outputs=x)

In [None]:
net = build_model()

In [None]:
net.summary()

In [None]:
cross_entropy = losses.SparseCategoricalCrossentropy(from_logits=True) #Remove this "from_logits" if put the softmax activation in last dense layer
accuracy = metrics.SparseCategoricalAccuracy()
optimizer = optimizers.Adam()

cbs = [
    callbacks.TerminateOnNaN(),
    callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5, 
                            restore_best_weights=True, verbose=1),
    callbacks.TensorBoard(log_dir='logs', update_freq=50)      
]

In [None]:
net.compile(loss=cross_entropy, optimizer=optimizer, metrics=[accuracy])

In [None]:
net.fit(train_data_p, validation_data=val_data_p, epochs=100, callbacks=cbs)

# Other possible (failed) approach...

In [None]:
net = tf.keras.Sequential([
    tf.keras.layers.Conv2D(n_filters, 3, padding='same', kernel_regularizer=regularizers.L2(10e-3), input_shape = (500,375,3)),
    tf.keras.layers.Dense(50, activation = keras.activations.relu),
    tf.keras.layers.Dense(3, activation = keras.activations.softmax)

In [None]:
class CustomModel(tf.keras.Model): #New class, inheriting keras.Model.
  def __init__(self):
    self.net  = [] #Keep track of CB
    self.classifiers   = [] #Keep track of EE

  def add_conv_struct(self, n_filters, set_shape = False, input_shape=None):
    struc_array = []
    if set_shape:
      struc_array.append(layers.Conv2D(n_filters, 3, padding='same', kernel_regularizer=regularizers.L2(10e-3), input_shape=input_shape))
    else:
      struc_array.append(layers.Conv2D(n_filters, 3, padding='same', kernel_regularizer=regularizers.L2(10e-3)))
    
    struc_array.append(layers.BatchNormalization())
    struc_array.append(tf.nn.relu())
    struc_array.append(layers.MaxPool2D(2))
    self.net.append(struc_array)

  def classifiers(self):
    layers = []
    layers.append(layers.GlobalAvgPool2D())
    layers.append(layers.Dense(100, activation='relu'))
    layers.append(layers.Dropout(0.3))
    layers.append(layers.Dense(10, activation='softmax'))
    self.classifiers.append(layers)

  def build_model(self):
  # First Convolutional Block     # Output: (None, 500, 375, 24)
  self.add_conv_struct(24, set_shape = True, input_shape = (500, 375, 3))   

  # Second Convolutional Block    # Output: (None, 250, 187, 48) 
  self.add_conv_struct(48)        

  # 1st Early Exit
  self.classifiers() 
  
  # Third Convolutional Block     # Output: (None, 125, 93, 96)
  self.add_conv_struct(96)

  # 2nd Early Exit
  self.classifiers() 

  # Fourth Convolutional Block    # Output: (None, 62, 46, 192)
  self.add_conv_struct(192)     

  # Final Classifier
  self.classifiers()

  def early_exit_loss(self, y_true,y_preds):
    scce = tf.keras.losses.SparseCategoricalCrossentropy()
    loss = scce(y_true, y_preds)
    return loss


  def apply(input,layers):
    for layer in layers:
      input = layer(input)

    return input

  def fit(self, images, num_epochs):
    for epoch in range(num_epochs):
      loss = 0
      print(f"Epoch number: " {epoch})

      for batch_idx, (x_batch, y_batch) in enumerate(images):
        with tf.GradientTape() as tape:
          
          output_1 = apply(x_batch, self.net[0]) #First CB
          output_2 = apply(output_1,self.net[1]) #Second CB

          loss += early_exit_loss(y_batch, apply(output_2,self.classifiers[0])

          output_3 = apply(output_2, self.net[2]) #Third CB

          loss += early_exit_loss(y_batch, apply(output_3,self.classifiers[1]))

          output_4 = apply(output_3, self.net[3]) #Fourth CB

          loss += early_exit_loss(y_batch, apply(output_4,self.classifiers[2]))

        model = tf.keras.Sequential(self.net[0])


# Original approach - correcting mistakes

In [None]:
def build_model(training = True):
  #*****INPUT******
  inp = layers.Input(shape=(500, 375, 3))
  #*****FIRST CONVOLUTIONAL BLOCK******

  x_cb1_t = add_conv_block(inp, 6)

  #*****FIRST EARLY EXIT******
  x = layers.GlobalAvgPool2D()(x_cb1_t) #Possibly have to change names
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee1_output = layers.Dense(10, activation='softmax')(x)

  piu = tf.stack(ee1_output, axis=0) 
    
  if training == False: #Only during inference

    threshold_ee1 = 0.9 
    batch_size = 12
    auxiliary_tensor = tf.range(batch_size)
    auxiliary_tensor = tf.reshape(auxiliary_tensor, [-1])
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(ee1_output, axis = -1) # TAKING RESULT OF THE EARLY EXIT 1

    #Thresholding operation. 
    #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
    #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee1, 0, 1), tf.int32)

    #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(ee1_output, mask_exiters) #Take probability vector for the exiter instances

    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) 
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) 

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0)
    input_non_exiters = tf.boolean_mask(x_cb1_t, mask_non_exiters) 

    output_list = []
    output_list.append(output)

    sorting_list = []
    sorting_list.append(sorting_tensor)

    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]
  #*****FIRST EARLY EXIT******

  #*****SECOND CONVOLUTIONAL BLOCK******
  if training == True:
    input_2cb = x_cb1_t
  else: 
    input_2cb = input_non_exiters

  
  x_cb2_t = add_conv_block(input_2cb, 6)

  #*****SECOND EARLY EXIT******
  x = layers.GlobalAvgPool2D()(x_cb2_t) #Possibly have to change names
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  ee2_output = layers.Dense(10, activation='softmax')(x)

  piu = tf.stack([piu,ee2_output], axis=0) 

  if training == False: #Only during inference

    threshold_ee2 = 0.8 #define threshold?
    
    #Take for each image of the Batch the category with highest confidence after the softmax.
    max_confidence = tf.reduce_max(ee2_output, axis = -1)

    #Thresholding operation. 
    #New tensor with 0's where confidence is below threshold --> shall be passed to subsequent layers
    #1's where confidence is above threshold --> shall NOT be passed to subsequent layers.
    exiting_instances = tf.cast(tf.where(max_confidence < threshold_ee2, x = 0, y = 1), tf.int32)

    #Update output with elements ABOVE threshold that DON'T need to be given to subsequent layers
    mask_exiters = tf.equal(exiting_instances, 1) #Mask those elements of batch that took the early exit
    output = tf.boolean_mask(ee2_output, mask_exiters) #Take probability vector for the exiter instances

    sorting_tensor = tf.math.multiply(auxiliary_tensor,exiting_instances) #Make zero indexes of Batch that didn't exit
    mask_non_zeros = tf.not_equal(sorting_tensor, 0) #Take indexes that aren't zero
    sorting_tensor = tf.boolean_mask(sorting_tensor, mask_non_zeros) #Tensor to sort the output according to instance to which they belonged

    #Update input_non_exiters with elements BELOW threshold that NEED to be given to subsequent layers
    mask_non_exiters = tf.equal(exiting_instances, 0)
    input_non_exiters = tf.boolean_mask(x_cb2_t, mask_non_exiters) #¿?¿?In testing, I update which members of the Batch didn't take EE --> must be passed

    #Appending operation
    output_list.append(output)
    sorting_list.append(sorting_tensor)

    #Auxiliary tensor update --> possibly dispensable
    auxiliary_tensor = tf.compat.v1.setdiff1d(auxiliary_tensor, sorting_tensor,index_dtype=tf.dtypes.int32)
    auxiliary_tensor = auxiliary_tensor[0]
                                    

    #*****FINAL EXIT******
  if training == True:
    input_final_exit = x_cb2_t
  else: 
    input_final_exit = input_non_exiters

  x = layers.GlobalAvgPool2D()(input_final_exit) 
  x = layers.Dense(100, activation='relu')(x)
  x = layers.Dropout(0.3)(x)
  final_output = layers.Dense(10, activation='softmax')(x)
  
  final_out = tf.expand_dims(final_output, axis=0)
  piu = tf.concat([piu, final_out], axis=0)                       
  assert piu.shape == (3, None, 10)
  x = piu
  '''
  if training == False:

    #Appending Operation
    output_list.append(final_output)
    sorting_list.append(sorting_tensor)
    
    #Sorting operation
    sorting_list = tf.concat(sorting_list, axis=0)
    sorting_idx = tf.argsort(sorting_list)
    output_list = tf.concat(output_list, axis=0)

    x = tf.gather(output_list, sorting_idx, batch_dims = 0)
  '''

  return tf.keras.Model(inputs=inp, outputs=x)

In [None]:
model = build_model()

In [None]:
model.summary()

In [None]:
def early_exit_loss(y_true,y_preds):
  scce = tf.keras.losses.SparseCategoricalCrossentropy()
  loss_ee1 = scce(y_true, y_preds[0])
  print(y_true.shape)
  print(y_preds.shape)
  loss_ee2 = scce(y_true, y_preds[1])
  loss_final = scce(y_true, y_preds[2])

  return loss_final + loss_ee1*1 + loss_ee2*1

In [None]:
#cross_entropy = losses.SparseCategoricalCrossentropy(from_logits=True) #Remove this "from_logits" if put the softmax activation in last dense layer
accuracy = metrics.SparseCategoricalAccuracy()
optimizer = optimizers.Adam()

# Callbacks are objects that provide additional functionalities during training,
# allowing to plug-in things at will (in this case, we add a callback to immediately
# terminate when a NaN value is encountered, a callback to perform early stopping,
# and a callback to log the results for TensorBoard visualization).
cbs = [
    callbacks.TerminateOnNaN(),
    callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5, 
                            restore_best_weights=True, verbose=1),
    callbacks.TensorBoard(log_dir='logs', update_freq=50)      
]

In [None]:
model.compile(loss=early_exit_loss, optimizer=optimizer, metrics=[accuracy])

In [None]:
model.fit(train_data_p, validation_data=val_data_p, epochs=1, callbacks=cbs)

(None,)
(3, None, 10)
(None,)
(3, None, 10)
     67/Unknown - 31s 78ms/step - loss: 7.0423 - sparse_categorical_accuracy: 0.1053(None,)
(3, None, 10)


<keras.callbacks.History at 0x7f48e908e970>

In [None]:
model.predict(test_data_p)

In [None]:
model.evaluate(test_data_p)



[6.98460578918457, 0.07744107395410538]