In [None]:
%cd ..

In [None]:
import argparse
import collections
import lightning
import numpy as np
import pyro
import torch
import tqdm

import logger, train

In [None]:
# pyro.enable_validation(True)
# torch.autograd.set_detect_anomaly(True)

In [None]:
config, (data, model, trainer) = train.from_file("experiments/dcpc_celeba_config.json")

In [None]:
logger = config.get_logger('valid')

In [None]:
trainer.fit(model, data, ckpt_path="saved/models/Heteroskedastic_CelebA_Dcpc/0903_161845/checkpoint_149.ckpt")

In [None]:
model.graph.clear()
model.eval()
model.cuda()

In [None]:
for (xs, _, indices) in tqdm.tqdm(data.train_dataloader()):
    xs = xs.to(model.device)
    model._load_particles(indices, True)
    with model.graph.condition(X=xs) as joint:
        trace, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)
    del trace
    del log_weight
    del xs

In [None]:
xs, _, indices = list(data.val_dataloader())[0]
xs = xs.to(model.device)
model._load_particles(indices, False)

In [None]:
with model.graph.condition(X=xs) as joint:
    for i in range(299):
        trace, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)
        logger.info("Free energy at evaluation %d: %f" % (i+1, -log_weight.mean()))
        del trace
        del log_weight
    _, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)
logger.info("Free energy at evaluation 300: %f" % -log_weight.mean())
del log_weight

In [None]:
with model.graph.condition(z=model.graph.nodes['z']['value']) as predictive:
    x_hats = predictive(B=len(xs), mode="prior", P=model.num_particles).mean(dim=0)

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, axes = plt.subplots(nrows=8, ncols=8, sharex="all", sharey="all", layout="compressed")

for row in range(8):
    for col in range(8):
        orgs = data.reverse_transform(xs[row * 8 + col].detach().cpu()).transpose(0, -1)
        axes[row, col].imshow(orgs)
        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

fig.savefig("dcpc_celeba_orgs.pdf", bbox_inches="tight")
plt.show()

In [None]:
fig, axes = plt.subplots(nrows=8, ncols=8, sharex="all", sharey="all", layout="compressed")

for row in range(8):
    for col in range(8):
        estimates = data.reverse_transform(x_hats[row * 8 + col].detach().cpu()).transpose(0, -1).clamp(0, 1)
        axes[row, col].imshow(estimates)
        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

fig.savefig("dcpc_celeba_recons.pdf", bbox_inches="tight")
plt.show()

In [None]:
del xs
del x_hats

In [None]:
model.graph.clear()

In [None]:
posterior = {k: torch.cat((v.detach(), model.particles["valid"][k].detach()), dim=1)
             for k, v in model.particles["train"].items()}
x_hats = model.graph.predict(B=64 // model.num_particles, P=model.num_particles, **posterior)
x_hats = torch.flatten(x_hats, 0, 1)

In [None]:
fig, axes = plt.subplots(nrows=8, ncols=8, sharex="all", sharey="all", layout="compressed")

for row in range(8):
    for col in range(8):
        estimates = data.reverse_transform(x_hats[row * 8 + col].squeeze().detach().cpu()).transpose(0, -1).clamp(0, 1)
        axes[row, col].imshow(estimates)
        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

fig.savefig("dcpc_celeba_predictive.pdf", bbox_inches="tight")
plt.show()

In [None]:
del x_hats

In [None]:
model.graph.clear()

In [None]:
model.graph(B=64 // model.num_particles, lr=1e-3, mode="prior", P=model.num_particles)
for _ in range(299):
    model.graph(B=64 // model.num_particles, lr=1e-3, P=model.num_particles)
x_hats = model.graph(B=64 // model.num_particles, lr=1e-3, mode="prior", P=model.num_particles)
x_hats = x_hats.flatten(0, 1)

In [None]:
fig, axes = plt.subplots(nrows=8, ncols=8, sharex="all", sharey="all", layout="compressed")

for row in range(8):
    for col in range(8):
        estimates = data.reverse_transform(x_hats[row * 8 + col].squeeze().detach().cpu()).transpose(0, -1).clamp(0, 1)
        axes[row, col].imshow(estimates)
        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

fig.savefig("dcpc_celeba_priors.pdf", bbox_inches="tight")
plt.show()

In [None]:
del x_hats

In [None]:
SEEDS = [123, 456, 789, 101112, 131415]

In [None]:
LOG_LIKELIHOODS = torch.zeros(len(SEEDS), requires_grad=False)
MEAN_SQUARED_ERROR = torch.zeros(len(SEEDS), requires_grad=False)

In [None]:
with torch.no_grad():
    for (s, SEED) in enumerate(SEEDS):
        torch.manual_seed(SEED)
        np.random.seed(SEED)
        for b, (xs, target, indices) in enumerate(data.val_dataloader()):
            xs = xs.to(model.device)
            model._load_particles(indices, False)
            x_hats = model.graph(B=len(xs), mode="prior", P=model.num_particles)
            with model.graph.condition(X=xs) as predictive:
                trace, _ = predictive(B=len(xs), P=model.num_particles)
            LOG_LIKELIHOODS[s] += trace.nodes['X']['fn'].log_prob(xs).sum().cpu()
            MEAN_SQUARED_ERROR[s] += ((xs - x_hats) ** 2).sum(dim=0).mean().cpu()
    
            del xs
            del x_hats
            del trace
            del target
            del indices
            logger.info("Evaluated likelihood for valid batch %d under seed %s" % (b, s))
    
        LOG_LIKELIHOODS[s] /= len(data.val_dataloader().dataset)
        MEAN_SQUARED_ERROR[s] /= len(data.val_dataloader().dataset)

In [None]:
LOG_LIKELIHOODS.mean(), LOG_LIKELIHOODS.std()

In [None]:
MEAN_SQUARED_ERROR.mean(), MEAN_SQUARED_ERROR.std()

In [None]:
model.graph.clear()

In [None]:
model.graph.likelihood.scale

In [None]:
fids = []
metrics = collections.defaultdict(lambda: [])
data.setup("test")

for f in range(10):
    for b, batch in enumerate(tqdm.tqdm(data.test_dataloader(), desc='Test set FIDs')):
        ms = model.test_step(batch, b)
        for k, v in ms.items():
            metrics[k].append(v)
    fids.append(model.metrics['fid'].compute())
    model.metrics['fid'].reset()
    model.graph.gmm = None

fids = torch.stack(fids, dim=0)
fids.mean(), fids.std()

In [None]:
for k, v in metrics.items():
    metrics[k] = torch.tensor(v)

In [None]:
{m: v.mean(dim=-1) for m, v in metrics.items()}

In [None]:
{m: v.std(dim=-1) for m, v in metrics.items()}