<a href="https://colab.research.google.com/github/njucs/notebook/blob/master/AutoEncoder_%26_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###**一个简化版本**

In [None]:
!pip install visdom
#!python -m visdom.server

In [None]:
import torch
#import visdom
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch import nn, optim

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 20] => [b, 64]
            nn.Linear(20, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x

def main():
    mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32)

    print(mnist_train)
    
    epochs = 1000
    lr = 1e-3
    model = AE()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print(model)

#    viz = visdom.Visdom()
#    assert viz.check_connection()

    for epoch in range(epochs):
        # 不需要label，所以用一个占位符"_"代替
        for batchidx, (x, _) in enumerate(mnist_train):
            x_hat = model(x)
            loss = criterion(x_hat, x)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 10 == 0:
            print(epoch, 'loss:', loss.item())
        x, _ = iter(mnist_test).next()
        with torch.no_grad():
            x_hat = model(x)
#        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
#        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()

###**VAE**

In [None]:
import torch
import visdom
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

'''
Encode 以后的变量 h 要分成两半儿，利用 h.chunk(num, dim) 实现，num 表示要分成几块，dim 值表示在什么维度上进行。
然后随机采样出标准正态分布的数据，用 mu 和 sigma 对其进行变换。这里的 kld 指的是 KL Divergence
'''
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # [b, 784] => [b, 20]
        # u: [b, 10]
        # sigma: [b, 10]
        self.encoder = nn.Sequential(
            # [b, 784] => [b, 256]
            nn.Linear(784, 256),
            nn.ReLU(),
            # [b, 256] => [b, 64]
            nn.Linear(256, 64),
            nn.ReLU(),
            # [b, 64] => [b, 20]
            nn.Linear(64, 20),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            # [b, 10] => [b, 64]
            nn.Linear(10, 64),
            nn.ReLU(),
            # [b, 64] => [b, 256]
            nn.Linear(64, 256),
            nn.ReLU(),
            # [b, 256] => [b, 784]
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        :param [b, 1, 28, 28]:
        :return [b, 1, 28, 28]:
        """
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, -1)
        # encoder
        # [b, 20] including mean and sigma
        q = self.encoder(x)
        # [b, 20] => [b, 10] and [b, 10]
        mu, sigma = q.chunk(2, dim=1)
        # reparameterize trick,  epsilon~N(0, 1)
        q = mu + sigma * torch.randn_like(sigma)

        # decoder
        x_hat = self.decoder(q)
        # reshape
        x_hat = x_hat.view(batchsz, 1, 28, 28)

        # KL
        kld = 0.5 * torch.sum(
            torch.pow(mu, 2) +
            torch.pow(sigma, 2) -
            torch.log(1e-8 + torch.pow(sigma, 2)) - 1
        ) / (batchsz*28*28)

        return x_hat, kld

def main():
    mnist_train = datasets.MNIST('mnist', train=True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST('mnist', train=False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32)

    epochs = 1000
    lr = 1e-3
    model = VAE()
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    print(model)

    viz = visdom.Visdom()
    for epoch in range(epochs):
        # 不需要label，所以用一个占位符"_"代替
        for batchidx, (x, _) in enumerate(mnist_train):
            x_hat, kld = model(x)
            loss = criteon(x_hat, x)

            if kld is not None:
                elbo = loss + 1.0 * kld
                loss = elbo

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if epoch % 10 == 0:
            print(epoch, 'loss:', loss.item(), 'kld', kld.item())
        x, _ = iter(mnist_test).next()
        with torch.no_grad():
            x_hat, kld = model(x)
        viz.images(x, nrow=8, win='x', opts=dict(title='x'))
        viz.images(x_hat, nrow=8, win='x_hat', opts=dict(title='x_hat'))

if __name__ == '__main__':
    main()