In [None]:
import ipywidgets

import matplotlib.pyplot as plt
import torch

# <center> Applying temperature + keeping only top K values</center>

$T=\mbox{temperature}$ $$\large P_i=\frac{e^{\frac{y_i}T}}{\sum_{k=1}^n e^{\frac{y_k}T}}$$

In [None]:
@ipywidgets.interact
def _(
    n_tokens=ipywidgets.IntSlider(min=4, max=30, value=8, continuous_update=False),
    random_state=ipywidgets.IntSlider(min=0, max=10, value=2, continuous_update=False),
    temperature=ipywidgets.FloatSlider(min=0, max=10, value=1, continuous_update=False),
    top_k=ipywidgets.IntSlider(min=1, max=20, value=8, continuous_update=False),
    ):
    # Preparations
    top_k = min(top_k, n_tokens)
    torch.manual_seed(random_state)
    logits = 10 * torch.rand(n_tokens,)


    # Generate original
    probs_orig = torch.nn.functional.softmax(logits, dim=0).numpy()
    
    # Generate new
    logits = logits / temperature
    top_values, _ = torch.topk(logits, top_k)  # (top_k,)                                                                                                                                                                                 
    logits[logits < top_values.min()] = -torch.inf       
    probs_new = torch.nn.functional.softmax(logits, dim=0).numpy()

    # Plotting
    fig, (ax_orig, ax_new) = plt.subplots(1, 2, sharey=True, figsize=(10, 2), dpi=100)
    x = range(n_tokens)

    ax_orig.bar(x, probs_orig)
    ax_orig.set_ylim((0, 1))
    ax_orig.set_title("Original")
    
    ax_new.bar(x, probs_new)
    ax_new.set_title("Temperature + top K")
    
    plt.show()