In [1]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

Mounted at /gdrive


In [2]:
import torch
import easydict
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

In [3]:
config = {}
config['seed'] = 42
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['num_workers'] = 4

config['model'] = 'VAE'
config['criterion'] = 'ELBO'

config['epoch'] = 10
config['lr'] = 0.1
config['batch_size'] = 32

config['printEvery'] = 100 

config['latent_size'] = 2
config['encoder_layer_channels'] = [784, 392, 196]
config['decoder_layer_channels'] = [196, 392, 784]

args = easydict.EasyDict(config)

In [4]:
import torch
import torch.nn as nn

class VAE(nn.Module):
  def __init__(self, encoder_layer_channels, decoder_layer_channels, latent_size):
    super().__init__()
    self.encoder = VAE_Encoder(encoder_layer_channels, latent_size)
    self.decoder = VAE_Decoder(decoder_layer_channels, latent_size)

  def forward(self, x):
    mean, var = self.encoder(x)
    latent = self.reparameterization(mean, var)
    sampled_x = self.decoder(latent)

    return mean, var, latent, sampled_x

  def inference(self, z):
    sampled_x = self.decoder(z)
    
    return sampled_x

  def reparameterization(self, mean, var):
    std = torch.exp(0.5 * var)
    auxiliaryNoise = torch.randn_like(std) #정규분포로부터의의 랜덤 넘버들을 input과 같은 사이즈의 tensor에 채워서 반환

    return mean+std*auxiliaryNoise

class VAE_Encoder(nn.Module):
  def __init__(self, layer_channels, latent_size):
    super().__init__()

    layers = []
    for idx in range(len(layer_channels)-1):
      layers.append(nn.Linear(layer_channels[idx], layer_channels[idx+1]))
      layers.append(nn.ReLU())
      
    print(layer_channels[-1])
    self.MLP = nn.Sequential(*layers)
    self.meanLayer = nn.Linear(layer_channels[-1], latent_size)
    self.stdLayer = nn.Linear(layer_channels[-1], latent_size)

  def forward(self, x):
    out = self.MLP(x)
    mean = self.meanLayer(out)
    std = self.stdLayer(out)

    return mean, std

class VAE_Decoder(nn.Module):
  def __init__(self, layer_channels, latent_size):
    super().__init__()

    layers = []
    layers.append(nn.Linear(latent_size, layer_channels[0]))
    for idx in range(len(layer_channels)-1):
      layers.append(nn.ReLU())
      layers.append(nn.Linear(layer_channels[idx], layer_channels[idx+1]))
    layers.append(nn.Sigmoid())


    self.MLP = nn.Sequential(*layers)

  def forward(self, z):
    sampled_x = self.MLP(z)

    return sampled_x

In [5]:
torch.manual_seed(args.seed)
if args.device == 'cuda':
    torch.cuda.manual_seed_all(args.seed)

In [6]:
def ELBO_loss (mean, var, x, sampled_x):
  #recon_err
  BCE = torch.nn.functional.binary_cross_entropy(
            recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum')
  
  #regularization
  KLD = -0.5 * torch.sum(1 + var - mean.pow(2) - var.exp())

  return (BCE+KLD) / x.size(0)

In [7]:
#MNIST데이터 로드
dataset = MNIST(
    root='data', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(
    dataset=dataset, batch_size=args.batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
def imshowAndWrite(x):
  plt.figure()
  plt.figure(figsize=(5, 10))
  for p in range(10):
      plt.subplot(5, 2, p+1)
      plt.imshow(x[p].view(28, 28).cpu().data.numpy())
      # plt.axis('off')

      # plt.savefig(
      #     os.path.join(args.fig_root, str(ts),
      #                   "E{:d}I{:d}.png".format(epoch, iteration)),
      #     dpi=300)
      # plt.clf()
      # plt.close('all')

In [9]:
model = VAE(args.encoder_layer_channels, args.decoder_layer_channels, args.latent_size)
model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

for epoch in range(args.epoch):
  for iter, (x, y) in enumerate(data_loader):
    x, y = x.to(args.device), y.to(args.device)
    x = x.view(-1, 28*28)
    mean, var, z, recon_x  = model(x)

    loss = ELBO_loss(mean, var, x, recon_x)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #logs['loss'].append(loss.item())

    if iter% args.printEvery == 0 or iter == len(data_loader)-1:
        print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format(
            epoch, args.epoch, iter, len(data_loader)-1, loss.item()))

        z = torch.randn([10, args.latent_size]).to(args.device)
        x = model.inference(z)
        imshowAndWrite(x)

Output hidden; open in https://colab.research.google.com to view.