In [9]:
import  torch
from    torch.utils.data import DataLoader
from    torch import nn, optim
from    torchvision import transforms, datasets
import  visdom
import numpy as np

In [2]:
class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()


        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

        self.criteon = nn.MSELoss()

    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        # [b, 20], including mean and sigma
        h_ = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = h_.chunk(2, dim=1)
        # reparametrize trick, epison~N(0, 1)
        h = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(h)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

In [3]:
mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([transforms.ToTensor()]), download=False)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([transforms.ToTensor()]), download=False)

In [4]:
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

In [5]:
x, _ = iter(mnist_train).next()
print('x:', x.shape)

x: torch.Size([32, 1, 28, 28])


In [None]:
viz = visdom.Visdom(use_incoming_socket=False)

In [6]:
model = VAE()

In [7]:
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
model

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=20, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=784, bias=True)
    (5): Sigmoid()
  )
  (criteon): MSELoss()
)

In [10]:
for epoch in range(1000):
    for batchidx, (x, _) in enumerate(mnist_train):
        # [b, 1, 28, 28]
        x = x
        x_hat, kld = model(x)
        loss = criteon(x_hat, x)

        if kld is not None:
            elbo = - loss - 1.0 * kld
            loss = - elbo

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(epoch, 'loss:', loss.item(), 'kld:', kld.item())

    x, _ = iter(mnist_test).next()
    x = x
    with torch.no_grad():
        x_hat, kld = model(x)
    viz.images(x, nrow=8, win='x', opts=dict(title='x'))
    viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))  

0 loss: 0.07325191050767899 kld: 0.017683329060673714
1 loss: 0.048956796526908875 kld: 0.006203962955623865
2 loss: 0.04813229292631149 kld: 0.007824454456567764
3 loss: 0.04703172296285629 kld: 0.007814062759280205
4 loss: 0.0467158742249012 kld: 0.00792806874960661
5 loss: 0.05149685591459274 kld: 0.008520587347447872
6 loss: 0.050130680203437805 kld: 0.008276141248643398
7 loss: 0.044426411390304565 kld: 0.008716377429664135
8 loss: 0.05141933634877205 kld: 0.008179866708815098
9 loss: 0.0463092215359211 kld: 0.00917123258113861
10 loss: 0.051698166877031326 kld: 0.009055194444954395
11 loss: 0.044617678970098495 kld: 0.008537468500435352
12 loss: 0.04720044881105423 kld: 0.0082466471940279
13 loss: 0.051399149000644684 kld: 0.009027311578392982
14 loss: 0.03806211054325104 kld: 0.0077248793095350266
15 loss: 0.045137397944927216 kld: 0.009074234403669834
16 loss: 0.04852191358804703 kld: 0.008916136808693409
17 loss: 0.048965368419885635 kld: 0.009279765188694
18 loss: 0.051609273

149 loss: 0.046198952943086624 kld: 0.009330139495432377
150 loss: 0.04643646627664566 kld: 0.009423859417438507
151 loss: 0.048245400190353394 kld: 0.009266117587685585
152 loss: 0.04093409329652786 kld: 0.008479460142552853
153 loss: 0.040981318801641464 kld: 0.009870738722383976
154 loss: 0.04640789330005646 kld: 0.009199514985084534
155 loss: 0.043319664895534515 kld: 0.009655829519033432
156 loss: 0.04760870710015297 kld: 0.009455283172428608
157 loss: 0.04425114020705223 kld: 0.00992922205477953
158 loss: 0.046265602111816406 kld: 0.009494359605014324
159 loss: 0.047667764127254486 kld: 0.00899045355618
160 loss: 0.043427638709545135 kld: 0.009602256119251251
161 loss: 0.047472499310970306 kld: 0.009851501323282719
162 loss: 0.04680495336651802 kld: 0.008952760137617588
163 loss: 0.043476518243551254 kld: 0.00965746957808733
164 loss: 0.042239606380462646 kld: 0.009317144751548767
165 loss: 0.041390568017959595 kld: 0.010507535189390182
166 loss: 0.041437502950429916 kld: 0.00992

296 loss: 0.04528867080807686 kld: 0.010691010393202305
297 loss: 0.044421978294849396 kld: 0.009740461595356464
298 loss: 0.04662606865167618 kld: 0.01030943263322115
299 loss: 0.03952508047223091 kld: 0.009123452007770538
300 loss: 0.05017346888780594 kld: 0.010204884223639965
301 loss: 0.04720431566238403 kld: 0.010538021102547646
302 loss: 0.046429142355918884 kld: 0.0099199702963233
303 loss: 0.04683499038219452 kld: 0.009679744951426983
304 loss: 0.049577079713344574 kld: 0.009940994903445244
305 loss: 0.04196963831782341 kld: 0.009566581808030605
306 loss: 0.04659895971417427 kld: 0.009853796102106571
307 loss: 0.04491090029478073 kld: 0.009692935273051262
308 loss: 0.043321870267391205 kld: 0.00941340159624815
309 loss: 0.044671431183815 kld: 0.01030008401721716
310 loss: 0.045286983251571655 kld: 0.009551960974931717
311 loss: 0.0447261743247509 kld: 0.010248369537293911
312 loss: 0.043523158878088 kld: 0.010767542757093906
313 loss: 0.04210415482521057 kld: 0.0091744828969240

KeyboardInterrupt: 