In [1]:
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
batch_size = 128

In [3]:
# Data
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

train_dataset = datasets.MNIST('./data', transform=transform, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST('./data', transform=transform, download=True, train=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
# utils
def one_hot(x, max_x):
    return torch.eye(max_x+1)[x]

In [5]:
x = torch.tensor([3, 4, 5])
y = torch.tensor([
    [0, 0, 0, 1, 0, 0, 0]
])

In [6]:
def calculate_loss(model, dataloader, loss_fn=nn.MSELoss(), flatten=True,
                  conditional=False):
    losses = []
    for batch, labels in dataloader:
        batch = batch.to(device)
        labels = labels.to(device)
        
        if flatten:
            batch = batch.view(batch.size(0), 28*28)
        if conditional:
            loss = loss_fn(batch, model(batch, labels))
        else:
            loss = loss_fn(batch, model(batch))
        
        losses.append(loss)
    
    return (sum(losses)/len(losses)).item()

In [7]:
def show_visual_progress(model, test_dataloader, row=5, flatten=True,
                         vae=False, conditional=False, title=None):
    if title:
        plt.title(title)
    
    iter(test_dataloader)
    
    image_rows = []
    
    for idx, (batch, label) in enumerate(test_dataloader):
        if row == idx:
            break
        
        batch = batch.to(device)
        if flatten:
            batch = batch.view(batch.size(0), 28*28)
        
        if not conditional:
            images = model(batch).detach().cpu().numpy().reshape(batch.size(0), 28, 28)
        else:
            images = model(batch, label).detach().cpu().numpy().reshape(batch.size(0), 28, 28)
        
        image_idxs = [list(label.numpy()).index(x) for x in range(10)]
        combined_images = np.concatenate([images[x].reshape(28, 28) for x in image_idxs], 1)
        
        image_rows.append(combined_images)
        plt.imshow(np.concatenate(image_rows))
        
        plt.show()

In [11]:
def evaluate(losses, autoencoder, dataloader, flatten=True, vae=False,
             conditional=False):
    if vae and conditional:
        model = lambda x, y: autoencoder(x, y)[0]
    elif vae:
        model = lambda x: autoencoder(x)[0]
    else:
        model = autoencoder
    
    loss = calculate_loss(model, dataloader, flatten=flatten, conditional=conditional)
#     show_visual_progress(model, test_dataloader, flatten=flatten, vae=vae, 
#                          conditional=conditional)
    print(loss)
    losses.append(loss)

In [9]:
def train(net, dataloader, test_dataloader, epochs=5, flatten=False,
          loss_fn=nn.MSELoss()):
    optim = torch.optim.Adam(net.parameters())
    train_losses = []
    validation_losses = []
    for i in range(epochs):
        for batch, labels in dataloader:
            batch = batch.to(device)
            if flatten:
                batch = batch.view(batch.size(0), 28*28)
            
            optim.zero_grad()
            loss = loss_fn(batch, net(batch))
            loss.backward()
            optim.step()
            
            train_losses.append(loss.item())
        
        evaluate(validation_losses, net, test_dataloader, flatten)

In [10]:
def calculate_nparameters(model):
    def times(shape):
        parameters = 1
        for layer in list(shape):
            parameters *= layer
        
        return parameters
    layer_params = [times(x.size()) for x in list(model.parameters())]
    
    return sum(layer_params)

# Vanilla AE

In [12]:
class Autoencoder(nn.Module):
    def __init__(self, input_size, hidden=10):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, hidden))
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden, 128),
            nn.ReLU(),
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, input_size))
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [13]:
autoencoder = Autoencoder(28*28).to(device)
train(autoencoder, train_dataloader, test_dataloader, epochs=10, flatten=True)

0.10473746061325073
0.09068146347999573
0.08406133204698563
0.0804152637720108
0.07708583027124405
0.07402276992797852
0.07214543223381042
0.0711098313331604
0.06941736489534378
0.0687880888581276


# Convolutional AE

In [14]:
class CNNAutoencoder(nn.Module):
    def __init__(self):
        super(CNNAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, (3, 3), stride=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(16, 8, (3, 3), stride=2, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=1))
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, (3, 3), stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, (5, 5), stride=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),
            nn.Tanh())
    
    def forward(self, x):
        x = self.decoder(self.encoder(x))
        
        return x

In [16]:
cnn_ae = CNNAutoencoder().to(device)
train(cnn_ae, train_dataloader, test_dataloader, epochs=20)

0.20470894873142242
0.15844248235225677
0.14456099271774292
0.1364603340625763
0.12973830103874207
0.12505628168582916
0.12228311598300934
0.11984138190746307
0.11817315220832825
0.1166527271270752
0.11594672501087189
0.11423417925834656
0.11302926391363144
0.11225937306880951
0.11140009760856628
0.11065704375505447
0.10954054445028305
0.10923619568347931
0.10844851285219193
0.10860192030668259


# VAE

In [21]:
class VAE(nn.Module):
    def __init__(self, input_size):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.fc21 = nn.Linear(512, 10)
        self.fc22 = nn.Linear(512, 10)
        
        self.relu = nn.ReLU()
        self.fc3 = nn.Linear(10, 512)
        self.fc4 = nn.Linear(512, input_size)
    
    def encoder(self, x):
        x = self.relu(self.fc1(x))
        return self.fc21(x), self.fc22(x)
    
    def decoder(self, z):
        z = self.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(z))
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x = self.decoder(z)
        
        return x, mu, logvar        

In [18]:
def vae_loss_fn(x, recon_x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5*torch.sum(1+logvar-mu.pow(2)-logvar.exp())
    
    return BCE+KLD

In [19]:
def train_vae(net, dataloader, test_dataloader, flatten=True, epochs=10):
    validation_losses = []
    optim = torch.optim.Adam(net.parameters())
    
    for i in range(epochs):
        for batch in dataloader:
            batch = batch[0].to(device)
            if flatten:
                batch = batch.view(batch.size(0), 28*28)
            
            optim.zero_grad()
            x, mu, logvar = net(batch)
            loss = vae_loss_fn(batch, x, mu, logvar)
            loss.backward()
            optim.step()
        
        evaluate(validation_losses, net, test_dataloader, vae=True)

In [22]:
vae = VAE(28*28).to(device)
train_vae(vae, train_dataloader, test_dataloader)

0.9095661044120789
0.8928028345108032
0.8846414089202881
0.8809676170349121
0.8797862529754639
0.8788581490516663
0.8773050308227539
0.8755577206611633
0.8746713399887085
0.8739920258522034


In [24]:
vae[0]

TypeError: 'VAE' object is not subscriptable