**Instructors**: Prof. Keith Chugg (chugg@usc.edu) & Prof. B. Keith Jenkins (jenkins@sipi.usc.edu)

**Notebook**: Written by Prof. Keith Chugg.

# Multiclass Classifier for MNIST (and Fashion MNIST) Using MLP in TensorFlow
In this notebook, we will use TensorFlow to train an ANN/MLP for the MNIST/FMNIST datasets we previously explored with an MSE classifier.

In [11]:
import numpy as np 
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras.utils import plot_model
from tensorflow.keras.datasets import fashion_mnist, mnist
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.models import load_model
from tensorflow.keras import regularizers


## Accessing the Data
First, let's get a function to get the MNIST or FMNIST data.  Previously, we used PyTorch to access the datasets, but TF also has these built in.  So, we should be familiar with this data.

In [12]:
USE_FASHION_MNIST = False

#### get the daatset

if USE_FASHION_MNIST:
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    tag_name = 'FashionMNIST'
    label_names = ["top", "trouser", "pullover", "dress", "coat", "sandal", "shirt", "sneaker", "bag", "ankle boot"]
else:
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    tag_name = 'MNIST'
    label_names = [f'{i}' for i in set(test_labels)]

# train_images.shape is (60000, 28, 28)
#test_images.shape (10000, 28, 28)
num_pixels = 28 * 28 
train_images = train_images.reshape( (60000, num_pixels) ).astype(np.float32) / 255.0
test_images = test_images.reshape( (10000, num_pixels) ).astype(np.float32)  / 255.0

## ANN/MLP Model Definition
In TensorFlow, we need to define the model.  Below, we define an ANN that takes a vector of length 784 as input, then has one hidden layer, followed by a SoftMax output layer.

In [13]:
## tSome hyper-parameters for our model
reg_val = 0.0001
hidden_nodes = 48


# this uses the Functional API for definning the model
nnet_inputs = Input(shape=(num_pixels,), name='images')
z = Dense(hidden_nodes, activation='relu', kernel_regularizer=regularizers.l2(reg_val), bias_regularizer=regularizers.l2(reg_val), name='hidden')(nnet_inputs)
z = Dense(10, activation='softmax', kernel_regularizer=regularizers.l2(reg_val), bias_regularizer=regularizers.l2(reg_val), name='output')(z)

our_first_model = Model(inputs=nnet_inputs, outputs=z)

#this will print a summary of the model to the screen
our_first_model.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 images (InputLayer)         [(None, 784)]             0         
                                                                 
 hidden (Dense)              (None, 48)                37680     
                                                                 
 output (Dense)              (None, 10)                490       
                                                                 
Total params: 38,170
Trainable params: 38,170
Non-trainable params: 0
_________________________________________________________________


Notice the `our_first_model.summary()` summary caused a model summary to be printed, this includes the layers and the number of trainable parameters.

In [14]:
#this will produce a digram of the model -- requires pydot and graphviz installed
plot_model(our_first_model, to_file='our_first_model.png', show_shapes=True, show_layer_names=True)


You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


In [None]:
# w_norms = np.linalg.norm(W_hat, axis=0)
w_norms = np.linalg.norm(W_hat[1:], axis=0)

C = W_hat.shape[1]

plt.figure()
plt.stem(np.arange(C), w_norms)
plt.grid(':')
plt.xlabel('class index (m)')
plt.ylabel(r'$|| {\bf w}_m ||$')


## Evaluating and visualizing the MSE Classifier


Now that we have trained our multiclass linear regression system, we need to use it to make decisions.  I wrote a function that takese in the weight-vectors (as a 2D array) the augmented data matrix, and will evalute the MSE multiclass classifier.  It will also produce histograms of $g_m({\bf x})$ conditioned on the true class.  This allows us to see the misclassification rate when a given class is true.  It also shows us visually which classes are easily confused.  

I took some time to try to document this reasonably well because you can use this on your homework!

In [None]:

def plot_multiclass_histograms(X_aug, W, y, fname, norm_W=False, scale=1, class_names=None):
    """
    Keith Chugg, USC, 2023.

    X_aug: shape: (N, D + 1).  Augmented data matrix
    W: shape: (D + 1, C).  The matrix of augmented weight-vectors.  W.T[m] is the weight vector for class m
    y: length N array with int values with correct classes.  Classes are indexed from 0 up.
    fname: a pdf of the histgrams will be saved to filename fname
    norm_W: boolean.  If True, the w-vectors for each class are normalized.
    scale: use scale < 1 to make the figure smaller, >1 to make it bigger
    class_names: pass a list of text, descriptive names for the classes.  

    This function takes in the weight vectors for a linear classifier and applied the "maximum value methd" -- i.e., 
    it computes the argmax_m g_m(x), where g_m(x) = w_m^T x to find the decision. For each class, it plots the historgrams 
    of  g_m(x) when class c is true.  This gives insights into which classes are most easily confused -- i.e., similar to a 
    confusion matrix, but more information.  

    Returns: the overall misclassification error percentage
    """
    if norm_W:
       W = W / np.linalg.norm(W[1:], axis=0)
    y_soft = X_aug @ W
    N, C = y_soft.shape
    y_hard = np.argmax(y_soft, axis=1)
    error_percent = 100 * np.sum(y != y_hard) / len(y) 

    fig, ax = plt.subplots(C, sharex=True, figsize=(12 * scale, 4 * C * scale))
    y_soft_cs = []
    conditional_error_rate = np.zeros(C)
    if class_names is None:
        class_names = [f'Class {i}' for i in range(C)]
    for c_true in range(C):
        y_soft_cs.append(X_aug[y == c_true] @ W)
        y_hard_c = np.argmax(y_soft_cs[c_true], axis=1)
        conditional_error_rate[c_true] = 100 * np.sum(y_hard_c != c_true) / len(y_hard_c)
    for c_true in range(C):
        peak = -100
        for c in range(C):
            hc = ax[c_true].hist(y_soft_cs[c_true].T[c], bins = 100, alpha=0.4, label=class_names[c])
            peak = np.maximum(np.max(hc[0]), peak)
            ax[c_true].legend()
            ax[c_true].grid(':')
        ax[c_true].text(0, 0.9 * peak, f'True: {class_names[c_true]}\nConditional Error Rate = {conditional_error_rate[c_true] : 0.2f}%')
    if norm_W:
        ax[C-1].set_xlabel(r'normalized discriminant function $g_m(x) / || {\bf w} ||$')
    else:
        ax[C-1].set_xlabel(r'discriminant function $g_m(x)$')
    plt.savefig(fname, bbox_inches='tight')
    return error_percent

In [None]:
error_rate = plot_multiclass_histograms(x_train_aug, W_hat, y_train, f'img/hist_{tag_name}.pdf', scale=1, class_names=label_names)
print(f'\nOverall Misclassification Rate: {error_rate : 0.2f}%')

That is not too bad!  For FashionMNIST, we can see that when "shirt" is true, it is easily confused for many other classes!

Also, if you look above, it is class 5 which has the w-vector with smallest norm (FashionMNIST).  Would it make more sense to maximize ${\bf w}_m{\bf x}/ \| {\bf w}_m\|$?  Note that this is the distance to the decision boundary for the impicit one-vs-rest test.  

Let's try that out, it's buit into `plot_multiclass_histograms()` function with the `norm_W` optional argument.

In [None]:
error_rate = plot_multiclass_histograms(x_train_aug, W_hat, y_train, f'img/hist_{tag_name}.pdf', norm_W=True, scale=1, class_names=label_names)
print(f'\nOverall Misclassification Rate: {error_rate : 0.2f}%')

Note that using `norm_W` as `True` or `False` provides two different decision (or fusion) rules for combining the OvR classifiers.  Which do you think performs better?

We can also argue that there is a real issue with an MSE classifier.

For example, given that the target of the true class is $+1$, when the true class is $k$, if discriminant function is $g({\bf x}) \gg +1$ it should be a high confidence decision that class $k$ is true.  However, the squared error loss actually penalizes for this -- e.g., $g({\bf x}) = +10$ incurs a huge penalty while $g({\bf x})=+1$ incurs no penalty.  

This can be addressed with multiclass logistic regression where the max operation over all $g_m({\bf x})$ is replaced by a multiclass softmax function.  This is multiclass logistic regression.  This doesn't have a closed form solution and requires gradient descent to solve. 