In [1]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import tensorflow as tf

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images = train_images/255

mb_size = 64
Z_dim = 100
X_dim = train_images.shape[1]**2
h_dim = 128
c = 0
lr = 1e-3

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


In [2]:
# =============================== Q(z|X) ======================================

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whz_mu = xavier_init(size=[h_dim, Z_dim])
bhz_mu = Variable(torch.zeros(Z_dim), requires_grad=True)

Whz_var = xavier_init(size=[h_dim, Z_dim])
bhz_var = Variable(torch.zeros(Z_dim), requires_grad=True)


def Q(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    z_mu = h @ Whz_mu + bhz_mu.repeat(h.size(0), 1)
    z_var = h @ Whz_var + bhz_var.repeat(h.size(0), 1)
    return z_mu, z_var


def sample_z(mu, log_var):
    eps = Variable(torch.randn(mu.shape[0], Z_dim))
    return mu + torch.exp(log_var / 2) * eps


In [3]:
# =============================== P(X|z) ======================================

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def P(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


In [4]:
train_images.shape

(60000, 28, 28)

In [5]:
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).float()
        self.data = torch.flatten(self.data, start_dim=1)
        self.target = torch.from_numpy(target).float()
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
        
        return x, y
    
    def __len__(self):
        return len(self.data)

In [6]:
dataset = MyDataset(train_images, train_labels)
loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

for batch_idx, (data, target) in enumerate(loader):
    print('Batch idx {}, data shape {}, target shape {}'.format(
        batch_idx, data.shape, target.shape))


Batch idx 0, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 1, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 2, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 3, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 4, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 5, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 6, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 7, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 8, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 9, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 10, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 11, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 12, data shape torch.Size([64, 784]), target shape torch.Size([6

Batch idx 367, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 368, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 369, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 370, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 371, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 372, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 373, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 374, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 375, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 376, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 377, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 378, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 379, data shape torch.Size([64, 784]), tar

Batch idx 730, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 731, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 732, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 733, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 734, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 735, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 736, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 737, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 738, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 739, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 740, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 741, data shape torch.Size([64, 784]), target shape torch.Size([64])
Batch idx 742, data shape torch.Size([64, 784]), tar

In [None]:
# =============================== TRAINING ====================================

params = [Wxh, bxh, Whz_mu, bhz_mu, Whz_var, bhz_var,
          Wzh, bzh, Whx, bhx]

solver = optim.Adam(params, lr=lr)

for it in range(100):
    for x_batch, y_batch in loader:
        
        # Forward
        z_mu, z_var = Q(x_batch)
        z = sample_z(z_mu, z_var)
        x_sample = P(z)
            

        # Loss
        recon_loss = nn.binary_cross_entropy(x_sample, x_batch, size_average=False) / mb_size
        kl_loss = torch.mean(0.5 * torch.sum(torch.exp(z_var) + z_mu**2 - 1. - z_var, 1))
        loss = recon_loss + kl_loss

        # Backward
        loss.backward()

        # Update
        solver.step()

        # Housekeeping
        for p in params:
            if p.grad is not None:
                data = p.grad.data
                p.grad = Variable(data.new().resize_as_(data).zero_())

    # Print and plot every now and then
    if (it+1) % 10 == 0:
        print('Iter-{}; Loss: {:.4}'.format(it, loss.item()))

        samples = P(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)


Iter-9; Loss: 67.48
Iter-19; Loss: 67.73
Iter-29; Loss: 68.76
Iter-39; Loss: 64.8
Iter-49; Loss: 70.3
Iter-59; Loss: 65.09


In [14]:
loss.item()

71.52584838867188