In [4]:
#imports
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact,fixed
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display


In [13]:
#softmax def
# inspired by the numerical stable version in
# https://jaykmody.com/blog/stable-softmax/
def softmax(logits,temperature):
    e_x = np.exp((logits - np.max(logits))/temperature)
    return e_x / e_x.sum()

def plot_softmax(temperature):
    global logits
    fig, ax = plt.subplots(ncols=2,figsize =(16, 9))
    ax[0].barh(y=list(range(len(logits))), color ='b', label ='logits',width=logits,alpha = 0.4)
    ax[1].barh(y=list(range(len(logits))), color ='g', label ='softmax',width = softmax(logits,temperature),alpha = 0.4) 
    ax[0].set_title("Logits")
    ax[1].set_title("Softmax")
def get_logits(n):
    global logits
    logits=np.random.uniform(size=n)
    print(logits)



# Softmax!
The good ol normalizer for classification, whose definition is...
$$\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} $$

The typical use case is at the end of a layer to convert logits (unnormalized predictions from a model) to a probability - where all items are nonzero, and they sum up to 1

There's a "spiked" version of softmax that has an added temperature parameter T

$$\text{softmax}(x)_i = \frac{e^{x_i/T}}{\sum_j e^{x_j/T}} $$

the intuition is - as the temperature gets hotter (as T approaches infinity) we get more uniformity - all values equal, as it gets cooler (as T approaches 0) we get more determinism where softmax becomes "regular" argmax, setting all the weight to the highest(s) values in the logits.

When T = 1, we get back to "regular" softmax

Visually...

In [12]:
dropdown = widgets.Dropdown(
    options=list(range(2,11)),
    value=2,
    description='Total items',
    disabled=False,
    layout={'width': 'max-content'}
)
temperature_widget = widgets.FloatLogSlider(
    value=1,
    base=10,
    min=-4,
    max=4,
    step=0.1,
    description='Temperature (Log Scale)'
)

interact(get_logits,n=dropdown)

interact(plot_softmax,temperature=temperature_widget)

interactive(children=(Dropdown(description='Total items', layout=Layout(width='max-content'), options=(2, 3, 4…

interactive(children=(FloatLogSlider(value=1.0, description='Temperature (Log Scale)', min=-4.0), Output()), _…

<function __main__.plot_softmax(temperature)>