# mnist variational auto encoder 예제

https://github.com/lyeoni/pytorch-mnist-VAE/blob/master/pytorch-mnist-VAE.ipynb


In [3]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn # 인공 신경망 모델들 모아놓은 모듈
import torch.nn.functional as F #그중 자주 쓰이는것들을 F로
from torchvision import transforms, datasets
import cv2
from torchvision import transforms, datasets
import pandas as pd
import os
from glob import glob
import torchvision.models as models
import sys

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
#DEVICE = torch.device('cpu')
print('Using Pytorch version : ',torch.__version__,' Device : ',DEVICE)

Using Pytorch version :  1.10.2  Device :  cuda


In [4]:
bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [11]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var) # e^(0.5*log_var)
        eps = torch.randn_like(std) # std사이즈와 같은 정규분포 랜덤 z값 생성. 
        
        
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var) 
        # 왜 encoder 결과가 mu와 log_var 인지?
        # 해당 이미지가 mu 에 위치하고, var만큼 퍼져있는 지역에 있다는 의미.
        # z라는 랜덤변수를 추출해서, decoder로.
        
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae=vae.cuda()

In [12]:
vae

VAE(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=2, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

In [19]:

optimizer = torch.optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') #합이 같은지를 비교
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [20]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [22]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [23]:
for epoch in range(1, 51):
    train(epoch)
    test()



====> Epoch: 1 Average loss: 180.3331
====> Test set loss: 163.0279
====> Epoch: 2 Average loss: 158.2077
====> Test set loss: 155.2007
====> Epoch: 3 Average loss: 152.7877
====> Test set loss: 151.4892
====> Epoch: 4 Average loss: 149.4538
====> Test set loss: 148.8531
====> Epoch: 5 Average loss: 147.2732
====> Test set loss: 146.8766
====> Epoch: 6 Average loss: 145.8046
====> Test set loss: 145.5623
====> Epoch: 7 Average loss: 144.5511
====> Test set loss: 145.0618
====> Epoch: 8 Average loss: 143.8190
====> Test set loss: 143.9456
====> Epoch: 9 Average loss: 142.9844
====> Test set loss: 143.6042
====> Epoch: 10 Average loss: 142.2589
====> Test set loss: 143.0996
====> Epoch: 11 Average loss: 141.6128
====> Test set loss: 142.3903
====> Epoch: 12 Average loss: 141.0952
====> Test set loss: 142.0226
====> Epoch: 13 Average loss: 140.5512
====> Test set loss: 141.8452
====> Epoch: 14 Average loss: 140.4849
====> Test set loss: 141.9205
====> Epoch: 15 Average loss: 140.0238
====

====> Epoch: 22 Average loss: 138.0359
====> Test set loss: 139.7064
====> Epoch: 23 Average loss: 138.0366
====> Test set loss: 140.0233
====> Epoch: 24 Average loss: 137.5555
====> Test set loss: 139.2645
====> Epoch: 25 Average loss: 137.6961
====> Test set loss: 139.4900
====> Epoch: 26 Average loss: 137.4585
====> Test set loss: 139.2033
====> Epoch: 27 Average loss: 137.1781
====> Test set loss: 139.3538
====> Epoch: 28 Average loss: 137.1863
====> Test set loss: 139.1786
====> Epoch: 29 Average loss: 137.0940
====> Test set loss: 139.3295
====> Epoch: 30 Average loss: 137.0338
====> Test set loss: 139.2800
====> Epoch: 31 Average loss: 136.7554
====> Test set loss: 139.0135
====> Epoch: 32 Average loss: 136.7249
====> Test set loss: 138.9689
====> Epoch: 33 Average loss: 136.5435
====> Test set loss: 139.7506
====> Epoch: 34 Average loss: 136.3882
====> Test set loss: 139.1953
====> Epoch: 35 Average loss: 136.1520
====> Test set loss: 138.4473
====> Epoch: 36 Average loss: 136.

====> Epoch: 44 Average loss: 135.2151
====> Test set loss: 138.2302
====> Epoch: 45 Average loss: 135.1306
====> Test set loss: 138.4556
====> Epoch: 46 Average loss: 134.8991
====> Test set loss: 137.6444
====> Epoch: 47 Average loss: 134.9415
====> Test set loss: 137.9413
====> Epoch: 48 Average loss: 134.9872
====> Test set loss: 137.9888
====> Epoch: 49 Average loss: 134.8480
====> Test set loss: 138.1786
====> Epoch: 50 Average loss: 134.8789
====> Test set loss: 138.0251


In [28]:
from torchvision.utils import save_image

In [31]:
with torch.no_grad():
    z = torch.randn(64, 2).cuda()
    sample = vae.decoder(z).cuda()
    
    save_image(sample.view(64, 1, 28, 28), './samples/sample_' + '.png')

In [30]:
sample.size()

torch.Size([64, 784])