In [None]:
%%capture
!pip install -U git+https://github.com/davidbau/baukit@main#egg=baukit

In [None]:
from baukit import PlotWidget, Range, Checkbox, show
import math, torch

xmin, xmax = -6.0, 6.0
z = torch.stack([
    torch.zeros(201),
    torch.linspace(xmin, xmax, 201),
])
p = torch.softmax(z, dim=0)

def compare_loss(fig, y1=0.5, dokl=True, dose=True, doce=True, dol1=True):
    [ax1] = fig.axes
    y0 = 1.0 - y1
    kl = y0 * (math.log(y0) - torch.log(p[0])) + y1 * (math.log(y1) - torch.log(p[1]))
    ce = y0 * ( - torch.log(p[0])) + y1 * ( - torch.log(p[1]))
    se = ((p - torch.tensor([y0, y1])[:, None])**2).sum(0)
    # sampled_se = (y0 * ((1-p[0])**2 + p[1]**2)) + (y1 * ((1-p[1])**2 + p[0]**2))
    sampled_l1 = (2*y0*p[1] + 2*y1*p[0])
    ax1.clear()
    ax1.set_ylim(0, 3.0)
    ax1.set_xlim(xmin, xmax)
    ax1.set_ylabel('Loss')
    ax1.set_xlabel('Difference between logits $z_1 - z_0$')
    ax1.set_title(f'Loss curve on softmax when target $y_1={y1:.3f}$')

    if dokl: ax1.plot(z[1], kl, label='KL', color='b')
    if dose: ax1.plot(z[1], se, label='SE', color='r')
    if doce: ax1.plot(z[1], ce, label='CE', color='g', linestyle='dashed', alpha=0.6)
    if dol1: ax1.plot(z[1], sampled_l1, label='L1', color='orange', linestyle='dotted', alpha=0.7)
    if dokl or dose or doce or dol1: ax1.legend()

def compare_grad(fig, y1=0.5, dokl=True, dose=True):
    [ax1] = fig.axes
    y0 = 1.0 - y1
    # TODO: fill me in so that d kl / d z1 is plotted.
    dkl_dz1 = p[1] - y1
    # TODO: fill me in so that d mse / d z1 is plotted
    dse_dz1 = 4 * (p[1] - y1) * p[1] * p[0]
    ax1.clear()
    ax1.set_ylim(-0.7, 0.7)
    ax1.set_xlim(xmin, xmax)
    ax1.set_xlabel('Difference between logits $z_1 - z_0$')
    ax1.set_title(f'Gradient of loss with repect to $z_1$ when $y_1={y1:.3f}$')

    if dokl:
        ax1.plot(z[1], dkl_dz1, color='b', label=r'$\frac{\partial \mathrm{KL}}{\partial z_1}$' +
            r'=$\frac{\partial \mathrm{CE}}{\partial z_1}$')
    if dose:
        ax1.plot(z[1], dse_dz1, color='r', label=r'$\frac{\partial \mathrm{SE}}{\partial z_1}$')
    ax1.axhline(0, color='gray', linewidth=0.5)
    if dokl or dose:
        ax1.legend(loc='upper left')

rw = Range(min=0.001, max=0.999, step=0.001, value=0.5)
bkl = Checkbox('KL', value=True)
bce = Checkbox('CE', value=False)
bse = Checkbox('SE', value=True)
bl1 = Checkbox('L1', value=False)
ploss = PlotWidget(compare_loss, y1=rw.prop('value'),
                   dokl=bkl.prop('value'), dose=bse.prop('value'),
                   doce=bce.prop('value'), dol1=bl1.prop('value'),
                   bbox_inches='tight')
pgrad = PlotWidget(compare_grad, y1=rw.prop('value'),
                   dokl=bkl.prop('value'), dose=bse.prop('value'),
                   bbox_inches='tight')
show([[show.raw_html('<div>target y<sub>1</sub> = </div>'),
                       show.style(flex=12), rw,
                       'Include:', bkl, bce, bl1, bse],
                      [ploss, pgrad]])