In [2]:
"""
Train a PixelCNN on MNIST using a pre-trained VQ-VAE.
"""

import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets
import torchvision.transforms

from examples.mnist.model import Generator, make_vq_vae

BATCH_SIZE = 32
LR = 1e-3
DEVICE = torch.device('cuda')

In [3]:
def load_images(train=True):
    while True:
        for data, _ in create_data_loader(train):
            yield data


def create_data_loader(train):
    mnist = torchvision.datasets.MNIST('./data', train=train, download=True,
                                       transform=torchvision.transforms.ToTensor())
    return torch.utils.data.DataLoader(mnist, batch_size=BATCH_SIZE, shuffle=True)


In [6]:
vae = make_vq_vae()
vae.load_state_dict(torch.load('test_run/i=900_vae.pt', map_location='cuda'))
vae.to(DEVICE)
vae.eval()

generator = Generator()

if os.path.exists('gen.pt'):
    generator.load_state_dict(torch.load('gen.pt', map_location='cuda'))
    
generator.to(DEVICE)

optimizer = optim.Adam(generator.parameters(), lr=LR)
loss_fn = nn.CrossEntropyLoss()

test_images = load_images(train=False)
for batch_idx, images in enumerate(load_images()):
    images = images.to(DEVICE)
    losses = []
    for img_set in [images, next(test_images).to(DEVICE)]:
        _, _, encoded = vae.encoders[0](img_set)
        logits = generator(encoded)
        logits = logits.permute(0, 2, 3, 1).contiguous()
        logits = logits.view(-1, logits.shape[-1])
        losses.append(loss_fn(logits, encoded.view(-1)))
    optimizer.zero_grad()
    losses[0].backward()
    optimizer.step()
    print('train=%f test=%f' % (losses[0].item(), losses[1].item()))
    if not batch_idx % 100:
        torch.save(generator.state_dict(), f'test_run/i={batch_idx}_gen.pt')

train=4.489676 test=4.582826
train=3.662408 test=3.619313
train=3.161065 test=2.905653
train=2.759594 test=2.802075
train=2.460760 test=2.441480
train=2.505965 test=2.402090
train=2.416065 test=2.307115
train=2.326545 test=2.267982
train=2.245868 test=2.189433
train=2.233014 test=2.291986
train=2.176854 test=2.159055
train=2.087788 test=2.039011
train=2.101034 test=2.096434
train=2.098418 test=2.055263
train=2.022511 test=1.979339
train=2.086345 test=2.017610
train=1.940372 test=2.112549
train=1.927207 test=1.869560
train=1.957178 test=1.941396
train=1.997023 test=2.012407
train=1.895596 test=1.971034
train=1.863542 test=1.835842
train=1.836429 test=1.884603
train=1.874629 test=2.025145
train=2.001813 test=1.892579
train=1.859140 test=1.855186
train=1.830120 test=1.821462
train=1.922757 test=1.832204
train=1.780086 test=1.727088
train=1.791612 test=1.910289
train=1.823741 test=1.779297
train=1.786595 test=1.806064
train=1.835795 test=1.843180
train=1.844270 test=1.725634
train=1.780206

KeyboardInterrupt: 