# Hierarchical

In [None]:
import argparse
import itertools
import os

from PIL import Image
import numpy as np
import torch
import torch.optim as optim

from vq_vae_2.examples.hierarchical.data import load_images, load_tiled_images, SwipeCropper
from vq_vae_2.examples.hierarchical.model import make_vae

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(DEVICE)

In [None]:
os.makedirs('./results/hierarchical', exist_ok=True)
os.makedirs('./saved_states/hierarchical', exist_ok=True)

In [None]:
training_sets = {
    'celebA256': {
        'VAE_PATH': './saved_states/hierarchical/celebA256/celebA256_vae_{}.pt',
        'DATA_TRAIN': './data/celebA/data256x256/train',
        'RESULTS': './results/hierarchical/celebA256/celebA_{}.png',
    },
    'celebA256_tiled': {
        'VAE_PATH': './saved_states/hierarchical/celebA256_tiled/vae_{}.pt',
        'DATA_TRAIN': './data/celebA/data512x512/train',
        'RESULTS': './results/hierarchical/celebA_tiled/reconstructed_{}.png',
    }
}
NAME = 'celebA256_tiled'
LOAD_EPOCH = 0

In [None]:
VAE_PATH = training_sets[NAME]['VAE_PATH']
DATA_TRAIN = training_sets[NAME]['DATA_TRAIN']
RESULTS = training_sets[NAME]['RESULTS']

In [None]:
os.makedirs(os.path.dirname(VAE_PATH.format(0)), exist_ok=True)
os.makedirs(os.path.dirname(RESULTS.format(0)), exist_ok=True)

## Train a hierarchical VQ-VAE on 256x256 images.

In [None]:
def save_reconstructions(vae, images, RESULTS, i):
    vae.eval()
    with torch.no_grad():
        recons = [torch.clamp(x, 0, 1).permute(0, 2, 3, 1).detach().cpu().numpy()
                  for x in vae.full_reconstructions(images)]
    vae.train()
    top_recons, real_recons = recons
    images = images.permute(0, 2, 3, 1).detach().cpu().numpy()

    columns = np.concatenate([top_recons, real_recons, images], axis=-2)
    columns = np.concatenate(columns, axis=0)
    Image.fromarray((columns * 255).astype('uint8')).save(
        RESULTS.format(i)
    )

In [None]:
model = make_vae()
if os.path.exists(VAE_PATH.format(LOAD_EPOCH)):
    model.load_state_dict(torch.load(VAE_PATH.format(LOAD_EPOCH), map_location=DEVICE))
model.to(DEVICE)
optimizer = optim.Adam(model.parameters())
# data = load_images(DATA_TRAIN)
data = load_tiled_images(DATA_TRAIN, width=256, height=256)
for i in itertools.count():
    batch = next(data)
    images = batch.to(DEVICE)
    terms = model(images)
    print('step %d: mse=%f mse_top=%f' %
          (i, terms['losses'][-1].item(), terms['losses'][0].item()))
    optimizer.zero_grad()
    terms['loss'].backward()
    optimizer.step()
    model.revive_dead_entries()
    if not i % 30:
        torch.save(model.state_dict(), VAE_PATH.format(i))
        save_reconstructions(model, images, RESULTS, i)