# MNIST AutoEncoder
Implementation of vanilla (no CNN) AutoEncoder.
#### Losses
* Reconstruction Loss (ie: MSE)

#### References
* [Paper](https://arxiv.org/pdf/1511.05644.pdf)
* https://github.com/neale/Adversarial-Autoencoder
* https://github.com/bfarzin/pytorch_aae
* https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/

In [1]:
import mnist_data_pytorch as data
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', device)
print('Pytorch version:', torch.__version__)
# Tensorboard
from torch.utils.tensorboard import SummaryWriter
!rm -rf ./runs
writer = SummaryWriter('./runs/train')

# Metaparameters
num_epochs = 30
latent_size = 50
gen_lr = 0.001

Device: cuda:0
Pytorch version: 1.2.0


#### Define Encoder/Decoder/Discriminator

In [2]:
class Encoder(nn.Module):  
    def __init__(self, X_dim, z_dim):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(X_dim, 1000)
        self.lin2 = nn.Linear(1000, 1000)
        self.latent = nn.Linear(1000, z_dim)
    
    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.25, training=self.training)
        x = F.relu(x)
        z = self.latent(x)
        return z
    

class Decoder(nn.Module):  
    def __init__(self, X_dim, z_dim):
        super(Decoder, self).__init__()
        self.lin1 = nn.Linear(z_dim, 1000)
        self.lin2 = nn.Linear(1000, 1000)
        self.lin3 = nn.Linear(1000, X_dim)
    
    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.25, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.25, training=self.training)
        x = self.lin3(x)
        x = torch.sigmoid(x)
        return x
    
# Initialize Networks
encoder = Encoder(784, latent_size).to(device)
decoder = Decoder(784, latent_size).to(device)

#### Initialize Optimizers

In [3]:
optim_encoder = torch.optim.Adam(encoder.parameters(), lr=gen_lr)
optim_decoder = torch.optim.Adam(decoder.parameters(), lr=gen_lr)

#### Train loop

In [4]:
for epoch in tqdm(range(num_epochs)):
    running_loss_discriminator = 0.0
    running_loss_generator = 0.0
    running_loss_reconstruction = 0.0
    # Iterate over the data
    for idx_sample, (inputs, _) in enumerate(data.dataloaders['train']):
        inputs = inputs.to(device)
        inputs = torch.flatten(inputs, start_dim=1, end_dim=-1)
        
        # Zero gradients
        optim_encoder.zero_grad()
        optim_decoder.zero_grad()
        
        # Concatenate label (one-hot) with latent
        z_sample = encoder(inputs)
        
        # Reconstruct X
        inputs_reconstruct = decoder(z_sample)
        # We can use MSE or Huber Loss
        #reconstruct_loss = F.mse_loss(inputs_reconstruct , inputs)
        reconstruct_loss = F.smooth_l1_loss(inputs_reconstruct , inputs)
        
        # Backprop from reconstruction loss
        reconstruct_loss.backward()
        # Optimizer Encoder/Decoder
        optim_encoder.step()
        optim_decoder.step()
        
        # Update statistics
        running_loss_reconstruction += reconstruct_loss.item() * inputs.size(0)
    
    # Epoch ends
    epoch_loss_reconstruction = running_loss_reconstruction / len(data.dataloaders['train'].dataset)
    
    # Send results to tensorboard
    writer.add_scalar('train/reconstruction', epoch_loss_reconstruction, epoch)
    
    # Send images to tensorboard
    writer.add_images('train/decoder_images', inputs_reconstruct.view(inputs.size(0),1,28,28), epoch)
    writer.add_images('train/input_images', inputs.view(inputs.size(0),1,28,28), epoch)
    
    # Send latent to tensorboard
    writer.add_histogram('train/latent', z_sample, epoch)
    writer.add_histogram('train/reconstruct_images_h', inputs_reconstruct, epoch)
    writer.add_histogram('train/input_images_h', inputs, epoch)
    

100%|██████████| 30/30 [01:46<00:00,  3.56s/it]
