In [None]:
%cd ..

In [None]:
import argparse
import collections
import lightning
import numpy as np
import pyro
import torch
import logger, train

In [None]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [None]:
config, (data, model, trainer) = train.from_file("experiments/ppc_celeba_config.json")

In [None]:
logger = config.get_logger('valid')

In [None]:
trainer.fit(model, data, ckpt_path="saved/models/CelebA_Ppc/0519_223059/checkpoint_99.ckpt")

In [None]:
model.graph.clear()
model.eval()
model.cuda()

In [None]:
xs, _, indices = list(data.val_dataloader())[0]
xs = xs.to(model.device)
model._load_particles(indices, False)

In [None]:
with model.graph.condition(X=xs) as joint:
    for i in range(299):
        trace, log_weight = joint(B=len(xs), lr=1e-1, P=model.num_particles)
        logger.info("Free energy at evaluation %d: %f" % (i+1, -log_weight.mean()))
        del trace
        del log_weight
    _, log_weight = joint(B=len(xs), lr=1e-1, P=model.num_particles)
logger.info("Free energy at evaluation 300: %f" % -log_weight.mean())
del log_weight

In [None]:
with model.graph.condition(z=model.graph.nodes['z']['value']) as predictive:
    x_hats = predictive(B=len(xs), mode="prior", P=model.num_particles).mean(dim=0)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=8, sharex="all", sharey="all", layout="compressed")

for i in range(8):
    orgs = xs[i].detach().transpose(0, -1).cpu()
    estimates = x_hats[i].detach().transpose(0, -1).cpu()
    axes[0, i].imshow(orgs)
    axes[1, i].imshow(estimates)

fig.savefig("ppc_celeba_recons.pdf", bbox_inches="tight")
plt.show()

In [None]:
del xs
del x_hats

In [None]:
model.graph.clear()

In [None]:
model.graph(B=8, lr=1e-1, mode="prior", P=model.num_particles)
for _ in range(299):
    model.graph(B=8, lr=1e-1, mode="prior", P=model.num_particles)
x_hats = model.graph(B=8, lr=1e-1, mode="prior", P=model.num_particles)
x_hats = x_hats.mean(dim=0)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=8, sharex="all", sharey="all", layout="compressed")

for i in range(8):
    estimates = x_hats[i].squeeze().detach().transpose(0, -1).cpu()
    axes[i].imshow(estimates)

fig.savefig("ppc_celeba_samples.pdf", bbox_inches="tight")
plt.show()

In [None]:
del x_hats

In [None]:
SEEDS = [123, 456, 789, 101112, 131415]

In [None]:
LOG_LIKELIHOODS = torch.zeros(len(SEEDS), requires_grad=False)
MEAN_SQUARED_ERROR = torch.zeros(len(SEEDS), requires_grad=False)

In [None]:
with torch.no_grad():
    for (s, SEED) in enumerate(SEEDS):
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        for b, (xs, target, indices) in enumerate(data.val_dataloader()):
            xs = xs.to(model.device)
            model._load_particles(indices, False)

            with model.graph.condition(X=xs) as joint:
                trace, _ = joint(B=len(xs), lr=1e-5, P=model.num_particles)
            x_hats = trace.nodes['X']['fn'].base_dist.loc.mean(dim=0)
            LOG_LIKELIHOODS[s] += trace.nodes['X']['fn'].log_prob(xs).sum().cpu()
            MEAN_SQUARED_ERROR[s] += ((xs - x_hats) ** 2).sum(dim=0).mean().cpu()
    
            del xs
            del x_hats
            del trace
            del target
            del indices
            logger.info("Evaluated likelihood for valid batch %d under seed %s" % (b, s))
    
        LOG_LIKELIHOODS[s] /= len(data.val_dataloader().dataset)
        MEAN_SQUARED_ERROR[s] /= len(data.val_dataloader().dataset)

In [None]:
LOG_LIKELIHOODS.mean(), LOG_LIKELIHOODS.std()

In [None]:
MEAN_SQUARED_ERROR.mean(), MEAN_SQUARED_ERROR.std()

In [None]:
model.graph.clear()

In [None]:
NUM_SAMPLES = 200
num_samples = 0

In [None]:
import utils.util as util

In [None]:
util.ensure_dir('data/celeba_ppc')

In [None]:
# plt.set_loglevel("error")

# while num_samples < NUM_SAMPLES:
#     x_hats = model.graph(B=data.batch_size, mode="prior", P=model.num_particles).mean(dim=0)
#     for k in range(data.batch_size):
#         fig = plt.imshow(x_hats[k].squeeze().detach().transpose(0, -1).cpu())
#         plt.savefig("data/celeba_ppc/%d.jpg" % (num_samples + k))
#     num_samples += data.batch_size

#     logger.info("Generated %d sample images" % num_samples)