In [82]:
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader

import visdom

from elbo_decomposition import elbo_decomposition
import lib.dist as dist
import lib.flows as flows
import lib.datasets as dset

from vae_quant import VAE, setup_data_loaders

In [100]:
def setup_data_loaders(args, use_cuda=False):
    if args.dataset == 'shapes':
        train_set = dset.Shapes()
    elif args.dataset == 'faces':
        train_set = dset.Faces()
    elif args.dataset == 'guppies':
        train_set = dset.Guppies()
    else:
        raise ValueError('Unknown dataset ' + str(args.dataset))

    kwargs = {'num_workers': 16, 'pin_memory': use_cuda}
    train_loader = DataLoader(dataset=train_set,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    return train_loader

def load_model_and_dataset(checkpt_filename):
    print('Loading model and dataset.')
    checkpt = torch.load(checkpt_filename, map_location=lambda storage, loc: storage)
    args = checkpt['args']
    state_dict = checkpt['state_dict']
    
    print(args)

    # model
    if not hasattr(args, 'dist') or args.dist == 'normal':
        prior_dist = dist.Normal()
        q_dist = dist.Normal()
    elif args.dist == 'laplace':
        prior_dist = dist.Laplace()
        q_dist = dist.Laplace()
    elif args.dist == 'flow':
        prior_dist = flows.FactorialNormalizingFlow(dim=args.latent_dim, nsteps=4)
        q_dist = dist.Normal()
    vae = VAE(z_dim=args.latent_dim, use_cuda=True, prior_dist=prior_dist, q_dist=q_dist, conv=args.conv)

    vae.load_state_dict(state_dict, strict=False)

    # dataset loader
    loader = setup_data_loaders(args)
    return vae, loader, args

win_samples = None
win_test_reco = None
win_latent_walk = None
win_train_elbo = None

def display_samples(model, x, vis):
    global win_samples, win_test_reco, win_latent_walk

    # plot random samples
    sample_mu = model.model_sample(batch_size=100).sigmoid()
    sample_mu = sample_mu
    images = list(sample_mu.view(-1, x.size(1), x.size(2), x.size(3)).data.cpu())
    win_samples = vis.images(images, 10, 2, opts={'caption': 'samples'}, win=win_samples)

    # plot the reconstructed distribution for the first 50 test images
    test_imgs = x[:50, :]
    _, reco_imgs, zs, _ = model.reconstruct_img(test_imgs)
    reco_imgs = reco_imgs.sigmoid()
    test_reco_imgs = torch.cat([
        test_imgs.view(x.size(1), -1, x.size(2), x.size(3)), reco_imgs.view(x.size(1), -1, x.size(2), x.size(3))], 0).transpose(0, 1)
    win_test_reco = vis.images(
        list(test_reco_imgs.contiguous().view(-1, x.size(1), x.size(2), x.size(3)).data.cpu()), 10, 2,
        opts={'caption': 'test reconstruction image'}, win=win_test_reco)

    # plot latent walks (change one variable while all others stay the same)
    zs = zs[0:3]
    batch_size, z_dim = zs.size()
    xs = []
    delta = torch.autograd.Variable(torch.linspace(-2, 2, 7), volatile=True).type_as(zs)
    for i in range(z_dim):
        vec = Variable(torch.zeros(z_dim)).view(1, z_dim).expand(7, z_dim).contiguous().type_as(zs)
        vec[:, i] = 1
        vec = vec * delta[:, None]
        zs_delta = zs.clone().view(batch_size, 1, z_dim)
        zs_delta[:, :, i] = 0
        zs_walk = zs_delta + vec[None]
        xs_walk = model.decoder.forward(zs_walk.view(-1, z_dim)).sigmoid()
        xs.append(xs_walk)

    xs = list(torch.cat(xs, 0).data.cpu())
    win_latent_walk = vis.images(xs, 7, 2, opts={'caption': 'latent walk'}, win=win_latent_walk)

In [101]:
decomp_fname = 'guppy_10_20/elbo_decomposition.pth'
checkpt_fname = 'guppy_10_20/checkpt-0000.pth'

In [102]:
vae, loader, args = load_model_and_dataset(checkpt_fname)

Loading model and dataset.
Namespace(batch_size=100, beta=20.0, beta_anneal=False, conv=True, dataset='guppies', dist='normal', exclude_mutinfo=False, gpu=0, lambda_anneal=False, latent_dim=10, learning_rate=0.001, log_freq=200, mss=True, num_epochs=10000, save='guppy_10_20', tcvae=True, visdom=False)


In [103]:
vis = visdom.Visdom(env=args.save, port=8097)



In [110]:
for x in loader:
    x = x.cuda(async=True)
    x = Variable(x)
    display_samples(vae, x, vis)

ValueError: only one element tensors can be converted to Python scalars

In [75]:
N = len(loader.dataset)  # number of data samples
K = vae.z_dim                    # number of latent variables
nparams = vae.q_dist.nparams
vae.eval()
qz_params = torch.Tensor(N, K, nparams)
n = 0
for xs in loader:
    batch_size = xs.size(0)
    with torch.no_grad():
        xs = Variable(xs.view(batch_size, xs.size(1), xs.size(2), xs.size(3)).cuda())
    qz_params[n:n + batch_size] = vae.encoder.forward(xs).view(batch_size, vae.z_dim, nparams).data
    n += batch_size
qz_means = qz_params[:, :, 0]
var = torch.std(qz_means.contiguous().view(N, K), dim=0).pow(2)

In [96]:
loader.dataset[0].size()

torch.Size([3, 256, 256])

In [77]:
x_ = vae.reconstruct_img(loader.dataset)

AttributeError: 'Guppies' object has no attribute 'view'

In [55]:
z = vae.encoder.forward(xs).data

In [56]:
z.shape

torch.Size([89, 20])

In [58]:
recon = vae.decoder.forward(z).data

RuntimeError: Given transposed=1, weight of size 10 512 1 1, expected input[89, 20, 1, 1] to have 10 channels, but got 20 channels instead

In [51]:
qz_means[0]

tensor([-0.2559, -1.3302,  1.2629,  0.2521, -0.8213, -1.3000, -1.2153,  1.4347,
         0.7383, -1.1549])

In [2]:
elbo_decomp = torch.load(decomp_fname)

In [3]:
elbo_decomp

{'logpx': tensor(-112506.8516, device='cuda:0'),
 'dependence': tensor(4.7775, device='cuda:0'),
 'information': tensor(6.8764, device='cuda:0'),
 'dimwise_kl': tensor(12.6220, device='cuda:0'),
 'analytical_cond_kl': tensor(24.2759, device='cuda:0'),
 'marginal_entropies': tensor([ 0.2425,  0.2372,  0.5986,  0.1648, -0.0074,  0.1511,  0.2688,  0.3027,
          0.4559,  0.1735], device='cuda:0'),
 'joint_entropy': tensor([-2.1898], device='cuda:0')}