In [None]:
import  torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def plot_batch(x):
    if type(x) == torch.Tensor:
        x = x.cpu().numpy()

    n_cols = x.shape[0]
    fig, axs = plt.subplots(1, n_cols)
    for idx in range(x.shape[0]):
        ax = axs[idx]
        img = x[idx]
        ax.imshow(img, cmap="gray")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

def get_latent_img(decoder, n=30, d_z=2, random_samples=False):
    x_limit = np.linspace(-2, 2, n)
    y_limit = np.linspace(-2, 2, n)
    #
    img = np.empty((28 * n, 28 * n))
    #
    for i, zi in enumerate(x_limit):
        for j, pi in enumerate(y_limit):
            
            if random_samples:
                latent_layer = np.random.normal(0, 1, size=[1, d_z])
            else:
                latent_layer = np.array([[zi, pi]])
            
            latent_layer = torch.Tensor(latent_layer)
            latent_layer = latent_layer.to(device)
            with torch.no_grad():
                x_gen = decoder(latent_layer)
            x_gen = x_gen.cpu().numpy()
            img[(n-i-1)*28:(n-i)*28,
                j*28:(j+1)*28] = x_gen[0].reshape((28, 28))
    return img

def scatter_with_legend(Z, Y):
    x = Z[:, 0]
    y = Z[:, 1]
    classes = Y
    
    unique = np.unique(classes)
    colors = [plt.cm.jet(i/float(len(unique)-1)) for i in range(len(unique))]
    for i, u in enumerate(unique):
        xi = [x[j] for j  in range(len(x)) if classes[j] == u]
        yi = [y[j] for j  in range(len(x)) if classes[j] == u]
        plt.scatter(xi, yi, color=colors[i], label=u, alpha=0.5)
    plt.legend()

    plt.show()

class Encoder(nn.Module):
    def __init__(self, d_in, d_hidden, d_z):
        super(Encoder, self).__init__()
        
        self.d_z = d_z
        
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            #nn.Linear(d_hidden, d_hidden),
            #nn.ReLU(),
        )
        self.net_mean = nn.Sequential(
            nn.Linear(d_hidden, d_z),
            #nn.ReLU()
        )
        self.net_logvar = nn.Sequential(
            nn.Linear(d_hidden, d_z),
            #nn.Softplus()
        )
    def forward(self, x):
        inter = self.net(x)
        z_mu = self.net_mean(inter)
        
        # we predict log_var = log(std**2)
        # -> std = exp(0.5 * log_var)
        # -> alternative is to directly predict std ;)
        z_logvar = self.net_logvar(inter)
        
        return z_mu, z_logvar

class Decoder(nn.Module):
    def __init__(self, d_z, d_hidden, d_out):
        super(Decoder, self).__init__()
        #
        self.net = nn.Sequential(
            nn.Linear(d_z, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

def reparametrize(mu, logvar):
    eps = torch.randn_like(logvar)
    std = torch.exp(0.5 * logvar)
    return eps * std + mu

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

        
def fn_loss_rec(x_ori, x_rec):
    # Alternative MSE
    #loss = F.mse_loss(x_ori, x_rec)
    
    # Binary Cross Entropy
    # loss = F.binary_cross_entropy(x_rec, x_ori, reduction='sum')
    loss = x_ori * torch.log(1e-10 + x_rec) + (1 - x_ori) * torch.log(1 - x_rec)
    loss = - torch.sum(loss, axis=1)
    
    return loss

def fn_loss_kld(mu, logvar):
    #
    loss = 1 + logvar - mu ** 2 - logvar.exp()
    loss = -0.5 * torch.sum(loss, dim = 1)
    return loss

def fn_loss(x_ori, x_rec, mu, logvar, alpha=1, beta=1):
    
    loss_rec = fn_loss_rec(x_ori, x_rec).mean()
    loss_kld = fn_loss_kld(mu, logvar).mean()
    
    loss = alpha * loss_rec + beta * loss_kld
    return loss, loss_rec, loss_kld

In [None]:
lr = 1e-3
num_epochs = 20
batch_size = 128
#
d_in = 28**2
d_z = 2
d_hidden = 512
#
alpha = 1
beta = 1
#
device = "cuda:0"

In [None]:
# DATA
dl_train = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=T.Compose([
                               T.ToTensor(),
                               #T.Normalize((0.1307,), (0.3081,)),
                               T.Lambda(lambda x: torch.flatten(x))
                             ])),
  batch_size=batch_size, shuffle=True)

dl_valid = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=T.Compose([
                               T.ToTensor(),
                               #T.Normalize((0.1307,), (0.3081,)),
                               T.Lambda(lambda x: torch.flatten(x))
                             ])),
  batch_size=batch_size, shuffle=True)

In [None]:
#### RECONSTRUCTION SAMPLE IMGS
n_vis = 8
x_vis, _ = next(iter(dl_valid))
x_vis = x_vis[:n_vis, :]
x_vis_ori = torch.reshape(x_vis, (n_vis, 28, 28)).numpy()
#
plot_batch(x_vis_ori)

# NOISE IMGS

In [None]:
# MODEL
encoder = Encoder(d_in, d_hidden, d_z)
decoder = Decoder(d_z, d_hidden, d_in)
#
encoder.apply(init_weights)
decoder.apply(init_weights)

# OPTIMIZER
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)

In [None]:
alpha = 1
beta = 10
#
encoder = encoder.to(device)
decoder = decoder.to(device)
#
step = 0
#
for epoch_idx in range(num_epochs):
    encoder.train()
    decoder.train()
    for x, _ in dl_train:
        x = x.to(device)
        mu, logvar = encoder(x)
        z = reparametrize(mu, logvar)
        x_rec = decoder(z)
        
        loss, loss_rec, loss_kld = fn_loss(x, x_rec, mu, logvar, alpha, beta)
        #
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #
        
        if step % 100 == 0:
            print_str = "Epoch [{:4}/{:4}] step {:5}: Loss: {:.4f}, REC: {:.4f} KLD: {:.4f}"
            print_str = print_str.format(epoch_idx, num_epochs, step, loss.item(),
                                        loss_rec.item(), loss_kld.item())
            print(print_str)
        if step % 500 == 0:
            encoder.eval()
            decoder.eval()
            with torch.no_grad():
                x_vis = x_vis.to(device)
                mu, logvar = encoder(x_vis)
                z = reparametrize(mu, logvar)
                x_vis_rec = decoder(z)
                x_vis_rec = x_vis_rec.cpu().numpy().reshape((-1, 28, 28))
            plot_batch(x_vis_ori)
            plot_batch(x_vis_rec)
        
        if step % 500 == 0 and d_z == 2:
            img = get_latent_img(decoder, random_samples=False)
            fig, axs = plt.subplots(1, 1, figsize=(8, 8))
            axs.imshow(img, cmap="gray")
            plt.show()
        if step % 500 == 0:
            img = get_latent_img(decoder, random_samples=True)
            fig, axs = plt.subplots(1, 1, figsize=(8, 8))
            axs.imshow(img, cmap="gray")
            plt.show()
        if step % 500 == 0 and d_z == 2:
            Z = []
            Y = []
            with torch.no_grad():
                for x, y in dl_valid:
                    x = x.to(device)
                    mu, logvar = encoder(x)
                    z = reparametrize(mu, logvar)
                    Z.append(z.cpu().numpy())
                    Y.append(y.numpy())
            Z = np.concatenate(Z)
            Y = np.concatenate(Y)
            scatter_with_legend(Z, Y)
        step += 1