# Semisupervised Adversarial Autoencoder
_(Requires Python 3, PyTorch 1.0.1, TorchVision 0.2.2)_

**Reference**: _A.Makhzani et al.,_ [Adversarial Autoencoders](https://arxiv.org/abs/1511.05644)

### Libraries
Import torch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.utils.data as Data
import sys

we'll also need numpy and matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Model
The design is essentially the same as the unsupervised adversarial autoencoder, with only one difference: a subset of input samples will have label information!

We will use this extra information to place different numbers in different distributions. For instance, to allocate the 10 digits we can construct a mixture of 10 Gaussian distributions.

#### Sample information as latent space regularizer

The prior for the latent space in this example will be a mixture of 10 Gaussian distributions, with standard deviation 1, and their means equally spaced on a circle of radius 10:

In [None]:
x_mean = 10*np.cos(np.arange(0,10) * 2*np.pi/10)
y_mean = 10*np.sin(np.arange(0,10) * 2*np.pi/10)

samples = np.zeros((10000,2))
for digit in np.arange(10):
        samples[digit*1000:(digit+1)*1000,:] = np.random.multivariate_normal(np.array([x_mean[digit], y_mean[digit]]), np.eye(2), size=1000)

In [None]:
plt.scatter(samples[:,0], samples[:,1], c=np.repeat(np.arange(10), 1000), s=0.7)
plt.xlim(-15,15)
plt.ylim(-15,15)
plt.show()

The label information will be used as a _mixture switch_, _i.e._ to pick one of the distributions from the mixture model, thence forcing the encoder to place the digit in that region of the latent space.

For instance, when the input sample has label `0`, the following distribution will be passed to the discriminator:

In [None]:
plt.scatter(samples[:1000,0], samples[:1000,1], c=np.repeat(np.arange(10), 1000)[:1000], s=0.7)
plt.xlim(-15,15)
plt.ylim(-15,15)
plt.show()

and the same for the digit `7`:

In [None]:
plt.scatter(samples[6000:7000,0], samples[6000:7000,1], c=np.repeat(np.arange(10), 1000)[6000:7000], s=0.7)
plt.xlim(-15,15)
plt.ylim(-15,15)
plt.show()

But what happens when the input sample does not have label information? We will let the encoder choose. But only within our mixture limits! In effect, we will present the full mixture to the discriminator. We expect that the unlabeled ones will be eventually placed close to their class mixture.

So, when the input sample is of `unknown` class, the presented prior will be:

In [None]:
plt.scatter(samples[:,0], samples[:,1], c='gray', s=0.7)
plt.xlim(-15,15)
plt.ylim(-15,15)
plt.show()

#### Parameter check

In [None]:
def parse_args(hidden_dims, activations):
    # Parse input arguments
    if type(hidden_dims) == int:
        hidden_dims = [hidden_dims,]

    if type(activations) == list:
        if len(activations) != len(hidden_dims):
            raise ValueError('activations and hidden_dims must have the same dimensions')
    else:
        activations = [activations]*len(hidden_dims)

    return hidden_dims, activations


### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dims, activations=func.relu, dropout=False, dropout_p=0.3):
        super(Encoder, self).__init__()
        
        # Latent dimensions is hard-coded in this example
        lat_dim = 2
        
        # Parse input arguments
        hidden_dims, activations = parse_args(hidden_dims, activations)

        # Store arguments
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        self.dropout     = dropout
        self.dropout_p   = dropout_p
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = input_dim
        for d in hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Latent layer
        self.out_layer = nn.Linear(prev_d, lat_dim)
        
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x)) if not self.dropout else activation(func.dropout(layer(x), training=self.training, p=self.dropout_p))        
        
        return self.out_layer(x)

### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, hidden_dims, output_dim, activations=func.relu, dropout=False, dropout_p=0.3):
        super(Decoder, self).__init__()
        
        # Latent dimensions is hard-coded in this example
        lat_dim = 2
        
        # Parse input arguments
        hidden_dims, activations = parse_args(hidden_dims, activations)
            
        # Store arguments
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        self.dropout     = dropout
        self.dropout_p   = dropout_p
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = lat_dim
        for d in self.hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Output layer
        self.out_layer = nn.Linear(prev_d, output_dim)

            
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x)) if not self.dropout else activation(func.dropout(layer(x), training=self.training, p=self.dropout_p))
        
        return torch.sigmoid(self.out_layer(x))

### Discriminator

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self, hidden_dims, class_count, activations=func.relu, dropout=False, dropout_p=0.3):
        super(Discriminator, self).__init__()
        
        # Latent dimensions is hard-coded in this example
        lat_dim = 2
        
        # Parse input arguments
        hidden_dims, activations = parse_args(hidden_dims, activations)
        
        # Store arguments
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        self.dropout     = dropout
        self.dropout_p   = dropout_p
        self.class_count = class_count
        
        # One-hot representation
        self.to_one_hot = torch.eye(self.class_count+1)

        # Create layers
        self.layers = nn.ModuleList()
        prev_d = lat_dim + class_count + 1
        for d in self.hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Output layer
        self.out_layer = nn.Linear(prev_d, 1)
    
    def forward(self, z, l_int):
        # Receives sample from latent space (z) and label information (l)
        x = torch.cat((z, self.to_one_hot[l_int]), dim=1)
        
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x)) if not self.dropout else activation(func.dropout(layer(x), training=self.training, p=self.dropout_p))
        
        return torch.sigmoid(self.out_layer(x))

### Semisupervised Adversarial Autoencoder

In [None]:
class SSAAE(nn.Module):
    def __init__(self, encoder_module, decoder_module, discrim_module):
        super(SSAAE, self).__init__()
        
        # Parameters
        self.class_count = discrim.class_count
        
        # AAE modules
        self.encoder = encoder_module
        self.decoder = decoder_module
        self.discrim = discrim_module
        
        # Prior distribution
        x_mean = self.class_count * np.cos(np.arange(0,self.class_count) * 2*np.pi/self.class_count).reshape((-1,1))
        y_mean = self.class_count * np.sin(np.arange(0,self.class_count) * 2*np.pi/self.class_count).reshape((-1,1))
        self.p_mean = np.concatenate((x_mean, y_mean), axis=1)
        
        # Number mapping (place them by similarity)
        digit_order = [0,5,9,4,8,2,1,6,3,7]
        self.p_mean = self.p_mean[digit_order]
        
    # This method generates samples for the desired prior (based on sample label)
    def p_samples(self, l):
        size = l.shape[0]
        
        # Find unknowns
        mean_id = l.detach().numpy()
        no_l = np.where(mean_id == self.class_count)[0]
        # Replace unknowns by random sample of means
        mean_id[no_l] = np.random.randint(0,self.class_count,len(no_l))
        # Generate gaussian mixtures
        s = self.p_mean[mean_id,:] + np.random.randn(size, 2)
            
        return torch.Tensor(s)
        
    def forward(self, x, l):
        # Encoder
        z = self.encoder(x)
        # Decoder
        y = self.decoder(z)
        # Discriminator
        d = self.discrim(z, l) if l is not None else 0
        
        return y, d, z        

## Training example

### Training Data
Let's train the denoising autoencoder on the MNIST dataset.

In [None]:
import torchvision
train = torchvision.datasets.MNIST('./', train=True, download=True, transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.MNIST('./', train=False, download=True, transform=torchvision.transforms.ToTensor())

### Instantiate Autoencoder
Again, we will use the same network presented in the original paper:

In [None]:
input_dim  = 28*28
class_count = 10 # 10 digits (MNIST)
enc_layers = [1000,1000]
dec_layers = [1000,1000]
dis_layers = [1000,1000]

encoder = Encoder(input_dim, enc_layers, dropout=False)
decoder = Decoder(dec_layers, input_dim, dropout=False)
discrim = Discriminator(dis_layers, class_count, dropout=False)

ssaae = SSAAE(encoder, decoder, discrim)

In [None]:
ssaae

### Optimizer
We will follow the same logic as described in episode 5.

In [None]:
encoder_optimizer = torch.optim.Adam(ssaae.encoder.parameters(), lr=0.001)
decoder_optimizer = torch.optim.Adam(ssaae.decoder.parameters(), lr=0.001)
discrim_optimizer = torch.optim.Adam(ssaae.discrim.parameters(), lr=0.001, betas=(0.5, 0.9))

### Training loop

In [None]:
%matplotlib notebook

num_labels = 30000
n_test_img = 6
epochs     = 20
batch_size = 128

# Reshape data
train_samples = train.data.view(-1, 28*28).type(torch.float32)/255.0
test_samples  = test.data.view(-1,28*28).type(torch.float32)/255.0
n_train = train_samples.shape[0]
n_test  = test_samples.shape[0]

# Allow some labels for SemiSupervised Training
num_unknown = n_train - num_labels
train_targets_copy = train.targets.clone()
train.targets[np.random.choice(n_train, replace=False, size=num_unknown)] = 10

# Lists to store training losses
train_loss = []
test_loss  = []

# Plot test input images
test_imgs = test.data[0:n_test_img,:].type(torch.float32).view(-1,28*28)/255.0
f, a = plt.subplots(2, n_test_img, figsize=(8, 3))
for i in range(n_test_img):
    a[0][i].imshow(255-np.reshape(test_imgs.data.numpy()[i], (28,28)), cmap='gray')
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())
    
loss_text = f.text(0, 0.01, "epoch: 0, loss: 0")

# Discriminator targets
target_real = torch.ones(batch_size,1)
target_fake = torch.zeros(batch_size,1)
target_test_real = torch.ones(n_test,1)
target_test_fake = torch.ones(n_test,1)

# Data iterator
train_batches = Data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True, drop_last=True)

# Loss records
aenc_train_loss = np.zeros(epochs)
disc_train_loss = np.zeros(epochs)
disc_train_acc  = np.zeros(epochs)
encd_train_loss = np.zeros(epochs)

aenc_test_loss = np.zeros(epochs)
disc_test_loss = np.zeros(epochs)
disc_test_acc  = np.zeros(epochs)
encd_test_loss = np.zeros(epochs)

# Learning rate Schedulers
encoder_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_optimizer, 'min', patience=10)
decoder_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(decoder_optimizer, 'min', patience=10)
discrim_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(discrim_optimizer, 'min', patience=10)

for e in np.arange(epochs):
    # Batch Loss
    aenc_bloss = 0
    disc_bloss = 0
    encd_bloss = 0
    
    # Batch Accuracy
    disc_bacc = 0
    
    for batch_no, (batch, batch_label) in enumerate(train_batches):
        # Input and target data (flatten)
        batch_input = batch.view(-1, 28*28)
        
        ## Train discriminator ##
        # Samples
        ssaae.eval()
        z_fake = ssaae.encoder(batch_input)
        z_real = ssaae.p_samples(batch_label)
        
        # Forward data
        ssaae.train()
        d = ssaae.discrim(torch.cat((z_fake.detach(), z_real), 0), torch.cat((batch_label, batch_label), 0))
        
        # Discriminator loss
        disc_loss = func.binary_cross_entropy(d, torch.cat((target_fake, target_real),0))
        disc_bloss += disc_loss
        disc_bacc += torch.sum((d[:batch_size] < 0.5) + (d[batch_size:] > 0.5)).data.numpy() / (2*batch_size)
        
        # Compute gradients
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        discrim_optimizer.zero_grad()
        disc_loss.backward()
        
        # Update discriminator parameters
        discrim_optimizer.step()
        
        
        ## Train autoencoder ##
        # Forward pass of the data through the network
        y, d, z = ssaae(batch_input, batch_label)
        # Autoencoder loss
        aenc_loss = func.binary_cross_entropy(y, batch_input)
        fool_loss = func.binary_cross_entropy(d, target_real)
        encd_loss = 0.99*aenc_loss + 0.01*fool_loss
        #aenc_loss = func.mse_loss(y, batch_input)
        aenc_bloss += aenc_loss
        encd_bloss += encd_loss
        # Compute gradients
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        discrim_optimizer.zero_grad()
        encd_loss.backward()
        # Update encoder-decoder parameters
        encoder_optimizer.step()
        decoder_optimizer.step()

        # Test images
        if batch_no % 50 == 0:
            ssaae.eval()
            test_out, _, _ = ssaae(test_imgs, None)
            for i in range(n_test_img):
                a[1][i].imshow(1.0-np.reshape(test_out.data.numpy()[i], (28,28)), cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            #a[1][-1].clear()
            #a[1][-1].scatter(z.data.numpy()[:,0], z.data.numpy()[:,1], c=batch_labels, s=0.5)
            loss_text.set_text("epoch: {}, reconstruct: {:.3f}, encoder: {:.3f}, discrim: {:.3f} ({:.2f}%)".format(
                e+1,
                aenc_bloss/(batch_no+1),
                encd_bloss/(batch_no+1),
                disc_bloss/(batch_no+1),
                disc_bacc/(batch_no+1)*100))
            f.canvas.draw()

    ## End of epoch
    # Train loss
    aenc_train_loss[e] = aenc_bloss/batch_no
    disc_train_loss[e] = disc_bloss/batch_no
    encd_train_loss[e] = encd_bloss/batch_no
    disc_train_acc[e] = disc_bacc/batch_no
    # Decrease encoder dist LR
    encoder_sched.step(encd_train_loss[e])
    decoder_sched.step(aenc_train_loss[e])
    discrim_sched.step(disc_train_acc[e])
    # Test loss
    # autoencoder loss
    ssaae.eval()
    y, d, z_fake = ssaae(test_samples, test.targets)
    aenc_test_loss[e] = func.binary_cross_entropy(y, test_samples)
    # fool loss
    encd_test_loss[e] = 0.99*aenc_test_loss[e] + 0.01*func.binary_cross_entropy(d, target_test_real)
    # discriminator loss
    z_real = ssaae.p_samples(test.targets)
    d = ssaae.discrim(torch.cat((z_fake.detach(), z_real), 0), torch.cat((test.targets, test.targets),0))
    disc_test_loss[e] = func.binary_cross_entropy(d, torch.cat((target_test_fake, target_test_real),0))
    disc_test_acc[e]  = torch.sum((d[:n_test] < 0.5) + (d[n_test:] > 0.5)).data.numpy() / (2*n_test)
    
train.targets = train_targets_copy

## Results
### Loss plots
#### Reconstruction loss

In [None]:
plt.figure()
plt.plot(aenc_train_loss)
plt.plot(aenc_test_loss)
plt.legend(['train', 'test'])

#### Discriminator Loss

In [None]:
plt.figure()
plt.plot(disc_train_loss)
plt.plot(disc_test_loss)
plt.legend(['train', 'test'])

#### Discriminator Accuracy

In [None]:
plt.figure()
plt.plot(disc_train_acc)
plt.plot(disc_test_acc)
plt.legend(['train', 'test'])

#### Encoder Distribution Loss

In [None]:
plt.figure()
plt.plot(encd_train_loss)
plt.plot(encd_test_loss)
plt.legend(['train', 'test'])

### Latent space
The projections of the test images in the latent space after training look like this:

In [None]:
test_d = ssaae.encoder(test_samples)
plt.figure(figsize=(10,7))
plt.scatter(test_d.data.numpy()[:,0], test_d.data.numpy()[:,1], c=test.targets.numpy(), s=1.7)

In [None]:
test_d = ssaae.encoder(train_samples)
plt.figure(figsize=(10,7))
plt.scatter(test_d.data.numpy()[:,0], test_d.data.numpy()[:,1], c=train.targets.numpy(), s=1.7)

### Image generation
Generate images sampling the latent space:

In [None]:
x_samples = np.arange(10,-10.1,-0.5) # 41 samples from -2 to 2
y_samples = np.arange(-10,10.1,0.5) # 41 samples from -2 to 2

imgs = np.zeros((28*41, 28*41))

for y_idx, y in enumerate(y_samples):
    for x_idx, x in enumerate(x_samples):
        y_img = 1.0 - ssaae.decoder(torch.Tensor([x,y])).view(-1,28,28)
        imgs[(x_idx*28):((x_idx+1)*28), (y_idx*28):((y_idx+1)*28)] = y_img.detach().numpy()
    
plt.figure(figsize=(10,10))
plt.imshow(imgs, cmap=plt.cm.gray)