<a href="https://colab.research.google.com/github/haluowan/pytorch/blob/master/AE%26VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms,datasets


In [0]:
class AE(nn.Module):
    def __init__(self):
        super(AE,self).__init__()
        
        self.encoder = nn.Sequential(nn.Linear(784,256),
                                       nn.ReLU(),
                                       nn.Linear(256,64),
                                       nn.ReLU(),
                                       nn.linear(64,20),
                                       nn.ReLU())
        
        self.decoder = nn.Sequential(nn.Linear(20,64),
                                    nn.ReLU(),
                                    nn.linear(64,256),
                                    nn.ReLU(),
                                    nn.Linear(256,784),
                                    nn.Sigmoid())
    def forward(self,x):
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz,784)
        # encoder
        x = self.encoder(x)
        # decoder 
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz,1,28,28)
        return x,None
    

In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        
        # [b,784] => [b,20]
        # u:[b,10]
        # sigma:[b,10]
        self.encoder = nn.Sequential(nn.Linear(784,256),
                                       nn.ReLU(),
                                       nn.Linear(256,64),
                                       nn.ReLU(),
                                       nn.Linear(64,20),
                                       nn.ReLU())
        
        # [b,20] => [b,784]
        self.decoder = nn.Sequential(
                                    nn.Linear(10,64),
                                    nn.ReLU(),
                                    nn.Linear(64,256),
                                    nn.ReLU(),
                                    nn.Linear(256,784),
                                    nn.Sigmoid())
        
        self.criteon = nn.MSELoss()
        
    def forward(self,x):
        
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz,784)
        # encoder
        # [b,20] including mean and sigma
        h_ = self.encoder(x)
        # [b,20] => [b,10] and [b,10]
        mu,sigma = h_.chunk(2,dim=1)
        
        # reparametrize trick,epison~N(0,1)
        h = mu + sigma * torch.randn_like(sigma)
        
        # decoder
        x_hat = self.decoder(h)
        
        # reshape
        x_hat = x_hat.view(batchsz,1,28,28)
        
        kld = 0.5 * torch.sum(torch.pow(mu,2) + torch.pow(sigma,2) - \
                              torch.log(1e-8 + torch.pow(sigma,2)) - 1) / (batchsz*28*28)
        
        return x_hat,kld

    
        

In [0]:
def main():
    train_data = datasets.MNIST(root = 'mnist',
                                train = True,
                                transform = transforms.Compose([transforms.ToTensor()]),
                                download = True,
                                )
    
    test_data = datasets.MNIST(root = 'mnist',
                                train = False,
                                transform = transforms.Compose([transforms.ToTensor()]),
                                download = True,
                                )
    
        
    train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
    test_loader = DataLoader(test_data,batch_size=32,shuffle=True)
    
    x,_ = iter(train_loader).next()
    print('x:',x.shape)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print('model:',model)
    
    for epoch in range(1000):
        for step,(x,_) in enumerate(train_loader):
            x = x.to(device)
            
            x_hat,kld = model(x)
            
            loss = criteon(x_hat,x)
            
            if kld is not None:
                elbo = - loss - 1.0*kld
                loss = -elbo
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print('Epoch:',epoch,'Loss:',loss.item(),'Kld:',kld.item())
        
                    

In [0]:
main()