In [1]:
import torch
import json
import os
import sys
import pprint
from glob import glob
import pickle
import numpy as np
from itertools import zip_longest
import io
from einops import rearrange, reduce, repeat
from PIL import Image, ImageDraw
import cv2
from collections import defaultdict

import albumentations as A
from albumentations.pytorch import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torchvision.utils import save_image
from torchvision import datasets, transforms


In [2]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var


In [3]:
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 8, 3, padding='same'),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),

            torch.nn.Conv2d(8, 16, 3, padding='same'),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(16, 8, 2, stride=2),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),

            torch.nn.ConvTranspose2d(8, 1, 2, stride=2),
            torch.nn.BatchNorm2d(1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        
        return decoded

    def predict(self, x):
        pred = self.encoder(x)
        
        return pred

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        mnist = MNIST(root='/home/hebb/ml/datasets/', download=True, train=True)
        self.data = mnist.train_data

    def __getitem__(self, idx):
        image = self.data[idx] / 255
        image = image.reshape(shape=[1, 28, 28])
        return image

    def __len__(self):
        return len(self.data)

In [5]:
def criterion(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='mean')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [7]:
# build model
model = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    model.cuda()

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)


# ds = Dataset()
# loader = torch.utils.data.DataLoader(ds, batch_size=128)
# # criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.01)


In [9]:
model.train()

for epoch in range(100):
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = model(data)
        loss = criterion(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(train_loss)

AttributeError: 'list' object has no attribute 'view'

In [None]:
with torch.no_grad():
    z = torch.randn(64, 2).cuda()
    sample = model.decoder(z).cuda()
    
    save_image(sample.view(64, 1, 28, 28), './samples/sample_' + '.png')


In [None]:
with torch.no_grad():
    latent = torch.randn(10, 20)
    recon = model.decode(latent[6]).squeeze().reshape(28, 28)
    plt.imshow(recon, cmap='gray')

AttributeError: 'VAE' object has no attribute 'decode'