In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
pip install git+https://github.com/davidbau/torchkit

In [None]:
from torchkit.plotwidget import PlotWidget
from torchkit.labwidget import Range, Textbox
from torchkit import show
import torch

def gaussian_model(x, mean, variance):
    return torch.exp(-(x - mean) ** 2 / variance) / torch.sqrt(2 * torch.pi * variance)

def make_gaussian_model(mean, variance):
    return lambda x: gaussian_model(x, torch.tensor(mean), torch.tensor(variance))

def discriminator(x, p_true, p_gen):
    return p_true(x) / (p_true(x) + p_gen(x))

def value_fn(x, p_true, p_gen):
    d = discriminator(x, p_true, p_gen)
    return p_true(x) * torch.log(d) + p_gen(x) * torch.log(1 - d)

def redraw_rule(fig, gen_mean=0.0, gen_variance=1.0, true_mean=3.0, true_variance=0.5):
    if len(fig.axes) == 1:
       ax2 = fig.axes[0].twinx()
    [ax, ax2] = fig.axes
    x_range = torch.arange(-3, 6, 0.1)
    true_model = make_gaussian_model(true_mean, true_variance)
    gen_model = make_gaussian_model(gen_mean, gen_variance)
    ax2.clear()
    ax2.set_ylim(-1.05, 0.05)
    ax2.fill_between(x_range, value_fn(x_range, true_model, gen_model),
                     alpha=0.5, color='red', lw=0, label='$V$ value')
    ax.clear()
    ax.set_title('The GAN game: try to minimize the maximum of the discriminator')
    ax.plot(x_range, gen_model(x_range), linewidth=3, label='$p_G$ generator distribution')
    ax.plot(x_range, true_model(x_range), label='$p_t$ target distribution')
    ax.plot(x_range, discriminator(x_range, true_model, gen_model), label='$D$ discriminator')
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc='upper left')
    ax2.legend(loc='center right')

plot = PlotWidget(redraw_rule)
mean_slider = Range(value=plot.prop('gen_mean'), min=-5, max=5, step=0.01)
mean_input = Textbox(value=plot.prop('gen_mean'))
var_slider = Range(value=plot.prop('gen_variance'), min=0.1, max=1.9, step=0.01)
var_input = Textbox(value=plot.prop('gen_variance'))
show([[plot, 'Generator Mean', mean_slider, mean_input, 'Generator Variance', var_slider, var_input]])
