In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

import numpy as np
from bokeh.io import show, output_notebook
from bokeh.plotting import figure
from torch import optim, nn
import torch

from plotting import plot_model_outputs

import beer

output_notebook()

In [137]:
mean = np.array([3., 2.])
cov = np.array([[2., 0.], [0., .2]])
data = np.random.multivariate_normal(mean, cov, size=100)
data[:, 1] = data[:, 1] + (data[:, 0]-mean[0])**2

mean = data.mean(axis=0)
var = data.var(axis=0)

data = (data - mean) / np.sqrt(var)

fig = figure(
    title='Non-Linear subspace',
    width=400,
    height=400,
)
fig.circle(data[:, 0], data[:, 1])
#x = np.linspace(-1, 7, 1000)
#fig.line(x, (x-3)**2+2, color='red')
show(fig)

In [159]:
obs_dim = 2
latent_dim = 2
nb_samples = 10

enc_struct = nn.Sequential(
    nn.Linear(obs_dim, 10),
    nn.Tanh(),
    nn.Linear(10, 10),
    nn.Tanh()
)
encoder = beer.models.MLPNormalDiag(enc_struct, 10, obs_dim)

dec_struct = nn.Sequential(
    nn.Linear(latent_dim, 10),
    nn.ReLU(),
    nn.Linear(10, 10),
    nn.ReLU()
)
decoder = beer.models.MLPNormalDiag(dec_struct, 10, obs_dim)

#latent_model = beer.models.NormalDiagonalCovariance.create(
#    latent_dim,
#    mean_prec=1e1,
#    prec_shape=1,
#    prec_rate=1.
#)
latent_model = beer.models.NaturalIsotropicGaussian(2)

model = beer.models.VAE(encoder, decoder, latent_model, nb_samples)
model.encoder.hid_to_logvar.bias = nn.Parameter(torch.ones(2) * -.5)
model.decoder.hid_to_logvar.bias = nn.Parameter(torch.ones(1) * -.5)

In [160]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-1)
history = beer.inference.History(report_interval=250)

In [186]:
model.sample = True
beer.inference.run_training(data, model, optimizer, 1000, history, batch_size=20, lrate_latent_model=0., kl_weight=1.0)

Epoch: 10250 	elbo: -12.916033 llh: -234.236256 kld: 24.084405
Epoch: 10500 	elbo: -12.856605 llh: -232.276443 kld: 24.855655
Epoch: 10750 	elbo: -12.872988 llh: -230.917252 kld: 26.542503
Epoch: 11000 	elbo: -12.863875 llh: -231.972294 kld: 25.305198


In [187]:
plot_model_outputs(model, data[:100])

In [188]:
history.plot()