In [None]:
from torchkit.plotwidget import PlotWidget
from torchkit.labwidget import Range
from torchkit import show
import torch

def gaussian_model(x, mean, variance):
    return torch.exp(-(x - mean) ** 2 / variance) / torch.sqrt(2 * torch.pi * variance)

def make_gaussian_model(mean, variance):
    return lambda x: gaussian_model(x, torch.tensor(mean), torch.tensor(variance))

def discriminator(x, p_true, p_gen):
    return p_true(x) / (p_true(x) + p_gen(x))

def redraw_rule(fig, gen_mean=0.0, gen_variance=1.0, true_mean=3.0, true_variance=0.5):
    [ax] = fig.axes
    x_range = torch.arange(-3, 6, 0.1)
    true_model = make_gaussian_model(true_mean, true_variance)
    gen_model = make_gaussian_model(gen_mean, gen_variance)
    ax.clear()
    ax.set_title('The GAN game: try to minimize the maximum of the discriminator')
    ax.plot(x_range, true_model(x_range), label='$p_t$ true distribution')
    ax.plot(x_range, gen_model(x_range), label='$p_G$ generator distribution')
    ax.plot(x_range, discriminator(x_range, true_model, gen_model), label='$D$ discriminator')
    ax.axhline(discriminator(x_range, true_model, gen_model).max(), ls='--', color='green', label="$D$ max")
    ax.set_ylim(-0.05, 1.05)
    ax.legend(loc='upper left')
if False:
    plot = PlotWidget(redraw_rule)
    mean_slider = Range(value=plot.gen_mean, min=-5, max=5, step=0.01).on('value', lambda v: setattr(plot, 'gen_mean',v))
    var_slider = Range(value=plot.gen_variance, min=0.1, max=1.9, step=0.01).on('value', lambda v: setattr(plot, 'gen_variance', v))
    show([[plot, 'Generator Mean', mean_slider, 'Generator Variance', var_slider]])

mean_slider = Range(value=0, min=-5, max=5, step=0.01)
var_slider = Range(value=1.0, min=0.1, max=1.9, step=0.01)
plot = PlotWidget(redraw_rule, gen_mean=mean_slider.prop('value'), gen_variance=var_slider.prop('value'))
show([[plot, 'Generator Mean', mean_slider, 'Generator Variance', var_slider]])


    