In [9]:
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data
from time import time
%matplotlib inline  

In [2]:
error_sigma = 1e-1
nobs = 1_000
xtrue = npr.randn(nobs, 1)
ytrue = np.zeros((nobs,2))
ytrue[:,0] = list(abs(xtrue)*np.cos(xtrue))
ytrue[:,1] = list(abs(xtrue)*np.sin(xtrue))
ytrue = 3*ytrue + npr.randn(*np.shape(ytrue))*error_sigma

In [3]:
mb_size = 10
Z_dim = 1
X_dim = 2
h_dim = 10
c = 0
lr = 1e-3

In [4]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [5]:
# =============================== Q(z|X) ======================================
# ============================== Encoding ===================================== 

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mb_size, Z_dim))
    return mu + torch.exp(log_var / 2) * eps

In [6]:
# =============================== P(X|z) ======================================
# ============================== Decoding ===================================== 

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

def P(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = torch.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [12]:
params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

T = 100_000
start = time()
for it in range(T) :
    mb = np.random.choice(nobs,mb_size)
    X = ytrue[mb]
    X = Variable(torch.from_numpy(X)).float()

    # Forward
    z_mu, z_var = Q(X)
    z = sample_z(z_mu, z_var)
    X_sample = P(z)

    # Loss
#     recon_loss = nn.binary_cross_entropy(X_sample, X, reduction='sum')/mb_size
    recon_loss = torch.nn.MSELoss()(X_sample, X)
    kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
    loss = recon_loss + kl_loss

    loss.backward()
    solver.step()

    # Housekeeping
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())

    # Print and plot every now and then
    if (it+1) % (T/10) == 0:
        print('Iter: {}; Loss: {:.4}; Time: {:.2}'.format(it+1, loss.data, time()-start))

Iter: 10000; Loss: 1.757; Time: 8.5
Iter: 20000; Loss: 2.799; Time: 1.7e+01
Iter: 30000; Loss: 3.004; Time: 2.6e+01
Iter: 40000; Loss: 3.28; Time: 3.4e+01
Iter: 50000; Loss: 2.88; Time: 4.2e+01
Iter: 60000; Loss: 4.769; Time: 5.1e+01
Iter: 70000; Loss: 9.353; Time: 6e+01
Iter: 80000; Loss: 4.316; Time: 6.8e+01
Iter: 90000; Loss: 3.929; Time: 7.7e+01
Iter: 100000; Loss: 4.985; Time: 8.6e+01


In [14]:
X_sample

tensor([[9.9922e-01, 8.9497e-14],
        [9.9923e-01, 5.5440e-01],
        [9.9965e-01, 2.0847e-10],
        [3.7163e-01, 1.0511e-36],
        [9.8759e-01, 9.8477e-01],
        [9.9981e-01, 9.7100e-08],
        [1.3820e-02, 1.0000e+00],
        [9.9843e-01, 2.7191e-16],
        [9.8996e-01, 9.7950e-01],
        [1.6816e-02, 1.0000e+00]], grad_fn=<SigmoidBackward>)