In [None]:
from __future__ import division

import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy, SparseCategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN, CSVLogger
import tensorflow as tf
import time
import os

from tensorflow.keras.layers import Input, Dense, ReLU, Softmax, Dropout, Conv2D, MaxPool2D, Flatten, Reshape
from tensorflow.keras import Model

# %matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
def softmax(x, axis=-1):
    e_x = np.exp(x - np.amax(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

In [None]:
def train_mnist(model, gen_train, gen_test=None, n_epochs=10, model_filename=None, verbose=1):
    # If filename is None, best model is not saved.
    
    dir_model = os.path.dirname(model_filename)    
    if not os.path.exists(dir_model):
        os.mkdir(dir_model)
    
    # Set other training parameters
    batch_size = 32
    n_train_steps_per_epoch = 60000/batch_size
    n_test_steps_per_epoch = 10000/batch_size

    model_checkpoint_val_loss = ModelCheckpoint(filepath=model_filename,
                                       monitor='val_loss',
                                       verbose=verbose,
                                       save_best_only=True,
                                       save_weights_only=False,
                                       mode='auto',
                                       save_freq='epoch')

    callbacks = [model_checkpoint_val_loss]    
    
    # Train
    t = time.time()
    history = model.fit(gen_train, epochs=n_epochs,
                                  steps_per_epoch=n_train_steps_per_epoch,
                                  callbacks=callbacks,
                                  validation_data=gen_test,
                                  validation_steps=n_test_steps_per_epoch,
                                  use_multiprocessing=False)
    print('\nTime to complete training: %f minutes.\n' % ((time.time() - t)/60))

    return model, history

In [None]:
def train_ramp(model, gen_train, gen_test=None, n_epochs=10, model_filename=None):
    # If filename is None, best model is not saved.
    
    dir_model = os.path.dirname(model_filename)    
    if not os.path.exists(dir_model):
        os.mkdir(dir_model)
    
    # Set other training parameters
    batch_size = 32
    n_train_steps_per_epoch = 1000
    n_test_steps_per_epoch = 1000

    model_checkpoint_loss = ModelCheckpoint(filepath=model_filename,
                                       monitor='loss',
                                       verbose=1,
                                       save_best_only=True,
                                       save_weights_only=False,
                                       mode='auto',
                                       save_freq='epoch')

    callbacks = [model_checkpoint_loss]    
    
    # Train
    t = time.time()
    history = model.fit_generator(gen_train, epochs=n_epochs,
                                  steps_per_epoch=n_train_steps_per_epoch,
                                  callbacks=callbacks,
                                  use_multiprocessing=False)
    print('\nTime to complete training: %f minutes.\n' % ((time.time() - t)/60))

    return model, history

In [None]:
def plot_mnist_digits(images, labels, predictions=None, n_samples_to_plot=30, width_scale=1.5):
    # width_scale: scaling factor so figure width fits nicely in notebook output
    
    batch_size = len(images)
    batch_size = min(batch_size, n_samples_to_plot)
    
    images = np.reshape(images[:batch_size], (batch_size, 28, 28))
    labels = labels[:batch_size]
    if predictions is not None:
        predictions = softmax(predictions[:batch_size], axis=-1)
        
    n_sp_cols = min(batch_size, 10)                # number of subplot columns
    n_sp_rows = int(np.ceil(batch_size/n_sp_cols)) # number of subplot rows

    # extra 1.2 scaling of rows is to accommodate plot titles...
    plt.figure(figsize=(width_scale*n_sp_cols, width_scale*n_sp_rows*1.2))
    
    for i_im  in range(batch_size):
        ax = plt.subplot(n_sp_rows, n_sp_cols, i_im+1)
        im = images[i_im]

        if predictions is not None:
            i_max = np.argmax(predictions[i_im])
            title_str = '%d, %d, %0.2f' % (labels[i_im], i_max, predictions[i_im, i_max])
            if labels[i_im]!=i_max:
                # Draw red border around images with wrong prediction, or
                # Draw blue border around images with null prediction
                if i_max==10:
                    chan = 2 # blue
                else:
                    chan = 0 # red
            else:
                # Draw green border around images with correct prediction
                chan = 1 # green
            im = np.repeat(np.expand_dims(im, axis=2), 3, axis=2) # convert to color
            im[0:2, :, chan] = 1.0
            im[-2:, :, chan] = 1.0
            im[:, 0:2, chan] = 1.0
            im[:, -2:, chan] = 1.0
        else:
            title_str = '%d' % (labels[i_im])
        
        plt.imshow(im, cmap='gray', aspect='equal')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)            
        plt.title(title_str)
        
#     if predictions is not None:
#         caption='Values in image titles are: <true label>, <predicted label>, <predicted label score (probability)>.'
#     else:
#         caption='Values in image titles are ground truth labels.'
#     plt.figtext(0.5, 0.01, caption, wrap=True, horizontalalignment='center', fontsize=15)

In [None]:
def plot_training_history(history):
    keys = history.history.keys()
    loss_keys = [k for k in history.history.keys() if k.endswith('loss')]
    acc_keys = [k for k in history.history.keys() if k.endswith('accuracy')]
    
    figsize = (12, 6)
    fontsize = 15
    
    # Plot losses
    plt.figure(figsize=figsize)
    for k in loss_keys:
        ax = plt.plot(history.history[k], label=k)
    plt.legend(prop={'size': fontsize})
    plt.xlabel('Epoch', fontsize=fontsize)
    plt.ylabel('Loss', fontsize=fontsize)
    plt.tick_params(axis='both', which='major', labelsize=fontsize)
    plt.grid(True)

    # Plot accuracies
    if len(acc_keys)>0:
        plt.figure(figsize=figsize)
        for k in acc_keys:
            ax = plt.plot(history.history[k], label=k)
        plt.legend(prop={'size': fontsize})
        plt.xlabel('Epoch', fontsize=fontsize)
        plt.ylabel('Accuracy', fontsize=fontsize)
        plt.tick_params(axis='both', which='major', labelsize=fontsize)
        plt.grid(True)

In [None]:
def create_adversarial_pattern(model, input_image, input_label):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    with tf.GradientTape() as tape:
        tape.watch(input_image)
        prediction = model(input_image)
        loss = loss_object(input_label, prediction)

    # Get the gradients of the loss with respect to the input image.
    gradient = tape.gradient(loss, input_image)
    signed_grad = tf.sign(gradient)
    
    return np.array(signed_grad), np.array(gradient)