In [None]:
"""
Train an encoder/decoder on the MNIST dataset.
"""

import os

from PIL import Image
import numpy as np
import torch
import torch.optim as optim

from examples.mnist.model import make_vq_vae
from examples.mnist.train_generator import load_images

from tqdm.notebook import tqdm

DEVICE = torch.device('cuda')

# Train vqvae2

In [None]:
def save_reconstructions(batch, decoded, i):
    batch = batch.detach().permute(0, 2, 3, 1).contiguous()
    decoded = decoded.detach().permute(0, 2, 3, 1).contiguous()
    input_images = (np.concatenate(batch.cpu().numpy(), axis=0) * 255).astype(np.uint8)
    output_images = np.concatenate(decoded.cpu().numpy(), axis=0)
    output_images = (np.clip(output_images, 0, 1) * 255).astype(np.uint8)
    joined = np.concatenate([input_images[..., 0], output_images[..., 0]], axis=1)
    Image.fromarray(joined).save(f'{folder}/i={i}reconstructions.png')


In [None]:
vae = make_vq_vae()

if os.path.exists('vae.pt'):
    vae.load_state_dict(torch.load('vae.pt', map_location='cuda'))
    
vae.to(DEVICE)
optimizer = optim.Adam(vae.parameters())

folder='test_run'

for i, batch in tqdm(enumerate(load_images())):
    batch = batch.to(DEVICE)
    terms = vae(batch)
    print('step %d: loss=%f ' % (i, terms['loss'].item()))
    optimizer.zero_grad()
    terms['loss'].backward()
    optimizer.step()
    vae.revive_dead_entries()
    if not i % 100:
        torch.save(vae.state_dict(), f'{folder}/i={i}_vae.pt')
    if not i % 100:
        save_reconstructions(batch, terms['reconstructions'][-1], i)