In [None]:
from baulab import show, Range, Numberbox, PlotWidget, set_requires_grad
import torch, numpy

# The Generative Modeling Setting

This notebook sets up the generative modeling game to help build some intuition about the problem setting. To set up the game, we set up a true (but computationally unknown) distribution, and we also set up a generated distribution that is explicitly under your control.

We plot both: the **true** distribution is in orange.

Your **generated** distribution is in blue.

In [None]:
torch.set_grad_enabled(False)

num_samples = 1000

z = torch.from_numpy(numpy.random.RandomState(1).randn(num_samples, 2)).float()
true_model = torch.nn.Linear(2, 2)
data = true_model(torch.from_numpy(numpy.random.RandomState(2).randn(num_samples, 2)).float())

def draw_scatter(fig, A=1.0, B=0.0, C=0.0, D=1.0, X=0.0, Y=0.0, title='Gaussian', minibatch=num_samples):
    [ax] = fig.axes
    ax.clear()
    
    layer = torch.nn.Linear(2, 2)
    layer.weight[...] = torch.tensor([[A, B], [C, D]])
    layer.bias[...] = torch.tensor([X, Y])
    
    output = layer(z)
    
    batch = numpy.random.choice(num_samples, min(minibatch, num_samples))
    
    ax.scatter(output[batch,0], output[batch,1], s=1, label='Generated samples')
    ax.scatter(data[batch,0], data[batch,1], s=1, label='True samples')
    ax.set_title(title)
    ax.set_aspect(1.0)
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.legend(loc='upper left')

model_description = '''
X = W Z + V

    [ A  B ]      [ X ]
W = [      ]  V = [   ]
    [ B  C ]      [ Y ]
'''

## Interactive generative modeling

Your model is defined by $X = WZ + M$, where $Z$ is a random normal vector, and $W$ and $V$ are a matrix and vector with parameters that you control.

The game is: based on your observations of blue and orange dots, try to adjust $W$ and $V$ parameters to make the distributions match.

In [None]:
before = PlotWidget(draw_scatter, figsize=(5, 5), dpi=100, title="Before transform")
after = PlotWidget(draw_scatter, figsize=(5, 5), dpi=100, title="After transform")

show([ ['Train a generative model by hand: match the orange distribution.'],
       [before, after],
       [show.style(textAlign='center'), model_description,
         [
          [  v,
             show.style(flex=5), Range(value=after.prop(v), min=-2.0, max=2.0, step=0.01),
             show.style(width=50), Numberbox(value=after.prop(v))]
          for v in 'ABCDXY'
         ]
       ],
       [show.style(textAlign='right', flex=3),
        'Minibatch size', Numberbox(value=after.prop('minibatch'))]
     ])

Now try playing the game with a small minibatch size.  It gets harder.

## Things to think about

What did you do to solve the game?

 * You had to estimate the divergence between the two distributions.
 * You had to decide which noise to ignore, and what details to pay attention to.
 * You had to decide which direction to push the slider, that would help.
 * You had to decide when to stop...

Also - we had to pick a set of modeling parameters that are powerful enough to get close to the distribution.