In [None]:
from tensorflow.keras.models import load_model
from keras.datasets import mnist, cifar10, fashion_mnist
from keras.applications import ResNet50, MobileNetV2, MobileNetV3Small
from keras.utils import to_categorical
import tensorflow as tf

import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

import pandas as pd
import numpy as np
import pathlib

from gand.config import MLConfig
from gand.data import data
from gand.models import models, architecture
from gand.visualisation import visualise
from gand.preprocessing import utils

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [None]:
DATASET_INFO = {
    'mnist': {
        'class_names': [str(i) for i in range(10)],
        'input_shape': (28, 28, 1)
    },
    'fashion_mnist': {
        'class_names': ['t-shirt/top', 'trouser', 'pullover', 'dress', 'coat', 
                        'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'],
        'input_shape': (28, 28, 1)
    },
    'cifar10':{
        'class_names': ['airplane', 'car', 'bird', 'cat', 'deer', 
                        'dog', 'frog', 'horse', 'ship', 'truck'],
        'input_shape': (32, 32, 3)
    }
}

In [None]:
import os.path
os.path.isfile('reports/reports/cifar10/models/E_020/gans-Mobilenetv3small.h5') 

In [None]:
def confusion_matrix_plot(path=None, data=None, 
                          dataset_name='mnist',
                          name='training',
                          model=None, figsize=(5, 5),
                          fontsize=20, savefig=False):

    X, y = data
    y = np.argmax(y, axis=1)

    y_prediction = model.predict(X, verbose=0)
    y_prediction = np.argmax(y_prediction, axis=1)

    result = confusion_matrix(y, y_prediction)

    sns.set_theme(style='darkgrid')
    plt.figure(figsize=figsize)
    ax = plt.gca()

    # Create a heatmap of the confusion matrix
    sns.heatmap(result, annot=True, fmt='d',
                cmap='Blues', cbar=False, ax=ax, annot_kws={"size": fontsize-5})

    # Set labels, title, and ticks
    ax.set_xlabel('Predicted', fontsize=fontsize)
    ax.set_ylabel('Actual', fontsize=fontsize)
    ax.set_title(f'Confusion Matrix for {dataset_name}', fontsize=fontsize)

    ax.xaxis.set_ticklabels(DATASET_INFO[dataset_name]['class_names'], rotation=90)
    ax.yaxis.set_ticklabels(DATASET_INFO[dataset_name]['class_names'], rotation=0)

    ax.tick_params(axis='both', which='major', labelsize=fontsize-5)

    plt.tight_layout()
    if savefig:
        plt.savefig(path.joinpath(f'{name}.png'))
    plt.show()
    plt.close()

In [None]:
dataset = cifar10
dataset_name = 'cifar10'
type = 'normal'
model_name = 'deepModel'

(_, _), (X_test, y_test) = dataset.load_data()

if X_test.shape[-1] != 3:
    X_test = np.expand_dims(X_test, axis=-1)
    # X_test = np.repeat(X_test, 3, axis=-1)
    # X_test = np.array(tf.image.resize(X_test, [32,32]))

X_test = X_test.astype('float32') / 255.0
# X_test = tf.keras.applications.mobilenet_v3.preprocess_input(X_test)
y_test = to_categorical(y_test, num_classes=10)

# reports/models/mnist/gans/E_100/deep_model_stable.h5
model_path = Path.cwd().joinpath(f'reports/reports/{dataset_name}/models/E_100/{type}-{model_name}.h5')
# model_path = Path.cwd().joinpath(f'reports/models/{dataset_name}/{type}/E_100/{model_name}.h5')
print(model_path)
model = load_model(model_path)

conf_path = Path.cwd() / f'reports/reports/{dataset_name}/figures/E_100/confusion_matrix/'
conf_path.mkdir(parents=True, exist_ok=True)

confusion_matrix_plot(model=model, data=(X_test, y_test), 
                      dataset_name=dataset_name, 
                      figsize=(12, 12), fontsize=30, 
                      name=f'{type}-{model_name}', path=conf_path, savefig=True)

# Plotting Predictions

In [None]:
dataset = cifar10

(_, _), (X_test, y_test) = dataset.load_data()
# X_test = X_test.astype('float32') / 255.0

if X_test.shape[-1] != 3:
    X_test = np.expand_dims(X_test, axis=-1)
    # X_test = np.repeat(X_test, 3, axis=-1)
    # X_test = np.array(tf.image.resize(X_test, [32,32]))
    
# X_test = tf.keras.applications.mobilenet_v3.preprocess_input(X_test)
X_test = X_test.astype('float32') / 255.0
y_test = to_categorical(y_test, num_classes=10)

class_idx = [np.where(np.argmax(y_test, axis=-1) == i)[0] for i in range(10)]
sample_idx = []
for i in range(10):
    sample_idx.append(np.random.choice(class_idx[i]))

sample_images = X_test[sample_idx]
sample_labels = y_test[sample_idx]

(_, _), (X_test, y_test) = dataset.load_data()
saved_copy = X_test[sample_idx]

In [None]:
dataset_name = 'cifar10'
type = 'gans'
model_name = 'deepModel'

# model_path = Path.cwd().joinpath(f'reports/models/{dataset_name}/{type}/E_100/{model_name}.h5')
model_path = Path.cwd().joinpath(f'reports/reports/{dataset_name}/models/E_100/{type}-{model_name}.h5')
print(model_path)
model = load_model(model_path)

FONTSIZE = 20
fig, axs = plt.subplots(5, 4, figsize=(20, 20))

axs = axs.flatten()
for i in range(0, 20, 2):
    image = np.expand_dims(sample_images[i // 2], axis=0)    
    true_label = np.argmax(sample_labels[i // 2])
    pred_values = model.predict(image, verbose=0)
    pred_label = np.argmax(pred_values)

    axs[i].imshow(saved_copy[i//2], cmap='gray_r')
    
    color = 'red'
    if true_label == pred_label:
        color = 'blue'

    axs[i].set_xticks([])
    axs[i].set_yticks([])
    
    axs[i].set_xlabel("{}: {:2.0f}% ({})".format(DATASET_INFO[dataset_name]['class_names'][pred_label],
                                                100*np.max(pred_values),
                                                DATASET_INFO[dataset_name]['class_names'][true_label]), 
                      color=color, fontsize=FONTSIZE)

    axs[i+1].set_xticks(range(10))
    axs[i+1].set_yticks([])
    axs[i+1].set_ylim([0, 1])
    
    thisplot = axs[i+1].bar(range(10), pred_values[0], color="#777777")
    thisplot[pred_label].set_color('red')
    thisplot[true_label].set_color('blue')


pred_path = Path.cwd() / f'reports/reports/{dataset_name}/figures/E_100/predictions/'
pred_path.mkdir(parents=True, exist_ok=True)

plt.tight_layout()
plt.savefig(pred_path.joinpath(f'{type}-{model_name}.png'), bbox_inches='tight')
plt.show()
plt.close()

In [None]:
import pandas as pd

In [None]:
def metric_plot(path: Path = None, history=None,
                dataset_name=None, epochs=None,
                fontsize=20, figsize=(10, 8),
                savefig=True, show_fig=False):
    hash = {
        'legend_loc': ['upper right', 'lower right'],
        'color': ['#1f77b4', '#ff7f0e'],
        'text': {
            'loss': [0.90, 0.80],
            'accuracy': [0.15, 0.05]
        }
    }

    sns.set_theme(style='darkgrid')
    fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=figsize)
    X = np.arange(epochs)

    for i in range(len(axs)):
        key = 'loss'
        if i == 1:
            key = 'accuracy'

        sns.lineplot(x=X, y=history[key], ax=axs[i],
                     linewidth=2.5, label='Training', color=hash['color'][0])
        sns.lineplot(x=X, y=history[f'val_{key}'], ax=axs[i],
                     linewidth=2.5, linestyle='--', label='Testing', color=hash['color'][1])

        axs[i].legend(loc=hash['legend_loc'][i], fontsize=fontsize)

        axs[i].set_ylabel(f'{key.title()}', fontsize=fontsize)
        axs[i].tick_params(axis='y', labelsize=fontsize - 6)

        for spine in ['top', 'right']:
            axs[i].spines[spine].set_visible(False)

        print(f"{key}")
        axs[i].text(0.55, hash['text'][key][0], '{}: {:.4f}'.format(key, list(history[f"{key}"])[-1]),
                    transform=axs[i].transAxes, ha='right', fontsize=fontsize - 3, color=hash['color'][0])
        axs[i].text(0.55, hash['text'][key][1], 'val_{}: {:.4f}'.format(key, list(history[f"val_{key}"])[-1]),
                    transform=axs[i].transAxes, ha='right', fontsize=fontsize - 3, color=hash['color'][1])

        axs[0].set_title(f'Loss and Accuracy Curves ({dataset_name})', fontsize=fontsize)
        axs[1].set_xlabel('Epochs', fontsize=fontsize)

        num_ticks = 10  # Specify the desired number of ticks
        xticks = np.linspace(0, epochs - 1, num_ticks)
        xticklabels = ['{:d}'.format(int(tick)) for tick in xticks]

        axs[1].set_xticks(xticks)
        axs[1].set_xticklabels(xticklabels, fontsize=fontsize - 6)

    plt.tight_layout()
    if savefig:
        plt.savefig(path)
    if show_fig:
        plt.show()
    else:
        plt.close()

In [None]:
dataset_name = 'mnist'
type = 'gans'
model_name = 'Mobilenetv3small'
EPOCHS = 20

history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')

vis_path = Path.cwd() / f'reports/reports/{dataset_name}/figures/E_{EPOCHS:03d}/'
vis_path.mkdir(parents=True, exist_ok=True)
print(vis_path)

metric_plot(show_fig=True, history=history, dataset_name='mnist',
            savefig=True, epochs=20, fontsize=25, 
            path=f'{vis_path.joinpath(f"{type}-{model_name}")}.png')

In [None]:
for dataset_name in ['mnist', 'fashion_mnist', 'cifar10']:
    for type in ['normal', 'imbalanced', 'gans']:
        for model_name in ['Resnet50', 'Mobilenetv3small']:
            for EPOCHS in [20]:
                history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')

                vis_path = Path.cwd() / f'reports/reports/{dataset_name}/figures/E_{EPOCHS:03d}/'
                vis_path.mkdir(parents=True, exist_ok=True)
                print(vis_path)
                
                metric_plot(show_fig=True, history=history, dataset_name='mnist',
                            savefig=True, epochs=20, fontsize=25, 
                            path=f'{vis_path.joinpath(f"{type}-{model_name}")}.png')

In [None]:
dataset_name = 'mnist'
type = 'normal'

model_name = 'deepModel'
EPOCHS = 100
history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')
# plt.plot(history['val_loss'], marker='o')


model_name = 'Resnet50'
EPOCHS = 20
history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')
plt.plot(history['val_loss'], marker='o', label=model_name)

model_name = 'Mobilenetv3small'
history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')
plt.plot(history['val_loss'], marker='o', label=model_name)
plt.legend()



In [None]:
for dataset_name in ['mnist', 'fashion_mnist', 'cifar10']:
    for type in ['normal', 'imbalanced', 'gans']:
        for model_name in ['deepModel']:
            for EPOCHS in [100]:
                history = pd.read_csv(f'reports/reports/{dataset_name}/history/E_{EPOCHS:03d}/{type}-{model_name}.csv')
            
                vis_path = Path.cwd() / f'reports/reports/{dataset_name}/figures/E_{EPOCHS:03d}/'
                vis_path.mkdir(parents=True, exist_ok=True)
                print(vis_path)
                
                metric_plot(show_fig=True, history=history, dataset_name='mnist',
                            savefig=True, epochs=EPOCHS, fontsize=25, 
                            path=f'{vis_path.joinpath(f"{type}-{model_name}")}.png')