In [59]:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import torch.nn.functional as F

In [72]:
class Decoder(nn.Module):
    def __init__(self, x_dim, z_dim, hidden_dim, dropout=.2):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, x_dim)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(x_dim, affine=True)
        
    def forward(self, z):
        z = self.dropout(z)
        hidden = self.softplus(self.fc1(z))
        loc_img = self.sigmoid(self.bn(self.fc21(hidden)))
        return loc_img
    
class Encoder(nn.Module):
    def __init__(self, x_dim, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(x_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        self.softplus = nn.Softplus()
        self.x_dim = x_dim
        
    def forward(self, x):
        x = x.reshape(-1, self.x_dim)
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale
    
class VAE(nn.Module):
    def __init__(self, x_dim, z_dim=50, hidden_dim=400):
        super().__init__()
        self.encoder = Encoder(x_dim, z_dim, hidden_dim)
        self.decoder = Decoder(x_dim, z_dim, hidden_dim)
        self.x_dim = x_dim
        self.z_dim = z_dim
        
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            logz = pyro.sample('latent', dist.Normal(z_loc, z_scale).to_event(1))
            z =  F.softmax(logz, -1)
            
            loc_img = self.decoder(z)
            pyro.sample('obs', dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, self.x_dim))
            
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate('data', x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            pyro.sample('latent', dist.Normal(z_loc, z_scale).to_event(1))
            
    def reconstruct_img(self, x):
        z_loc, z_scale = self.encoder(x)
        z = dist.Normal(z_loc, z_scale).sample()
        loc_img = self.decoder(z)
        return loc_img

In [73]:
data = torch.tensor([1.,0,0]).repeat(1000,1)
data[0:5,:]
train_loader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True)

In [74]:
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

In [75]:
def train(svi, train_loader):
    epoch_loss = 0.
    for x in train_loader:
        epoch_loss += svi.step(x)
        
    return epoch_loss / len(train_loader.dataset)

def evaluate(svi, test_loader):
    test_loss = 0.
    for x in test_loader:
        test_loss += svi.evaluate_loss(x)
        
    return test_loss / len(test_loader.dataset)

In [76]:
vae = VAE(x_dim=3)
optimizer = Adam({'lr': 1.0e-3})
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

In [77]:
print(train(svi, train_loader))
evaluate(svi, test_loader)

17.014805236816407


16.633094055175782

In [78]:
pyro.clear_param_store()
vae = VAE(x_dim=3, z_dim=4)
adam_args = {'lr': 1.0e-3}
optimizer = Adam(adam_args)
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
for epoch in range(100):
    epoch_train_elbo = train(svi, train_loader)
    train_elbo.append(-epoch_train_elbo)
    if epoch % 5 == 0:
        epoch_test_elbo = train(svi, test_loader)
        test_elbo.append(-epoch_test_elbo)
        print("[epoch %03d] average test loss: %.4f" % (epoch, epoch_test_elbo))

[epoch 000] average test loss: 2.3459
[epoch 005] average test loss: 1.4520
[epoch 010] average test loss: 0.9930
[epoch 015] average test loss: 0.7336
[epoch 020] average test loss: 0.5588
[epoch 025] average test loss: 0.4683
[epoch 030] average test loss: 0.3519
[epoch 035] average test loss: 0.3385
[epoch 040] average test loss: 0.2559
[epoch 045] average test loss: 0.1931
[epoch 050] average test loss: 0.1909
[epoch 055] average test loss: 0.2020
[epoch 060] average test loss: 0.1798
[epoch 065] average test loss: 0.1326
[epoch 070] average test loss: 0.1466
[epoch 075] average test loss: 0.1247
[epoch 080] average test loss: 0.0860
[epoch 085] average test loss: 0.1146
[epoch 090] average test loss: 0.1317
[epoch 095] average test loss: 0.1191


In [79]:
vae.reconstruct_img(data[0:3,:])

tensor([[0.9979, 0.0021, 0.0020],
        [0.9979, 0.0021, 0.0019],
        [0.9979, 0.0021, 0.0023]], grad_fn=<SigmoidBackward>)