# VQ-VAE2

In [None]:
import os
import random

from PIL import Image
import numpy as np
import torch

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
import torchvision.transforms

from vq_vae_2.examples.mnist.model import Generator, make_vq_vae
from vq_vae_2.examples.mnist.train_generator import load_images

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

In [None]:
os.makedirs('./results/mnist', exist_ok=True)
os.makedirs('./saved_states/mnist', exist_ok=True)

## Train an encoder/decoder on the MNIST dataset.

In [None]:
def save_reconstructions(batch, decoded):
    batch = batch.detach().permute(0, 2, 3, 1).contiguous()
    decoded = decoded.detach().permute(0, 2, 3, 1).contiguous()
    input_images = (np.concatenate(batch.cpu().numpy(), axis=0) * 255).astype(np.uint8)
    output_images = np.concatenate(decoded.cpu().numpy(), axis=0)
    output_images = (np.clip(output_images, 0, 1) * 255).astype(np.uint8)
    joined = np.concatenate([input_images[..., 0], output_images[..., 0]], axis=1)
    Image.fromarray(joined).save('./results/mnist/reconstructions_{epoch:05d}.png'.format(epoch=epoch))

In [None]:
vae = make_vq_vae()
if os.path.exists('./saved_states/nmist/vae.pt'):
    vae.load_state_dict(torch.load('./saved_states/mnist/vae.pt', map_location=DEVICE))
vae.to(DEVICE)
optimizer = optim.Adam(vae.parameters())
for i, batch in enumerate(load_images()):
    batch = batch.to(DEVICE)
    terms = vae(batch)
    # import pdb;pdb.set_trace()
    print(
        'step {step}: loss={loss} losses={losses} reconstructions={reconstructions} embedded={embedded}'.format(
            step=i,
            loss=terms['loss'],
            losses=terms['losses'],
            reconstructions=terms['reconstructions'],
            embedded=terms['embedded'],
        )
    )
    optimizer.zero_grad()
    terms['loss'].backward()
    optimizer.step()
    vae.revive_dead_entries()
    if not i % 10:
        torch.save(vae.state_dict(), './saved_states/mnist/vae.pt')
    if not i % 100:
        save_reconstructions(batch, terms['reconstructions'][-1], epoch=i)

## Train a PixelCNN on MNIST using a pre-trained VQ-VAE.

In [None]:
BATCH_SIZE = 32
LR = 1e-3

In [None]:
def load_images(train=True):
    while True:
        for data, _ in create_data_loader(train):
            yield data


def create_data_loader(train):
    mnist = torchvision.datasets.MNIST('./data/mnist', train=train, download=True,
                                       transform=torchvision.transforms.ToTensor())
    return torch.utils.data.DataLoader(mnist, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
vae = make_vq_vae()
vae.load_state_dict(torch.load('./saved_states/mnist/vae.pt', map_location=DEVICE))
vae.to(DEVICE)
vae.eval()

generator = Generator()
if os.path.exists('./saved_states/minst/gen.pt'):
    generator.load_state_dict(torch.load('./saved_states/mnist/gen.pt', map_location=DEVICE))
generator.to(DEVICE)

optimizer = optim.Adam(generator.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

test_images = load_images(train=False)
for batch_idx, images in enumerate(load_images()):
    images = images.to(DEVICE)
    losses = []
    for img_set in [images, next(test_images).to(DEVICE)]:
        _, _, encoded = vae.encoders[0](img_set)
        logits = generator(encoded)
        logits = logits.permute(0, 2, 3, 1).contiguous()
        logits = logits.view(-1, logits.shape[-1])
        losses.append(loss_fn(logits, encoded.view(-1)))
    optimizer.zero_grad()
    losses[0].backward()
    optimizer.step()
    print('train=%f test=%f' % (losses[0].item(), losses[1].item()))
    if not batch_idx % 100:
        torch.save(generator.state_dict(), './saved_states/mnist/gen.pt')

## Sample an image from a PixelCNN.

In [None]:
def sample_softmax(probs):
    number = random.random()
    for i, x in enumerate(probs):
        number -= x
        if number <= 0:
            return i
    return len(probs) - 1

In [None]:
vae = make_vq_vae()
vae.load_state_dict(torch.load('./saved_states/mnist/vae.pt', map_location=DEVICE))
vae.to(DEVICE)
vae.eval()
generator = Generator()
generator.load_state_dict(torch.load('./saved_states/mnist/gen.pt', map_location=DEVICE))
generator.to(DEVICE)

inputs = np.zeros([4, 7, 7], dtype=np.long)
for row in range(7):
    for col in range(7):
        with torch.no_grad():
            outputs = torch.softmax(generator(torch.from_numpy(inputs).to(DEVICE)), dim=1)
            for i, out in enumerate(outputs.cpu().numpy()):
                probs = out[:, row, col]
                inputs[i, row, col] = sample_softmax(probs)
    print('done row', row)
embedded = vae.encoders[0].vq.embed(torch.from_numpy(inputs).to(DEVICE))
decoded = torch.clamp(vae.decoders[0]([embedded]), 0, 1).detach().cpu().numpy()
decoded = np.concatenate(decoded, axis=1)
Image.fromarray((decoded * 255).astype(np.uint8)[0]).save('./results/mnist/samples.png')