## The softmax function

To convert the model to multiclass classification, we need to make a few changes to the metrics and training parameters. Previously, we used the sigmoid function to convert logits to probabilities, then rounded those probabilities to get a predicted label. However, now that there are multiple possible classes, we need to use the generalization of the sigmoid function, known as the softmax function.

The softmax function takes in a vector of numbers (logits for each class), and converts the numbers to a probability distribution. This means that the sum of the probabilities across all the classes equals 1, and each class's individual probability is based on how large its logit was relative to the sum of all the classes's logits.

In [11]:
import tensorflow as tf

with tf.compat.v1.Session() as sess:
    t = tf.constant([[0.4, -0.8, 1.3],
                 [0.2, -1.2, -0.4]])
    softmax_t = tf.nn.softmax(t)

    print('{}\n'.format(repr(sess.run(t))))
    print('{}\n'.format(repr(sess.run(softmax_t))))

array([[ 0.4, -0.8,  1.3],
       [ 0.2, -1.2, -0.4]], dtype=float32)

array([[0.2659011 , 0.08008787, 0.65401113],
       [0.5569763 , 0.13734867, 0.30567506]], dtype=float32)



## Predictions

Our model's prediction now becomes the class with the highest probability. Since we label each class with a unique index, we need to return the index with the maximum probability. TensorFlow provides a function that lets us do this, called tf.math.argmax.

The function takes in a required input tensor, as well as a few keyword arguments. One of the more important keyword arguments is axis, which tells us which dimension to retrieve the maximum index from. Setting axis=-1 uses the final dimension, which in this case corresponds to retrieving the column index.

In [12]:
with tf.compat.v1.Session() as sess:
    probs = tf.constant([[0.4, 0.3, 0.3],
                     [0.2, 0.7, 0.1]])
    preds = tf.argmax(probs, axis=-1)

    print('{}\n'.format(repr(sess.run(probs))))
    print('{}\n'.format(repr(sess.run(preds))))

array([[0.4, 0.3, 0.3],
       [0.2, 0.7, 0.1]], dtype=float32)

array([0, 1], dtype=int64)

