# Hierarchical

In [None]:
import os
import random
import itertools

from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Not shown in this notebook:
from vq_vae_2.examples.hierarchical.model import TopPrior, BottomPrior, make_vae
from vq_vae_2.examples.hierarchical.data import load_images, load_tiled_images  # SwipeCropper

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

In [None]:
NAME = 'celebA256'
LOAD_EPOCH = 4500

In [None]:
training_sets = {
    'celebA256': {
        'VAE_PATH': './saved_states/hierarchical/celebA256/vae_{:08d}.pt',
        'VAE_OPT_PATH': './saved_states/hierarchical/celebA256/vae_opt_{:08d}.pt',
        'BOTTOM_PRIOR_PATH': './saved_states/hierarchical/celebA256/bottom_prior_{:08d}.pt',
        #'BOTTOM_PRIOR_OPT_PATH': './saved_states/hierarchical/celebA256/bottom_prior_opt_{:08d}.pt',
        'TOP_PRIOR_PATH': './saved_states/hierarchical/celebA256/top_prior_{:08d}.pt',
        #'TOP_PRIOR_OPT_PATH': './saved_states/hierarchical/celebA256/top_prior_opt_{:08d}.pt',
        'DATA_TRAIN': './data/celebA/data256x256/train',
        'DATA_TEST': './data/celebA/data256x256/test',
        'RESULTS': './results/hierarchical/celebA256/reconstructed_{:08d}.png',
        'RESULTS_TEST': './results/hierarchical/celebA256/reconstructed_test_{:08d}.png',
    },
    'celebA256_tiled': {
        'VAE_PATH': './saved_states/hierarchical/celebA256_tiled/vae_{:08d}.pt',
        'VAE_OPT_PATH': './saved_states/hierarchical/celebA256_tiled/vae_opt_{:08d}.pt',
        'BOTTOM_PRIOR_PATH': './saved_states/hierarchical/celebA256_tiled/bottom_prior_{:08d}.pt',
        #'BOTTOM_PRIOR_OPT_PATH': './saved_states/hierarchical/celebA256_tiled/bottom_prior_opt_{:08d}.pt',
        'TOP_PRIOR_PATH': './saved_states/hierarchical/celebA256_tiled/top_prior_{:08d}.pt',
        #'TOP_PRIOR_OPT_PATH': './saved_states/hierarchical/celebA256_tiled/top_prior_opt_{:08d}.pt',
        'DATA_TRAIN': './data/celebA/data1024x1024/train',
        'DATA_TEST': './data/celebA/data1024x1024/test',
        'RESULTS': './results/hierarchical/celebA_tiled/reconstructed_{:08d}.png',
        'RESULTS_TEST': './results/hierarchical/celebA_tiled/reconstructed_test_{:08d}.png',
    },
    'DIV2K_tiled': {
        'VAE_PATH': './saved_states/hierarchical/DIV2K_tiled/vae_{:08d}.pt',
        'VAE_OPT_PATH': './saved_states/hierarchical/DIV2K_tiled/vae_opt_{:08d}.pt',
        'BOTTOM_PRIOR_PATH': './saved_states/hierarchical/DIV2K_tiled/bottom_prior_{:08d}.pt',
        #'BOTTOM_PRIOR_OPT_PATH': './saved_states/hierarchical/DIV2K_tiled/bottom_prior_opt_{:08d}.pt',
        'TOP_PRIOR_PATH': './saved_states/hierarchical/DIV2K_tiled/top_prior_{:08d}.pt',
        #'TOP_PRIOR_OPT_PATH': './saved_states/hierarchical/DIV2K_tiled/top_prior_opt_{:08d}.pt',
        'DATA_TRAIN': './data/DIV2K/DIV2K_train_HR',
        'DATA_TEST': './data/DIV2K/DIV2K_test_HR',
        'RESULTS': './results/hierarchical/DIV2K_tiled/reconstructed_{:08d}.png',        
        'RESULTS_TEST': './results/hierarchical/DIV2K_tiled/reconstructed_test_{:08d}.png',        
    }
}

VAE_PATH = training_sets[NAME]['VAE_PATH']
#VAE_OPT_PATH = training_sets[NAME]['VAE_OPT_PATH']
BOTTOM_PRIOR_PATH = training_sets[NAME]['BOTTOM_PRIOR_PATH']
#BOTTOM_PRIOR_OPT_PATH = training_sets[NAME]['BOTTOM_PRIOR_OPT_PATH']
TOP_PRIOR_PATH = training_sets[NAME]['TOP_PRIOR_PATH']
#TOP_PRIOR_OPT_PATH = training_sets[NAME]['TOP_PRIOR_OPT_PATH']
DATA_TRAIN = training_sets[NAME]['DATA_TRAIN']
RESULTS = training_sets[NAME]['RESULTS']
RESULTS_TEST = training_sets[NAME]['RESULTS_TEST']

In [None]:
os.makedirs(os.path.dirname(VAE_PATH.format(0)), exist_ok=True)
#os.makedirs(os.path.dirname(VAE_OPT_PATH.format(0)), exist_ok=True)
os.makedirs(os.path.dirname(BOTTOM_PRIOR_PATH.format(0)), exist_ok=True)
#os.makedirs(os.path.dirname(BOTTOM_PRIOR_OPT_PATH.format(0)), exist_ok=True)
os.makedirs(os.path.dirname(RESULTS.format(0)), exist_ok=True)
os.makedirs(os.path.dirname(RESULTS_TEST.format(0)), exist_ok=True)

## Train a hierarchical VQ-VAE on 256x256 images.

In [None]:
def save_reconstructions(vae, images, RESULTS, i):
    vae.eval()
    with torch.no_grad():
        recons = [torch.clamp(x, 0, 1).permute(0, 2, 3, 1).detach().cpu().numpy()
                  for x in vae.full_reconstructions(images)]
    vae.train()
    top_recons, real_recons = recons
    images = images.permute(0, 2, 3, 1).detach().cpu().numpy()

    columns = np.concatenate([top_recons, real_recons, images], axis=-2)
    columns = np.concatenate(columns, axis=0)
    Image.fromarray((columns * 255).astype('uint8')).save(
        RESULTS.format(i)
    )

In [None]:
model = make_vae()
if os.path.exists(VAE_PATH.format(LOAD_EPOCH)):
    model.load_state_dict(torch.load(VAE_PATH.format(LOAD_EPOCH), map_location='cpu'))
model.to(DEVICE)
optimizer = optim.Adam(model.parameters())
#if os.path.exists(VAE_OPT_PATH.format(LOAD_EPOCH)):
#    model.load_state_dict(torch.load(VAE_OPT_PATH.format(LOAD_EPOCH), map_location='cpu'))
# data = load_images(DATA_TRAIN)
data = load_tiled_images(DATA_TRAIN, batch_size=12, width=256, height=256)
for i in itertools.count(LOAD_EPOCH):
    batch = next(data)
    images = batch.to(DEVICE)
    terms = model(images)
    optimizer.zero_grad()
    terms['loss'].backward()
    optimizer.step()
    model.revive_dead_entries()
    if not i % 100:
        print('step %d: mse=%f mse_top=%f' %
             (i, terms['losses'][-1].item(), terms['losses'][0].item()))
    if not i % 500:
        torch.save(model.state_dict(), VAE_PATH.format(i))
        #torch.save(optimizer.state_dict(), VAE_OPT_PATH.format(i))
        save_reconstructions(model, images, RESULTS, i)
#    with torch.no_grad():
#        pass

## Train the bottom-level prior.

In [None]:
vae = make_vae()
vae.load_state_dict(torch.load(VAE_PATH.format(LOAD_EPOCH), map_location='cpu'))
vae.to(DEVICE)
vae.eval()

bottom_prior = BottomPrior()
if os.path.exists(BOTTOM_PRIOR_PATH.format(LOAD_EPOCH)):
    bottom_prior.load_state_dict(torch.load(BOTTOM_PRIOR_PATH.format(LOAD_EPOCH), map_location='cpu'))
bottom_prior.to(DEVICE)

optimizer = optim.Adam(bottom_prior.parameters(), lr=1e-4)
#if os.path.exists(BOTTOM_PRIOR_OPT_PATH.format(LOAD_EPOCH)):
#    optimizer.load_state_dict(torch.load(BOTTOM_PRIOR_OPT_PATH.format(LOAD_EPOCH), map_location=DEVICE))

loss_fn = nn.CrossEntropyLoss()

data = load_tiled_images(DATA_TRAIN, batch_size=4, width=256, height=256)
for i in itertools.count(LOAD_EPOCH):
    images = next(data).to(DEVICE)
    bottom_enc = vae.encoders[0].encode(images)
    _, _, bottom_idxs = vae.encoders[0].vq(bottom_enc)
    _, _, top_idxs = vae.encoders[1](bottom_enc)
    logits = bottom_prior(bottom_idxs, top_idxs)
    logits = logits.permute(0, 2, 3, 1).contiguous()
    logits = logits.view(-1, logits.shape[-1])
    loss = loss_fn(logits, bottom_idxs.view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if not i % 100:
        print('step %d: loss=%f' % (i, loss.item()))
    if not i % 500:
        torch.save(bottom_prior.state_dict(), BOTTOM_PRIOR_PATH.format(i))
        #torch.save(optimizer.state_dict(), BOTTOM_PRIOR_OPT_PATH.format(i))

## Train the top-level prior.

In [None]:
vae = make_vae()
vae.load_state_dict(torch.load(VAE_PATH.format(LOAD_EPOCH), map_location='cpu'))
vae.to(DEVICE)
vae.eval()

top_prior = TopPrior()
if os.path.exists(TOP_PRIOR_PATH.format(LOAD_EPOCH)):
    top_prior.load_state_dict(torch.load(TOP_PRIOR_PATH.format(LOAD_EPOCH), map_location='cpu'))
top_prior.to(DEVICE)

optimizer = optim.Adam(top_prior.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

data = load_tiled_images(DATA_TRAIN, batch_size=4, width=256, height=256)
for i in itertools.count(LOAD_EPOCH):
    images = next(data).to(DEVICE)
    _, _, encoded = vae.encoders[1](vae.encoders[0].encode(images))
    logits = top_prior(encoded)
    logits = logits.permute(0, 2, 3, 1).contiguous()
    logits = logits.view(-1, logits.shape[-1])
    loss = loss_fn(logits, encoded.view(-1))
    if not i % 100:
        print('step %d: loss=%f' % (i, loss.item()))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if not i % 500:
        torch.save(top_prior.state_dict(), TOP_PRIOR_PATH.format(i))

## Generate samples using the top-level prior.

In [None]:
NUM_SAMPLES = 4

def sample_softmax(probs):
    number = random.random()
    for i, x in enumerate(probs):
        number -= x
        if number <= 0:
            return i
    return len(probs) - 1

vae = make_vae()
vae.load_state_dict(torch.load(VAE_PATH.format(LOAD_EPOCH)))
vae.to(DEVICE)
vae.eval()

top_prior = TopPrior()
top_prior.load_state_dict(torch.load(TOP_PRIOR_PATH.format(LOAD_EPOCH)))
top_prior.to(DEVICE)

results = np.zeros([NUM_SAMPLES, 32, 32], dtype=np.long)
for row in range(results.shape[1]):
    for col in range(results.shape[2]):
        partial_in = torch.from_numpy(results[:, :row + 1]).to(DEVICE)
        with torch.no_grad():
            outputs = torch.softmax(top_prior(partial_in), dim=1).cpu().numpy()
        for i, out in enumerate(outputs):
            probs = out[:, row, col]
            results[i, row, col] = sample_softmax(probs)
    print('done row', row)
with torch.no_grad():
    full_latents = torch.from_numpy(results).to(DEVICE)
    top_embedded = vae.encoders[1].vq.embed(full_latents)
    bottom_encoded = vae.decoders[0]([top_embedded])
    bottom_embedded, _, _ = vae.encoders[0].vq(bottom_encoded)
    decoded = torch.clamp(vae.decoders[1]([top_embedded, bottom_embedded]), 0, 1)
decoded = decoded.permute(0, 2, 3, 1).cpu().numpy()
decoded = np.concatenate(decoded, axis=1)
Image.fromarray((decoded * 255).astype(np.uint8)).save(os.path.join(RESULTS_TEST.format(0)))