In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline

DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)   
torch.cuda.manual_seed(SEED)


class CVAE(nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self._zdim = zdim
        self._in_units = 28 * 28
        hidden_units = 512
        self._encoder = nn.Sequential(
            nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
        )
        self._to_mean = nn.Linear(hidden_units, zdim)
        self._to_lnvar = nn.Linear(hidden_units, zdim)
        self._decoder = nn.Sequential(
            nn.Linear(zdim + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, self._in_units),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
        in_[:, :self._in_units] = x
        in_[:, self._in_units:] = labels
        h = self._encoder(in_)
        mean = self._to_mean(h)
        lnvar = self._to_lnvar(h)
        return mean, lnvar

    def decode(self, z, labels):
        in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
        in_[:, :self._zdim] = z
        in_[:, self._zdim:] = labels
        return self._decoder(in_)


def to_onehot(label):
    return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]


# Train
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for e in range(NUM_EPOCHS):
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        labels = to_onehot(labels)
        # Reconstruction images
        # Encode images
        x = images.view(-1, 28*28*1).to(DEVICE)
        mean, lnvar = model.encode(x, labels)
        std = lnvar.exp().sqrt()
        epsilon = torch.randn(ZDIM, device=DEVICE)

        # Decode latent variables
        z = mean + std * epsilon
        y = model.decode(z, labels)

        # Compute loss
        kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
        bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
        loss = (-1 * kld + bce).mean()

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.shape[0]

    print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


epoch: 1 epoch_loss: 200.63719370117187
epoch: 2 epoch_loss: 161.12568348795574
epoch: 3 epoch_loss: 148.00521254882813
epoch: 4 epoch_loss: 139.95780497233073
epoch: 5 epoch_loss: 133.9727405110677
epoch: 6 epoch_loss: 128.5624144124349
epoch: 7 epoch_loss: 124.46043256835938
epoch: 8 epoch_loss: 121.92354873453776
epoch: 9 epoch_loss: 119.49953055419923
epoch: 10 epoch_loss: 117.55944395345053
epoch: 11 epoch_loss: 116.65602159830729
epoch: 12 epoch_loss: 113.89352791341146
epoch: 13 epoch_loss: 113.69349411214192
epoch: 14 epoch_loss: 112.2301326538086
epoch: 15 epoch_loss: 111.06621162923177
epoch: 16 epoch_loss: 110.2488162516276
epoch: 17 epoch_loss: 109.30495798339844
epoch: 18 epoch_loss: 108.146364066569
epoch: 19 epoch_loss: 108.21288971761068
epoch: 20 epoch_loss: 107.08519996744792
epoch: 21 epoch_loss: 106.70227438151042
epoch: 22 epoch_loss: 106.23790797932942
epoch: 23 epoch_loss: 104.98591331380209
epoch: 24 epoch_loss: 105.06430674235025
epoch: 25 epoch_loss: 104.29053

In [2]:
# Generation data with label '5'
NUM_GENERATION = 100

os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
    z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
    label = torch.tensor([5], device=DEVICE)
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()

    # Save image
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
    plt.close(fig) 

In [3]:
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)
target_image, label = list(test_dataset)[48]

x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
    mean, _ = model.encode(x, to_onehot(label))
z = mean

print(f'z = {z.cpu().detach().numpy().squeeze()}')

z = [-8.36085454e-02  7.50088692e-02  2.96378732e-02 -1.06805992e+00
 -2.80670404e+00  3.21322143e-01 -1.99517421e-02 -2.60175997e-03
  1.12871274e-01  1.03785515e+00 -6.50890768e-02  1.19150484e+00
 -5.43689504e-02 -1.30372897e-01 -5.77559888e-01 -3.43698680e-01]


In [4]:
os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/fat/img{label}')
    plt.close(fig) 