# Logits to Probabilities Demo

In [5]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Function to calculate softmax
def softmax(logits):
    e_logits = np.exp(logits - np.max(logits))
    return e_logits / e_logits.sum()

# Predefined set of logits
logits_set = {
    "Set 1": [2.0, 1.0, 0.1],
    "Set 2": [1.0, 2.0, 3.0],
    "Set 3": [1.0, 1.0, 1.0],
    "Set 4": [3.0, 1.0, 0.5],
    "Set 5": [0.5, 0.2, 0.1]
}

# Create initial plot
initial_logits = logits_set["Set 1"]
initial_probs = softmax(initial_logits)

fig = make_subplots(rows=1, cols=2, subplot_titles=("Logits", "Softmax Probabilities"))

# Add initial logits bar chart
logits_bar = go.Bar(x=[f'Logit {i+1}' for i in range(len(initial_logits))],
                    y=initial_logits, name="Logits")
fig.add_trace(logits_bar, row=1, col=1)

# Add initial softmax probabilities bar chart
softmax_bar = go.Bar(x=[f'Class {i+1}' for i in range(len(initial_probs))],
                     y=initial_probs, name="Softmax Probabilities")
fig.add_trace(softmax_bar, row=1, col=2)

# Function to create update args for a given set of logits
def create_update_args(logits):
    return [
        {"y": [logits, softmax(logits)]}
    ]

# Update layout
fig.update_layout(
    title="Interactive Softmax Calculation",
    showlegend=False,
    yaxis2=dict(range=[0, 1]),  # Set y-axis limit for softmax plot
    updatemenus=[
        {
            "buttons": [
                {
                    "args": create_update_args(logits),
                    "label": f"Set {i+1}",
                    "method": "update"
                } for i, logits in enumerate(logits_set.values())
            ],
            "direction": "down",
            "showactive": True,
            "x": 0.5,
            "xanchor": "center",
            "y": 1.15,
            "yanchor": "top"
        }
    ]
)

fig.show()
