In [None]:
from tensorflow.keras.utils import Progbar
from tensorflow.python.eager import backprop

In [None]:
def accumulated_gradients(gradients, step_gradients, batch_multiplier):
    if gradients is None:
        gradients = step_gradients
    else:
        for i, g in enumerate(step_gradients):
            gradients[i] += g
    return gradients

In [None]:
def train_with_batchsize(model, batch_size, batch_multiplier, num_epochs):
    
    # Generate training data set
    train_dataset = tf.data.Dataset.from_tensor_slices((trainX, trainY))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    
    # Initilization
    opt = SGD(lr=MIN_LR, momentum=0.9) # Instantiate a loss function.
    loss_fn = tf.keras.losses.CategoricalCrossentropy() # Instantiate a loss function.
    
    # Initialize Counters and Loggers
    train_counter = 0
    loss_history_epochs = []
    acc_history_epochs = []
    val_loss_history_epochs = []
    val_acc_history_epochs = []
    
    # Train on specified device
    with tf.device('/GPU:0'):

        for epoch in range(num_epochs):

            ## Progress Bar
            print("\nepoch {}/{}".format(epoch+1,num_epochs))
            pb_i = Progbar(num_training_samples // batch_size + 1, verbose=1, interval=0.08)
            ##

            # Initialization
            model.reset_metrics() # Reset metrics at the beginning of each epoch
            acc_gradients = None
            batch_counter = 0
            train_logs = None

            # Start the current epoch
            for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

                # Open a GradientTape to record the operations run
                # during the forward pass, which enables auto-differentiation.
                with backprop.GradientTape()  as tape:

                    # Run the forward pass of the layer.
                    # The operations that the layer applies
                    # to its inputs are going to be recorded
                    # on the GradientTape.
                    y_pred = model(x_batch_train, training=True) 

                    # Compute the loss value for this minibatch.
                    loss_value = loss_fn(y_batch_train, y_pred) / batch_multiplier

                # Use the gradient tape to automatically retrieve
                # the gradients of the trainable variables with respect to the loss.
                step_gradients = tape.gradient(loss_value, model.trainable_variables)

                # accumulate gradients
                acc_gradients = accumulated_gradients(acc_gradients, 
                                                      step_gradients, 
                                                      batch_multiplier)
                # Update
                if (batch_counter == 0):        # batch_multiplier of gradients accumulated      
                    opt.apply_gradients(zip(acc_gradients, model.trainable_variables))
                    acc_gradients = None
                    batch_counter = batch_multiplier
                

                # update metrics
                model.compiled_metrics.update_state(y_batch_train, y_pred)
                train_logs = {m.name : float(m.result()) for m in model.metrics}
                values= [('loss', loss_value*batch_multiplier), ('acc', train_logs['accuracy'])]
                pb_i.add(1, values=values)

                # Log step
                train_counter += 1  
                batch_counter -= 1

            ## Log result of current epoch
            # Calculate average training loss in this epoch
            loss_history_epochs.append(float(loss_value) * batch_multiplier)
            acc_history_epochs.append(train_logs['accuracy'])

            # Validation                
            val_logs = model.evaluate(testX, testY, 
                                      batch_size=32, 
                                      steps=10, 
                                      return_dict=True, 
                                      verbose=0)
            val_logs = {'val_' + name: val for name, val in val_logs.items()}
            
            

            
            # Log validation results
            val_loss_history_epochs.append(val_logs['val_loss'])
            val_acc_history_epochs.append(val_logs['val_accuracy'])
            
            print(val_logs)
            
    return {'train': [loss_history_epochs, acc_history_epochs], 
            'validation': [val_loss_history_epochs, val_acc_history_epochs]}

In [None]:
def create_model(lr=MIN_LR):
    
    # to complete
    model = MiniGoogLeNet.build(...)
    model.compile(...)    
    
    return model

In [None]:
batches = [[64, 1], 
           [128 , 1],
           [256, 1], 
           [512, 1],
           [1024, 1],
           [2048, 1],
           [2048, 2],
           [2048, 4]]

In [None]:
for [batch_size, batch_multiplier] in batches:
    
    # Train Model
    model = create_model(...)
    history = train_with_batchsize(model, batch_size=batch_size, batch_multiplier=batch_multiplier, num_epochs=5)