# Autoencoding beyond pixels using a learned similarity metric
* paper: [Autoencoding beyond pixels using a learned similarity metric](https://arxiv.org/pdf/1512.09300.pdf)

## Implementation Details
* **Training Data**: `FGVC-Aircraft`
* **Goal**: Generate Random New Images of Aircraft
* **Encoder Layers**: *In progress*
* **Decoder Layers**: *In progress*

## Notes
* using the results of this to compare against results from basic implementation of vanilla [VAE](https://github.com/bradley-ray/implementing-deep-learn-papers/blob/master/variational_autoencoder.ipynb)

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as T

from torchvision import datasets
from torch.utils.data import DataLoader, Subset


In [2]:
BS = 64
EPOCHS = 10

# Get Dataset

In [3]:
# FGVC-Aircraft Dataset

img_size = (64, 64)
ts = [T.ToTensor(), T.Resize(img_size)]
train = datasets.FGVCAircraft(
    root='data',
    split='train',
    download=True,
    transform=T.Compose(ts) 
)

test = datasets.FGVCAircraft(
    root='data',
    split='test',
    download=True,
    transform=T.Compose(ts)
)


use_subset = True
if use_subset:
    subset = Subset(train, range(750))
    train_loader = DataLoader(subset, batch_size=BS, shuffle=True)
else:
    train_loader = DataLoader(train, batch_size=BS, shuffle=True)

Downloading https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz to data/fgvc-aircraft-2013b.tar.gz


  0%|          | 0/2753340328 [00:00<?, ?it/s]

Extracting data/fgvc-aircraft-2013b.tar.gz to data


# Network

In [5]:
# currently just using architecture used in paper w/ LeakyReLU instead of ReLU
class Encoder(nn.Module):
    def __init__(self, img_channels: int=3, z_dim: int=2048, flat_size: int=8*8*256):
        super().__init__()
        self.downsamples = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(.2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(.2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(.2),
        )

        self.mean = nn.Linear(flat_size, z_dim)
        
        self.log_var = nn.Sequential(
            nn.Linear(flat_size, z_dim),
            nn.Tanh(),
        )
    
    def forward(self, X):
        flatten = torch.flatten(self.downsamples(X), 1, 3)
        print('flatten:\t', flatten.shape)
        mean = self.mean(flatten)
        log_var = self.log_var(flatten)
        return mean, log_var

class Generator(nn.Module):
    def __init__(self, img_channels: int=3, z_dim: int=2048, flat_size: int=8*8*256):
        super().__init__()

        self.fc = nn.Sequential(
            nn.Linear(z_dim, flat_size),
            nn.BatchNorm1d(flat_size),
            nn.LeakyReLU(.2),
        )

        # TODO: experiment with tanh on last layer with some other activation func
        # TODO: investigate convtranspose dim calculations (had to lookup to figure out how to get dim to match)
        self.upsamples = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(.2),

            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(.2),

            nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(.2),

            nn.ConvTranspose2d(in_channels=32, out_channels=img_channels, kernel_size=5, padding=2),
            nn.Tanh(),
        )

    def forward(self, X, size=(256,8,8)):
        fc_out = self.fc(X)
        print('fc_out:\t', fc_out.shape)
        reshape = fc_out.view(-1, size[0], size[1], size[2])
        out = self.upsamples(reshape)
        return out

class Discriminator(nn.Module):
    def __init__(self, img_channels: int=3, z_dim: int=2048, flat_size: int=8*8*256):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=32, kernel_size=5, padding=2),
            nn.LeakyReLU(.2),
        )

        # l layer for dis_l loss is choosen for after 3 downsamples in discriminator
        self.downsamples = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(.2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(.2),

            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(.2),
        )

        self.fc = nn.Sequential(
            nn.Linear(flat_size, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(.2),

            nn.Linear(512, 1),
            nn.Sigmoid()
        )


    def forward(self, z):
        init_out = self.initial(z)
        # downsamples_out needed for dis_l loss
        downsamples_out = self.downsamples(init_out)
        flatten = torch.flatten(downsamples_out, 1, 3)
        out = self.fc(flatten)
        return out, downsamples_out

# Training

In [None]:
# Loss functions
def kl_loss():
    ...

def loss_gan():
    ...

def like_loss():
    ...

def get_sample():
    ...

# Training function
def train_model():
    ...

In [None]:
train_model(...)

# Testing