# Load Trained PixelCNN and Sample/Evaluate

In [1]:
import time
import os
import argparse
import torch
import torch.nn as nn
from torch.nn import DataParallel
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
from tensorboardX import SummaryWriter
from utils import * 
from model import * 
from PIL import Image
import os

Get list of trained models

In [2]:
import glob

files = glob.glob('*/*/*.pth')

import re

ptrn = re.compile(r'.*_(\d+)\.pth')

g = ptrn.match(files[0])

batch = [int(ptrn.match(f).groups()[0]) for f in files]

In [3]:
parser = argparse.ArgumentParser()
# data I/O
parser.add_argument('-i', '--data_dir', type=str,
                    default='data', help='Location for the dataset')
parser.add_argument('-o', '--save_dir', type=str, default='saved',
                    help='Location for parameter checkpoints and samples')
parser.add_argument('-d', '--dataset', type=str,
                    default='cifar', help='Can be either cifar|mnist')
parser.add_argument('-p', '--print_every', type=int, default=50,
                    help='how many iterations between print statements')
parser.add_argument('-t', '--save_interval', type=int, default=10,
                    help='Every how many epochs to write checkpoint/samples?')
parser.add_argument('-r', '--load_params', type=str, default=None,
                    help='Point to checkpoint file to restore training from')
parser.add_argument('-z', '--resume', type=int, default=None,
                    help='Epoch number to resume from')
# model
parser.add_argument('-q', '--nr_resnet', type=int, default=5,
                    help='Number of residual blocks per stage of the model')
parser.add_argument('-n', '--nr_filters', type=int, default=160,
                    help='Number of filters to use across the model. Higher = larger model.')
parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10,
                    help='Number of logistic components in the mixture. Higher = more flexible model')
parser.add_argument('-l', '--lr', type=float,
                    default=0.0002, help='Base learning rate')
parser.add_argument('-e', '--lr_decay', type=float, default=0.999995,
                    help='Learning rate decay, applied every step of the optimization')
parser.add_argument('-b', '--batch_size', type=int, default=64,
                    help='Batch size during training per GPU')
parser.add_argument('-x', '--max_epochs', type=int,
                    default=5000, help='How many epochs to run in total?')
parser.add_argument('-s', '--seed', type=int, default=1,
                    help='Random seed to use')
args = parser.parse_args('')

In [4]:
args.resume=159
args.batch_size = 12

In [5]:
model_dir = os.path.join(args.save_dir, 'models')
image_dir = os.path.join(args.save_dir, 'images')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(image_dir):
    os.makedirs(image_dir)

In [6]:
sample_batch_size = 25
obs = (1, 28, 28) if 'mnist' in args.dataset else (1, 32, 32) # use modified grayscale CIFAR 10
input_channels = obs[0]
rescaling     = lambda x : (x - .5) * 2.
rescaling_inv = lambda x : .5 * x  + .5
kwargs = {'num_workers':2, 'pin_memory':True, 'drop_last':True}
transform_list = [transforms.ToTensor(), rescaling]

In [7]:
transform_list = [transforms.Grayscale(), transforms.ToTensor(), rescaling]
ds_transforms = transforms.Compose(transform_list)
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True, 
    download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False, 
                transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

loss_op   = lambda real, fake : discretized_mix_logistic_loss_1d(real, fake)
sample_op = lambda x : sample_from_discretized_mix_logistic_1d(x, args.nr_logistic_mix)

Files already downloaded and verified


In [None]:
def load_and

In [None]:
model_name = 'pcnn_lr:{:.5f}_nr-resnet{}_nr-filters{}'.format(args.lr, args.nr_resnet, args.nr_filters)

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters, 
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = DataParallel(model, [5])
#model = model.cuda(5)
epoch_start = 0

## Load pre-trained network

In [27]:
epoch = 129 # args.resume
load_part_of_model(model, '{}/{}_{}.pth'.format(model_dir, model_name, epoch))
epoch_start = args.resume + 1
print('model from epoch {} loaded'.format(epoch))

added 100% of params:
model from epoch 129 loaded


In [28]:
model.eval()
test_loss = 0.
for batch_idx, (input,_) in enumerate(test_loader):
    if batch_idx == 20:
        break
    input = input.cuda(5, async=True)
    input_var = Variable(input)
    output = model(input_var)
    loss = loss_op(input_var, output)
    test_loss += loss.data[0]
    del loss, output

In [29]:
deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.)
print('test loss : %s' % (test_loss / deno))

test loss : 4.623482959478233


In [32]:
def load_and_eval(epoch=139):
    load_part_of_model(model, '{}/{}_{}.pth'.format(model_dir, model_name, epoch))
    epoch_start = args.resume + 1
    print('model from epoch {} loaded'.format(epoch))
    model.eval()
    test_loss = 0.
    for batch_idx, (input,_) in enumerate(test_loader):
        input = input.cuda(5, async=True)
        input_var = Variable(input)
        output = model(input_var)
        loss = loss_op(input_var, output)
        test_loss += loss.data[0]
        del loss, output
    deno = batch_idx * args.batch_size * np.prod(obs) * np.log(2.)
    print('test loss : %s' % (test_loss / deno))
    return test_loss / deno

In [37]:
batch = sorted(batch)

In [39]:
scores = []
for b in batch:
    if b < 100:
        continue
    scores.append(load_and_eval(b))

added 100% of params:
model from epoch 109 loaded
test loss : 4.545966741246176
added 100% of params:
model from epoch 119 loaded
test loss : 4.543572381968458
added 100% of params:
model from epoch 129 loaded
test loss : 4.589228407378185
added 100% of params:
model from epoch 139 loaded
test loss : 4.549490092211239
added 100% of params:
model from epoch 149 loaded
test loss : 4.551606368496334
added 100% of params:
model from epoch 159 loaded
test loss : 4.551182236679022
added 100% of params:
model from epoch 169 loaded
test loss : 4.5530677320908115
added 100% of params:
model from epoch 179 loaded
test loss : 4.56078158090689
added 100% of params:
model from epoch 189 loaded
test loss : 4.5539982089448126
added 100% of params:
model from epoch 199 loaded
test loss : 4.557636452313435
added 100% of params:
model from epoch 209 loaded


Process Process-38:
Process Process-37:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()


KeyboardInterrupt: 

  File "/usr/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 342, in get
    with self._rlock:
  File "/usr/lib/python3.5/multiprocessing/queues.py", line 343, in get
    res = self._reader.recv_bytes()
  File "/usr/lib/python3.5/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/usr/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
   

## Sample from the network

In [68]:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

def sample(model, h=None, w=None, n=None):
    n = n or sample_batch_size
    h = h or obs[1]
    w = w or obs[2]
    model.train(False)
    data = torch.zeros(n, obs[0], h, w)
    data = data.cuda(5)
    for i in range(h):
        for j in range(w):
            data_v = Variable(data, volatile=True)
            out   = model(data_v, sample=True)
            out_sample = sample_op(out)
            data[:, :, i, j] = out_sample.data[:, :, i, j]
    return data

In [None]:
sample_t = sample(model, 256, 256, 1)
sample_t = rescaling_inv(sample_t)

utils.save_image(sample_t,'large.png', nrow=1, padding=0)