In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.losses import CategoricalCrossentropy
from keras.utils import plot_model
from keras.optimizers import Adam,SGD
from keras.callbacks import (
    EarlyStopping,
    TerminateOnNaN,
    ModelCheckpoint,
    TensorBoard,
    ReduceLROnPlateau
)
import tensorflow_addons as tfa
import tensorflow_models as tfm
import tensorflow_datasets as tfds
import sys, math, time
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
)
import os
import logging
import shutil
from utils import generate_gan_samples, generate_gan_samples1

print('Python version:', sys.version)
print('TensorFlow version:', tf.__version__)

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('GPU found at: {}'.format(device_name))

tfds.disable_progress_bar()
tf.get_logger().setLevel(logging.ERROR)

In [None]:
# Clear previous models from memory
tf.keras.backend.clear_session()

### Global constants

In [None]:
## Define constants and hyperparameters
EPOCHS = 100
NUMBER_POLICIES = 5
DATASET_NAME = "cifar10"
NETWORK = "resnet"
RESIZE_TO = (32,32,3) if NETWORK == "resnet" else (75, 75, 3)
BATCH_SIZE = 512
AUTO = tf.data.AUTOTUNE
SEED = 42
FOLDER = "." + os.sep + DATASET_NAME + os.sep + NETWORK

### Create folders

In [None]:
# Removing the old folders generated during training
if os.path.exists(FOLDER):
    shutil.rmtree(FOLDER)

# Creating folders to save images, models and checkpoints
newpaths = [FOLDER]
for newpath in newpaths:
    if not os.path.exists(newpath):
        os.makedirs(newpath)

In [None]:
(train_set, valid_set, test_set), info = tfds.load(
    "fashion_mnist" if DATASET_NAME == "fmnist" else DATASET_NAME,
    split=["train[:90%]", "train[90%:]", "test"],
    as_supervised=True,
    with_info=True,
    shuffle_files=True
)
# Extract informative features
class_names = info.features["label"].names
num_classes = info.features["label"].num_classes
input_shape = info.features['image'].shape
IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS = input_shape
NUM_TRAIN = len(train_set)
print("Image shape: {}".format(input_shape))
print("Classes: {}".format(class_names))
print("Number of classes: {}".format(num_classes))

In [None]:
# Generate GAN samples (regular labels from 0 to 9)
gan_set = generate_gan_samples(DATASET_NAME, NUM_TRAIN)
assert len(gan_set) == NUM_TRAIN, "GAN samples not generated correctly"
print("Generated {} GAN samples".format(len(gan_set)))

In [None]:
#https://github.com/tensorflow/models/blob/v2.12.0/official/vision/ops/augment.py#L2048-L2194
exclude_list = ["Cutout"]
augmenter = tfm.vision.augment.RandAugment(num_layers=2, magnitude=9,translate_const=4, exclude_ops=exclude_list)

In [None]:
def resize_and_rescale(image, label):
    # Add RGB channel
    image = tf.image.grayscale_to_rgb(image) if IMG_CHANNELS == 1 else image
    # Reshape without distortions
    image = tf.image.resize_with_pad(image, *RESIZE_TO[:2])
    # Convert to float32
    image = tf.cast(image, tf.float32)
    # One-hot encode labels
    label = tf.one_hot(label, num_classes)
    return image, label

def prepare_dataset(dataset, shuffle=False, augment=False):
    # Resize and rescale the dataset.
    dataset = dataset.map(resize_and_rescale, num_parallel_calls=AUTO).cache()
    # Shuffle the dataset.
    if shuffle:
        dataset = dataset.shuffle(len(dataset))
    # Batch the dataset.
    dataset = dataset.batch(BATCH_SIZE)
    # Augment the dataset using RandAugment.
    if augment:
        dataset = dataset.map(lambda x, y: (augmenter.distort(x), y), num_parallel_calls=AUTO)
    # Prefetch the dataset.
    return dataset.prefetch(AUTO)

In [None]:
def visualize_dataset(dataset, title="Dataset samples"):
    plt.figure().suptitle(title, fontsize=14)
    for images, labels in dataset.take(1):
        for i in range(9):
            ax = plt.subplot(3, 3, i + 1)
            img = tf.keras.utils.array_to_img(images[i])
            plt.imshow(img)
            plt.title(class_names[tf.argmax(labels[i])])
            plt.axis("off")
    plt.show()

In [None]:
# Define a model building utility function
def get_training_model(model_name):
    if model_name == "resnet":
        network = tf.keras.applications.ResNet50V2(
            weights="imagenet",
            include_top=False,
            input_shape=RESIZE_TO,
        )
    elif model_name == "inception":
        network = tf.keras.applications.InceptionV3(
            weights="imagenet",
            include_top=False,
            input_shape=RESIZE_TO,
        )
    else:
        raise NotImplementedError("network not supported")

    model = tf.keras.Sequential(
        [
            keras.layers.Input(RESIZE_TO),
            keras.layers.Rescaling(scale=1.0 / 255.0),
            network,
            keras.layers.GlobalAveragePooling2D(),
            keras.layers.Dense(num_classes),
        ],
        name = model_name,
    )
    return model

In [None]:
def train_model(train_ds):
    model = get_training_model(NETWORK)
    model.load_weights(INITIAL_WEIGHTS)
    model.compile(
        optimizer=Adam(3e-4),
        loss=CategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    callbacks = [
        TerminateOnNaN(),
        ReduceLROnPlateau(factor=1/3.0),
        EarlyStopping(patience=40, restore_best_weights=True),
        ModelCheckpoint(
            filepath=FINAL_WEIGHTS, save_weights_only=True, save_best_only=True
        ),
    ]
    start = time.perf_counter()
    history = model.fit(
        train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=callbacks, verbose=1
    )
    latency = time.perf_counter() - start
    return [history, model, latency]

In [None]:
def plot_history(history, policy):
    # plot loss during training
    fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, figsize=(10, 6))
    fig.suptitle("Policy {} for {} on {}".format(policy, DATASET_NAME, NETWORK))
    ax1.set_title("Training Loss")
    ax1.plot(history.history["loss"], "--")
    ax1.plot(history.history["val_loss"], "--")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend(["training", "validation"], loc="best")
    # plot accuracy during training
    ax2.set_title("Training Accuracy")
    ax2.plot(history.history["accuracy"], "--")
    ax2.plot(history.history["val_accuracy"], "--")
    ax2.set_xlabel("Epochs")
    ax2.set_ylabel("Accuracy")
    ax2.legend(["training", "validation"], loc="best")
    # Set the tick locations
    #plt.xticks(np.arange(0, EPOCHS +1, 1))
    # Save the figure
    plt.savefig(FOLDER + "/{}_{}_policy{}_plot.png".format(DATASET_NAME, NETWORK, policy))
    plt.show()

In [None]:
#refactor to plot all policies using one function with proper titles
def plot_combined_history(histories,training=True,validation=True):
    #epochs = range(1, EPOCHS+1)
    if training==True and validation==False:
        title = "Training for all policies"
    elif training==False and validation==True:
        title = "Validation for all policies" #or another title is Accuracy curves on the validation set for all policies
    else:
        title = "Training and Validation for all policies"
    fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, figsize=(10, 6))
    fig.suptitle(title)
    for i,history in enumerate(histories):
        policy = history.history["policy"][0]
        label_train = f"train - policy {policy}"
        label_val = f"val - policy {policy}"
        ax1.set_title("Model Loss")
        ax1.plot(history.history["loss"],label=label_train) if training else None
        ax1.plot(history.history["val_loss"],"--",label=label_val) if validation else None
        #ax1.legend([label_train, label_val], loc="best")
        ax1.legend()
        ax1.set_xlabel("Epochs")
        ax1.set_ylabel("Loss")
        #ax1.set_ylim([0, 7.0])
        #ax1.set_ylim(0, 4)
        # plot accuracy during training
        ax2.set_title("Model Accuracy")
        ax2.plot(history.history["accuracy"], label=label_train) if training else None
        ax2.plot(history.history["val_accuracy"], "--",label=label_val) if validation else None
        #ax2.legend([label_train, label_val], loc="best")
        ax2.legend()
        ax2.set_xlabel("Epochs")
        ax2.set_ylabel("Accuracy")
        #ax2.set_ylim(0, 1)
    #plt.xticks(np.arange(0, EPOCHS +1, 10))
    plt.savefig(FOLDER + "/{}_{}_combined_plot.png".format(DATASET_NAME,NETWORK))
    plt.show()

In [None]:
def get_performance_metrics(model,latency,policy):
    # Generate predictions
    predictions = model.predict(test_ds, verbose=0)
    # Get the predicted labels
    predictions = predictions.argmax(axis=1)
    # Get the true labels
    y_true = np.concatenate([y for _, y in test_ds]).argmax(axis=1)
    # Get the classification report
    report = classification_report(y_true, predictions, output_dict=True)
    print("\n**CLASSIFICATION REPORT**")
    print(pd.DataFrame(report).transpose())
    # Get the confusion matrix
    conf_matrix = confusion_matrix(y_true, predictions)
    print("\n**CONFUSION MATRIX**")
    print(pd.DataFrame(conf_matrix))
    # Get model metrics
    metrics = {
        "dataset": DATASET_NAME,
        "model": NETWORK,
        "params": NUMBER_PARAMETERS,
        "policy": policy,
        "accuracy": [report["accuracy"]],
        "precision": [report["macro avg"]["precision"]],
        "recall": [report["macro avg"]["recall"]],
        "f1-score": [report["macro avg"]["f1-score"]],
        "training-time(s)": [latency],
    }
    return metrics

In [None]:
def save_model_metrics(metrics):
    file_name = FOLDER + "/{}_{}_policy_results.csv".format(DATASET_NAME, NETWORK)
    results = []
    for i, metric in enumerate(metrics):
        # create and store dataframe
        df = pd.DataFrame(metric)
        results.append(df)
    # concatenate all dataframes
    results = pd.concat(results)
    # save to csv
    results.to_csv(file_name, index=False)

In [None]:
def save_train_history(train_history):
    file_name = FOLDER + "/{}_{}_history.csv".format(DATASET_NAME, NETWORK)
    results = []
    for i, history in enumerate(train_history):
        # create and store dataframe
        df = pd.DataFrame(history.history)
        results.append(df)
    # concatenate all dataframes
    results = pd.concat(results)
    # save to csv
    results.to_csv(file_name, index=False)

In [None]:
# Validation dataset
val_ds = prepare_dataset(valid_set)
# Testing dataset
test_ds = prepare_dataset(test_set)
# Training dataset (policy 1)
train_ds = prepare_dataset(train_set, shuffle=True)
# Augmented training dataset (policy 2)
train_ds_aug = prepare_dataset(train_set, shuffle=True, augment=True)
# GAN dataset (policy 3)
gan_ds = prepare_dataset(gan_set, shuffle=True)
# GAN plus basic training dataset (policy 4).
# Combine the (shuffled) datasets randomly.
gan_train_ds = tf.data.Dataset.sample_from_datasets([train_ds, gan_ds], [0.5, 0.5])
# Augment GAN dataset (policy 5)
gan_ds_aug = prepare_dataset(gan_set, shuffle=True, augment=True)

In [None]:
visualize_dataset(train_ds, title="Basic samples"),
visualize_dataset(train_ds_aug, title="Augmented samples"),
visualize_dataset(gan_ds, title="GAN samples")

In [None]:
initial_model = get_training_model(NETWORK)
NUMBER_PARAMETERS = initial_model.count_params()
INITIAL_WEIGHTS = FOLDER + f"/{DATASET_NAME}_{NETWORK}_initial_weights.h5"
FINAL_WEIGHTS = FOLDER + f"/{DATASET_NAME}_{NETWORK}_final_weights.h5"
initial_model.summary()
# For reproducibility, we first save the initialize weights
initial_model.save_weights(INITIAL_WEIGHTS)
# Save network structure
plot_model(
    initial_model,
    show_shapes=True,
    show_layer_names=True,
    to_file=FOLDER + f"/{DATASET_NAME}_{NETWORK}_model.png",
)

In [None]:
%%time
train_history = []
train_metrics = []
# five policies
for i in range(2, 3):
    if i == 1:
        history, model, latency = train_model(train_ds)
    elif i == 2:
        history, model, latency = train_model(train_ds_aug)
    elif i == 3:
        history, model, latency = train_model(gan_ds)
    elif i == 4:
        history, model, latency = train_model(gan_train_ds)
    elif i == 5:
        history, model, latency = train_model(gan_ds_aug)
    else:
        raise NotImplementedError("Policy does not exist")
    # compute performance metrics
    metrics = get_performance_metrics(model, latency, i)
    # add policy number to history
    history.history["policy"] = [i] * len(history.history["loss"])
    # show plot for policy
    plot_history(history, i)
    # save history and metrics of policy to array
    train_history.append(history)
    train_metrics.append(metrics)
    # save model
    model.save(FOLDER + "/{}_{}_policy{}_model.h5".format(DATASET_NAME, NETWORK, i))

In [None]:
# plot loss and accuracy for all policies
plot_combined_history(train_history)
plot_combined_history(train_history, training=False,validation=True)

In [None]:
# save history and metrics to file
save_model_metrics(train_metrics)
save_train_history(train_history)