[ref](https://blog.paperspace.com/adversarial-autoencoders-with-pytorch/)

In [4]:
import argparse
import pickle
import numpy as np

In [5]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [11]:
parser = { 
    'n_classes': 10,
    'z_dim': 2,
    'X_dim': 784,
    'y_dim': 10,
    'N': 1000,
    'batch_size' : 16,
    'epochs' : 10,
    'no_cuda' : True,
    'seed' : 7,
    'log_interval' : 10,
    'h_dim': 200,
    'z_dim': 10,
    'lr': 0.001,
    'betas': (0.9, 0.999),   
    'lr_decay': 0.95,
}

args = argparse.Namespace(**parser)

In [14]:
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual(args.seed)
    
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

# Define networks

In [15]:
# Encoder
class Q_net(nn.Module):
    def __init__(self):
        super(Q_net, self).__init__()
        self.lin1 = nn.Linear(args.X_dim, args.N)
        self.lin2 = nn.Linear(args.N, args.N)
        # Gaussian code (z)
        self.lin3gauss = nn.Linear(args.N, args.z_dim)
    
    # x --> lin1 --> relu --> lin2 ---> relu --> lin3gauss --> xgauss
    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)
        xgauss = self.lin3gauss(x)

        return xgauss

In [16]:
# Decoder
class P_net(nn.Module):
    def __init__(self):
        super(P_net, self).__init__()
        self.lin1 = nn.Linear(args.z_dim, args.N)
        self.lin2 = nn.Linear(args.N, args.N)
        self.lin3 = nn.Linear(args.N, args.X_dim)

    # z --> lin1 ---> relu --> lin2 --> lin3 --> sigmoid
    def forward(self, x):
        x = self.lin1(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin3(x)
        return F.sigmoid(x)

In [17]:
class D_net_gauss(nn.Module):
    def __init__(self):
        super(D_net_gauss, self).__init__()
        self.lin1 = nn.Linear(args.z_dim, args.N)
        self.lin2 = nn.Linear(args.N, args.N)
        self.lin3 = nn.Linear(args.N, 1)
    
    # z --> lin1 --> relu --> lin2 --> relu --> lin3 --> sigmoid
    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)

        return F.sigmoid(self.lin3(x))

# Utility function

In [18]:
def save_model(model, filename):
    print('Best model so far, saving it...')
    torch.save(model.state_dict(), filename)

In [19]:
def report_loss(epoch, D_loss_gauss, G_loss, recon_loss):
    '''
    Print loss
    '''
    print('Epoch-{}; D_loss_gauss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'.format(epoch,
                                                                                   D_loss_gauss.data[0],
                                                                                   G_loss.data[0],
                                                                                   recon_loss.data[0]))

In [20]:
def create_latent(Q, loader):
    '''
    Creates the latent representation for the samples in loader
    return:
        z_values: numpy array with the latent representations
        labels: the labels corresponding to the latent representations
    '''
    Q.eval()
    labels = []

    for batch_idx, (X, target) in enumerate(loader):

        X = X * 0.3081 + 0.1307
        # X.resize_(loader.batch_size, X_dim)
        X, target = Variable(X), Variable(target)
        labels.extend(target.data.tolist())
        if cuda:
            X, target = X.cuda(), target.cuda()
        # Reconstruction phase
        z_sample = Q(X)
        if batch_idx > 0:
            z_values = np.concatenate((z_values, np.array(z_sample.data.tolist())))
        else:
            z_values = np.array(z_sample.data.tolist())
    labels = np.array(labels)

    return z_values, labels
