In [None]:
%cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters

In [None]:
from bwb import _logging as logging


RUN_MCMC = True

log = logging.get_logger(__name__)

# Posterior Explícita

In [None]:
import numpy as np

from bwb.distributions.data_loaders import DistributionDrawDataLoader
from bwb.distributions import ExplicitPosteriorSampler

ExplicitPosteriorSampler.set_save_samples(True)

expl_posterior = ExplicitPosteriorSampler()
expl_posterior

In [None]:
from pathlib import Path

data_path = Path("./data/face.npy")

models_array = np.load(data_path)
n_faces, _ = models_array.shape
print(f"{n_faces = }")

faces = DistributionDrawDataLoader(models_array, (28, 28))

In [None]:
face0 = faces[0]
face0

In [None]:
import torch

torch.manual_seed(42)

data = faces[0].sample((100,))
data

In [None]:
face0.enumerate_support_()

In [None]:
import bwb.plotters as plotters
import matplotlib.pyplot as plt

data_coords = (
    face0.enumerate_support_()[data].cpu().numpy()
    + np.random.randn(100, 2) * 0.1
)
# data_coords = face0.enumerate_support_()[face0.sample((10_000,))].cpu().numpy() + np.random.randn(10_000, 2) * 0.1

plotters.plot_histogram_from_points(
    data_coords, rotate=True, histplot_kwargs=dict(bins=28)
)
plt.show()

In [None]:
data_5 = data[:5]
data_50 = data[:50]
data_100 = data[:100]
data_coords_5 = data_coords[:5]
data_coords_50 = data_coords[:50]
data_coords_100 = data_coords[:100]

In [None]:
max_images = 3 * 13
expl_posterior = ExplicitPosteriorSampler()
expl_posterior.fit(data=data_5, models=faces)
expl_posterior.rvs(1_000, seed=42)

print(f"Total time: {expl_posterior.total_time:.2f} seconds")

expl_posterior

In [None]:
import bwb.utils as utils

most_common = [
    faces[i] for i, _ in expl_posterior.samples_counter.most_common(max_images)
]
plotters.plot_list_of_draws(
    most_common,
    labels=utils.freq_labels_dist_sampler(expl_posterior),
    cmap="binary_r",
)
print()

In [None]:
max_images = 3 * 13
expl_posterior = ExplicitPosteriorSampler()
expl_posterior.fit(data=data[:10], models=faces)
expl_posterior.rvs(1_000, seed=42)
print(f"Total time: {expl_posterior.total_time:.2f} seconds")
most_common = [
    faces[i] for i, _ in expl_posterior.samples_counter.most_common(max_images)
]
plotters.plot_list_of_draws(
    most_common,
    labels=utils.freq_labels_dist_sampler(expl_posterior),
)
print()

In [None]:
max_images = 3 * 13
expl_posterior = ExplicitPosteriorSampler()
expl_posterior.fit(data=data[:20], models=faces)
expl_posterior.rvs(1_000, seed=42)
print(f"Total time: {expl_posterior.total_time:.2f} seconds")
most_common = [
    faces[i] for i, _ in expl_posterior.samples_counter.most_common(max_images)
]
plotters.plot_list_of_draws(
    most_common,
    labels=utils.freq_labels_dist_sampler(expl_posterior),
)
print()

In [None]:
from bwb.distributions.distribution_samplers import UniformDiscreteSampler

max_images = 3 * 13
uniform_sampler = UniformDiscreteSampler()
uniform_sampler.fit(models=faces)
uniform_sampler.rvs(1_000, seed=42)
print(f"Total time: {uniform_sampler.total_time:.2f} seconds")
most_common = [
    faces[i] for i, _ in uniform_sampler.samples_counter.most_common(max_images)
]
plotters.plot_list_of_draws(
    most_common,
    labels=utils.freq_labels_dist_sampler(uniform_sampler),
)
uniform_sampler

In [None]:
old_dict = dict(a=1, b=2)
old_dict.setdefault("c", 2)
new_dict = dict(a=2, c=4)
old_dict.update(new_dict)
old_dict

In [None]:
from quick_torch import QuickDraw
import torchvision.transforms.v2 as T
from pathlib import Path

# noinspection PyProtectedMember
from bwb.utils import array_like_t
from torch.utils.data import DataLoader
from bwb.config import config
import multiprocessing as mp
import bwb.distributions as dist
from bwb.distributions.models import BaseDiscreteWeightedModelSet

ds = QuickDraw(
    Path("./data"),
    categories="face",
    download=True,
    transform=T.Compose([
        T.ToTensor(),
        T.Lambda(lambda x: x.squeeze()),
    ]),
)


# noinspection PyMethodOverriding,PyShadowingNames
class DatasetWrapper(BaseDiscreteWeightedModelSet):
    def __init__(
        self, dataset, dataloader_args=None, device=None, dtype=None, eps=None
    ):
        self.dataset = dataset

        self.dataloader_args = dataloader_args or dict()
        default_args = dict(
            batch_size=1024, shuffle=False, num_workers=mp.cpu_count()
        )
        for key, value in default_args.items():
            self.dataloader_args.setdefault(key, value)

        self.device = torch.device(device or config.device)
        # If we are working in cuda
        if self.device.type == "cuda":
            self.dataloader_args.setdefault("pin_memory", True)

        self.dtype = dtype or config.dtype

        self.eps = eps or config.eps

    def __len__(self):
        return len(self.dataset)

    def compute_likelihood(self, data: array_like_t, **kwargs) -> torch.Tensor:
        dataloader = DataLoader(self.dataset, **self.dataloader_args)
        data = torch.as_tensor(data, device=self.device).reshape(1, -1)

        likelihoods = []

        for features, _ in dataloader:
            # Transfer the features to the appropriate device
            features = torch.as_tensor(
                features, device=self.device, dtype=self.dtype
            )
            # Flatten the features
            features = features.reshape(features.size(0), -1)
            # Normalize the features
            features = features / features.sum(dim=1, keepdim=True)
            # Compute the log of the features
            features = torch.log(features + self.eps)
            # Compute the likelihood
            evaluations = torch.take_along_dim(features, data, 1)
            likelihood = torch.exp(evaluations.sum(dim=1))

            likelihoods.append(likelihood)

        likelihood_cache = torch.cat(likelihoods, dim=0)

        probabilities = likelihood_cache / (likelihood_cache.sum() + self.eps)

        return probabilities

    def get(self, i: int, **kwargs) -> dist.DistributionDraw:
        return dist.DistributionDraw.from_grayscale_weights(self.dataset[i][0])


ds_wrapped = DatasetWrapper(ds)

ds_wrapped.get(1)
ds_wrapped.compute_likelihood(data)

In [None]:
max_images = 3 * 13
expl_posterior = ExplicitPosteriorSampler(save_samples=True)
expl_posterior.fit(data=data[:20], models=ds_wrapped)
print(f"Total time: {expl_posterior.total_time:.2f} seconds")
expl_posterior.rvs(1_000, seed=42)
print(f"Total time: {expl_posterior.total_time:.2f} seconds")
most_common = [
    faces[i] for i, _ in expl_posterior.samples_counter.most_common(max_images)
]
plotters.plot_list_of_draws(
    most_common,
    labels=utils.freq_labels_dist_sampler(expl_posterior),
)
print()

In [None]:
expl_posterior.models_.get(0)

In [None]:
%%time

expl_posterior.fit(data=data, models=faces)

In [None]:
%%time

expl_posterior.rvs(size=1000, seed=42)
print()

In [None]:
expl_posterior.total_time

In [None]:
expl_posterior2 = ExplicitPosteriorSampler()
expl_posterior2.total_time, expl_posterior.total_time

In [None]:
try:
    expl_posterior2.draw()
except Exception as e:
    print(e)

In [None]:
len(expl_posterior.samples_counter)

In [None]:
print(expl_posterior.samples_counter.total())
expl_posterior.draw()
print(expl_posterior.samples_counter.total())

In [None]:
expl_posterior.draw()

In [None]:
del expl_posterior, expl_posterior2, faces, models_array

# Importar redes

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import (
    Generator,
    Encoder,
    LatentDistribution,
)
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)

CURR_PATH = Path(".")
NETS_PATH = CURR_PATH / "wgan_gp" / "networks"

DS_NAME = "data"

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval()
E.eval()
print()

In [None]:
noise_sampler

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning

disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(torch.float64),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    # T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])

out: DistributionDraw = transform_out(m)
print(out.dtype)
out

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from bwb.distributions.distribution_samplers import GeneratorDistribSampler

distr_sampler = GeneratorDistribSampler()
distr_sampler.fit(G, transform_out_, noise_sampler)
distr_sampler.transform_noise(z)
distr_sampler.rvs(3)

In [None]:
distr_sampler.draw()

In [None]:
seed = 4102165607616432379
print(f"{seed = }")
torch.manual_seed(seed)

with torch.no_grad():
    m = transform_in(face0.grayscale_weights).unsqueeze(0).to(device)
    z = E(m)
    m = G(z)

face = transform_out(m)

# face = bwb_dist.DistributionDraw.from_grayscale_weights(m)
data = face.sample((100,))
print(f"{data = }")
face

In [None]:
shape = face.shape
data_coords = (
    face.enumerate_support_()[data].cpu().numpy()
    + np.random.randn(len(data), 2) * 0.1
)

plotters.plot_histogram_from_points(
    data_coords, rotate=True, shape=shape, histplot_kwargs=dict(bins=28)
)
plt.show()

# MCMC

![](https://quicklatex.com/cache3/e9/ql_ee61247290f642a5afc1fc6205cc98e9_l3.png)

Realizaremos MCMC utilizando la librería pyro

In [None]:
# from pyro.infer.mcmc import NUTS, MCMC

In [None]:
# DEV = "cpu"
# means1 = torch.tensor([-5, 5], device=DEV)
# means2 = -means1
#
# ndim = means1.size(0)
# cov = torch.eye(ndim, device=DEV) * 5
#
# # @torch.jit.script
# def prob(
#     params: dict[str, torch.Tensor],
#     means1: torch.Tensor = means1,
#     means2: torch.Tensor = means2,
#     cov: torch.Tensor = cov,
#     const: torch.Tensor = torch.tensor(0.7),
# ):
#     x = params["x"]
#     diff1 = x - means1
#     log_prob1 = 0.5 * torch.dot(diff1, torch.linalg.solve(cov, diff1))
#
#     diff2 = x - means2
#     log_prob2 = 0.5 * torch.dot(diff2, torch.linalg.solve(cov, diff2))
#     return -torch.log(const * torch.exp(-log_prob1) + (1-const) * torch.exp(-log_prob2))
#
# prob({"x": torch.zeros((ndim,), device=DEV)})

In [None]:
# nwalkers = 2
# p0 = torch.randn((nwalkers, ndim), device=DEV)
# p0.shape

In [None]:
# prob({"x": p0[0]})

In [None]:
# kernel = NUTS(
#     potential_fn=prob,
#     jit_compile=True,
#     target_accept_prob=0.6,
#     # full_mass=True,
# )

In [None]:
# mcmc = MCMC(kernel, warmup_steps=1000, num_samples=10_000, initial_params={"x": p0}, num_chains=nwalkers)

In [None]:
# mcmc.run(means1, means2, cov)

In [None]:
# mcmc.diagnostics()

In [None]:
# mcmc.summary()

In [None]:
# samples = mcmc.get_samples()["x"]
# samples.shape

In [None]:
# import matplotlib.pyplot as plt
#
# samples_ = samples.cpu()
# plt.hist(samples_[:, 0], 100, color="k", histtype="step")
# plt.xlabel(r"$\theta_1$")
# plt.ylabel(r"$p(\theta_1)$")
# plt.gca().set_yticks([]);

In [None]:
# import seaborn as sns
# x = samples[:, 0].cpu()
# y = samples[:, 1].cpu()
# g = sns.jointplot(x=x, y=y, alpha=0.1)
# g.plot_joint(sns.kdeplot, color="r")
# g.plot_marginals(sns.rugplot, color="r", height=-.15, clip_on=False)

In [None]:
# @torch.jit.script
def log_prior(
    z,
) -> torch.Tensor:
    """
    Corresponds to the log prior of a Normal(z; 0, 1) distribution.
    :param params:
    :return:
    """
    z = z.squeeze()
    z_2 = z**2
    return 0.5 * torch.sum(z_2)


# log_prior = torch.jit.script(_log_prior, example_inputs=[({"z": noise_sampler(1)},)])

log_prior(noise_sampler(1))

In [None]:
def log_likelihood_latent(
    z,
    data=data,
    generator=G,
    transform_out=transform_out_,
):
    eps = torch.finfo(z.dtype).eps

    z = torch.reshape(z, (1, -1, 1, 1))
    with torch.no_grad():
        m = generator(z)
    m = transform_out(m)
    m = m.reshape((-1,))

    m_data = m.take(data)
    # m_data_zeros = m_data == 0
    m_data = m_data + eps  # to avoid log(0)
    logits = torch.log(m_data)  # log m(x_i)
    # logits[m_data_zeros] = logits[m_data_zeros] * 3

    return torch.sum(logits)  # \sum_{i=1}^n \log m(x_i)


# log_likelihood_latent = torch.jit.script(_log_likelihood_latent, example_inputs=[({"z": noise_sampler(1)}, data, G, transform_out_)])

log_likelihood_latent(noise_sampler(1), data, G, transform_out_)

In [None]:
def log_posterior(
    z,
    data=data,
    generator=G,
    transform_out=transform_out_,
):
    return log_prior(z) + log_likelihood_latent(
        z, data, generator, transform_out
    )


log_posterior(noise_sampler(1), data, G, transform_out_)

In [None]:
# nwalkers = 2
# z0 = noise_sampler(nwalkers).squeeze()
# z0.shape

In [None]:
# kernel = NUTS(
#     potential_fn=log_posterior,
#     # jit_compile=True,
#     target_accept_prob=0.6,
# )

In [None]:
# mcmc = MCMC(
#     kernel,
#     warmup_steps=1_000,
#     num_samples=10_000,
#     initial_params={"z": z0},
#     num_chains=nwalkers,
#     mp_context="spawn"
# )

In [None]:
# mcmc.run()

In [None]:
# z = noise_sampler(1).squeeze()
# z_abs = torch.abs(z)
# constraint = torch.maximum(z_abs - 1, torch.zeros_like(z_abs))
# penalizations = 1e6 * constraint ** 2
# constraint, penalizations.sum()

In [None]:
# import typing as t
#
#
# @torch.jit.script
# def _log_prior_unif(z: torch.Tensor, radius=torch.tensor(1), penalizaton=torch.tensor(1e6)):
#     """
#     Compute the log-prior of the latent variable z. The prior is uniform [-1, 1].
#     """
#     z = z.squeeze()
#     z_abs = torch.abs(z)
#     constraint = torch.maximum(z_abs - radius, torch.zeros_like(z_abs))
#     penalizations = penalizaton * constraint ** 2
#     return -torch.sum(penalizations)
#
#
# @torch.jit.script
# def _log_prior_norm(z: torch.Tensor, radius=torch.tensor(3), penalizaton=torch.tensor(1e6)):
#     """
#     Compute the log-prior of the latent variable z.
#     """
#     z = z.squeeze()
#     # min_value = torch.finfo(z.dtype).min
#     norm_z_2 = torch.sum(z ** 2)  # \|z\|^2
#     n = z.shape[0]
#     constraint = torch.maximum(norm_z_2 / n - radius ** 2, torch.zeros_like(norm_z_2))
#     penalizations = penalizaton * constraint ** 2
#     return -norm_z_2 / 2 - penalizations  # -\frac{1}{2} \|z\|^2
#
#
# # @torch.jit.script
# def _log_prior(z: torch.Tensor, G, radius=torch.tensor(1), penalizaton=torch.tensor(1e6)):
#     """
#     Compute the log-prior of the latent variable z.
#     """
#     if G._latent_distr(1).name == "unif":
#         return _log_prior_unif(z, radius, penalizaton)
#     elif G._latent_distr(1).name == "norm":
#         return _log_prior_norm(z, radius, penalizaton)
#     raise ValueError(f"unknown latent distribution {G._latent_distr(1).name}")
#
#
# # @torch.jit.script
# def _log_likelihood_latent(z: torch.Tensor, data: torch.Tensor, generator: torch.jit.ScriptModule, transform_out: t.Callable[[torch.Tensor], torch.Tensor]):
#     """
#     Compute the log-likelihood of the data given the latent variable z.
#     This is done by first generating the image x from z, then transforming it to a DistributionDraw object.
#
#     The original likelihood is:
#     .. math::
#         \Pi_n(dm)
#         \propto \Pi(dm) \mathcal{L}_n(m)
#         = \int_{\mathcal{Z}} P_Z(dz) \Pi(dm | z) \mathcal{L}_n(m)
#         \propto \int_{\mathcal{Z}} dz \Pi(dm | z) \mathcal{L}_n(m) e^{-\frac{1}{2} \|z\|^2}
#         = \int_{\mathcal{Z}} dz \Pi(dm | z) \prod_{i=1}^n m(x_i) e^{-\frac{1}{2} \|z\|^2}
#
#     So, the log-likelihood with respecto to z is:
#     .. math::
#         \ell_n(z) = \sum_{i=1}^n \log m(x_i) - \frac{1}{2} \|z\|^2
#     """
#     eps = torch.finfo(z.dtype).eps
#
#     z = torch.reshape(z, (1, -1, 1, 1))
#     with torch.no_grad():
#         m = generator(z)
#     m = transform_out(m)
#     m = m.reshape((-1,))
#
#     # m_data = m.take(data) + eps  # to avoid log(0)
#     # logits = torch.log(m_data)  # log m(x_i)
#
#     m_data = m.take(data)
#     # m_data_zeros = m_data == 0
#     m_data = m_data + eps  # to avoid log(0)
#     logits = torch.log(m_data)  # log m(x_i)
#     # logits[m_data_zeros] = logits[m_data_zeros] * 3
#
#     return torch.sum(logits)  # \sum_{i=1}^n \log m(x_i)
#
#
# # @torch.jit.script
# def _log_posterior(
#     z: torch.Tensor,
#     data: torch.Tensor,
#     generator: torch.nn.Module,
#     transform_out,
#     radius=torch.tensor(1),
#     penalization=torch.tensor(1e6)
# ):
#     """
#     Compute the log-posterior of the latent variable z.
#     .. math::
#         \log \Pi_n(z) = \log \Pi_Z(z) + \log \Pi_n(dm)
#     """
#     return _log_prior_unif(z, radius, penalization) + _log_likelihood_latent(z, data, generator, transform_out)
#
#
# z = noise_sampler(1).squeeze()
#
# _log_likelihood_latent(z, data, G, transform_out), _log_posterior(z, data, G, transform_out)

In [None]:
# @torch.jit.script
# def _log_likelihood_true_latent(z: torch.Tensor, data: torch.Tensor, generator: torch.nn.Module, transform_out):
#     """
#     Compute the log-likelihood of the data given the latent variable z.
#     This is done by first generating the image x from z, then transforming it to a DistributionDraw object.
#
#     The original likelihood is:
#     .. math::
#         \Pi_n(dm)
#         \propto \Pi(dm) \mathcal{L}_n(m)
#         = \int_{\mathcal{Z}} P_Z(dz) \Pi(dm | z) \mathcal{L}_n(m)
#         \propto \int_{\mathcal{Z}} dz \Pi(dm | z) \mathcal{L}_n(m) e^{-\frac{1}{2} \|z\|^2}
#         = \int_{\mathcal{Z}} dz \Pi(dm | z) \prod_{i=1}^n m(x_i) e^{-\frac{1}{2} \|z\|^2}
#
#     So, the log-likelihood with respecto to z is:
#     .. math::
#         \ell_n(z) = \sum_{i=1}^n \log m(x_i) - \frac{1}{2} \|z\|^2
#     """
#     eps = torch.finfo(z.dtype).eps
#
#     z = torch.reshape(z, (1, -1, 1, 1))
#     with torch.no_grad():
#         m = generator(z)
#     m = transform_out(m)
#     m = m.reshape((-1,))
#
#     m_data = m.take(data)
#     m_data_zeros = m_data == 0
#     m_data = m_data + eps  # to avoid log(0)
#     logits = torch.log(m_data)  # log m(x_i)
#     logits[m_data_zeros] = logits[m_data_zeros] * 10
#
#     return torch.sum(logits)  # \sum_{i=1}^n \log m(x_i)
#
#
# _log_likelihood_true_latent(z, data, G, transform_out)

In [None]:
# for _ in range(10):
#     z = G.sample_noise(1).squeeze()
#     print(_log_likelihood_true_latent(z, data, G, transform_out), _log_prior(z), _log_posterior(z, data, G, transform_out))

## MCMC Clase Base

In [None]:
# from bwb.distributions import BaseGeneratorDistribSampler
#
#
# class BaseLatentMCMCPosteriorSampler(BaseGeneratorDistribSampler[DistributionDraw]):
#     def __init__(self):

## MCMC Clase Base

In [None]:
import torch
from hamiltorch import Sampler, Integrator, Metric

from bwb.distributions.posterior_samplers import (
    BaseLatentMCMCPosteriorSampler as _LatentMCMCPosteriorPiN,
)


_LatentMCMCPosteriorPiN()

## MCMC paralelo

In [None]:
from bwb.distributions.posterior_samplers import (
    LatentMCMCPosteriorSampler as LatentMCMCPosteriorPiN,
)


LatentMCMCPosteriorPiN(parallel=True), LatentMCMCPosteriorPiN(
    n_workers=2
), LatentMCMCPosteriorPiN(n_walkers=2)

## Nu-U-Turn-Sampler (NUTS)

In [None]:
burn = 5_000
num_samples = 500_000
n_walkers = 8

In [None]:
import bwb.utils as bwb_utils
from bwb.distributions.posterior_samplers import (
    NUTSPosteriorSampler as NUTSPosteriorPiN,
)

In [None]:
from pathlib import Path


if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = (
        NUTSPosteriorPiN(
            n_walkers=n_walkers,
            num_steps_per_sample=1,
            burn=burn,
            desired_accept_rate=0.6,
        )
        .fit(G, transform_out_, noise_sampler, data[:100])
        .run(n_steps=num_samples)
    )
    post_pi_n.save(NUTS_POSTERIOR_PATH)
else:
    post_pi_n = NUTSPosteriorPiN.load(NUTS_POSTERIOR_PATH)

post_pi_n

In [None]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [None]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [None]:
max_images = 12
plotters.plot_list_of_draws(post_pi_n_.rvs(max_images), n_rows=2, n_cols=6)
print()

In [None]:
post_pi_n

In [None]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms as T

transform_in_proj = T.Compose([
    # From pdf to grayscale
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToTensor(),
    T.Normalize(
        [0.5 for _ in range(1)],
        [0.5 for _ in range(1)],
    ),
])

transform_out_proj = T.Compose([
    # Ensure the range is in [0, 1]
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.max(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: x.squeeze(0)),
])

proj = ProjectorOnManifold(
    E,
    G,
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
import bwb.sgdw.sgdw as sgdw
from bwb.sgdw.utils import gamma
from bwb.sgdw.plotters import PlotterComparison

dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=post_pi_n,
    step_scheduler=gamma(a=0.5, b=0.51, c=0.51),
    projector=proj,
    proj_every=5,
    max_iter=len(post_pi_n.samples_cache) - 5,
    report_every=10,
).set_geodesic_params(
    reg=0.01,
    stop_thr=1e-3,
)

plotter_comp = PlotterComparison(
    dist_draw_sgdw, plot_every=50, n_cols=12, n_rows=2, cmap="binary_r"
)

bar = plotter_comp.run(
    include_dict=dict(total_time=True),
)

In [None]:
plotter_comp.plot(2000)

In [None]:
# max_images = 36
# plotters.plot_list_of_images(post_pi_n.rvs(max_images))
# bwb_utils.plot_list_of_draws(post_pi_n.rvs(max_images), max_images=max_images)

### Experimentos

In [None]:
post_pi_n_dict = {}
times_autocorr = {}
total_times = {}

In [None]:
burn = 400
num_samples = 15_000

In [None]:
bar = "=" * 10

for step_p_sample in [1, 2, 3]:
    post_pi_n_dict[step_p_sample] = {}
    times_autocorr[step_p_sample] = {}
    total_times[step_p_sample] = {}

    for n_data in [5, 50, 100]:
        post_pi_n_dict[step_p_sample][n_data] = {}
        times_autocorr[step_p_sample][n_data] = {}
        total_times[step_p_sample][n_data] = {}

        for desired_acc_rate in [0.3, 0.6]:
            print(
                bar
                + f" {step_p_sample = }, {n_data = }, {desired_acc_rate = } "
                + bar
            )
            post_pi_n = NUTSPosteriorPiN(
                n_walkers=8,
                num_steps_per_sample=step_p_sample,
                burn=burn,
                desired_accept_rate=desired_acc_rate,
            )
            post_pi_n.fit(data[:n_data], G, transform_out)
            post_pi_n.run(num_samples=num_samples)

            mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
            total_time = post_pi_n.total_time

            post_pi_n_dict[step_p_sample][n_data][desired_acc_rate] = post_pi_n
            times_autocorr[step_p_sample][n_data][
                desired_acc_rate
            ] = mean_autocorr_time
            total_times[step_p_sample][n_data][desired_acc_rate] = total_time

            print(post_pi_n)
            print(f"{mean_autocorr_time = }")
            print(f"{total_time = }")
            print()

Se demoró 13.85 horas en correr la celda anterior.

In [None]:
for n_data in [5, 50, 100]:
    for desired_acc_rate in [0.6, 0.3]:
        for step_p_sample in [1, 2, 3]:
            print(
                bar
                + f" {step_p_sample = }, {n_data = }, {desired_acc_rate = } "
                + bar
            )
            # print(post_pi_n_dict[step_p_sample][n_data][desired_acc_rate])
            print(
                "autocorr time:"
                f" {times_autocorr[step_p_sample][n_data][desired_acc_rate]}"
            )
            print(
                "time spend:"
                f" {total_times[step_p_sample][n_data][desired_acc_rate] / 60:.2f} mins"
            )
            print()
        print()
    print(bar * 8 + "\n")

Aquí tomamos aquellas cadenas que tuvieron un tiempo de autocorrelación menor.

In [None]:
post_pi_n_5 = post_pi_n_dict[1][5][0.3]
post_pi_n_50 = post_pi_n_dict[2][50][0.3]
post_pi_n_100 = post_pi_n_dict[3][100][0.3]

Obtenemos sus tiempos de autocorrelación promedios

In [None]:
mean_autocorr_time_n_5 = int(post_pi_n_5.get_autocorr_time().mean())
print(f"{mean_autocorr_time_n_5 = }")

mean_autocorr_time_n_50 = int(post_pi_n_50.get_autocorr_time().mean())
print(f"{mean_autocorr_time_n_50 = }")

mean_autocorr_time_n_100 = int(post_pi_n_100.get_autocorr_time().mean())
print(f"{mean_autocorr_time_n_100 = }")

Ahora haremos muestreos utilizando el tiempo de autocorrelación por cada cadena.

In [None]:
post_pi_n_5.shuffle_samples_cache(thin=mean_autocorr_time_n_5)

In [None]:
post_pi_n_50.shuffle_samples_cache(thin=mean_autocorr_time_n_50)

In [None]:
post_pi_n_100.shuffle_samples_cache(thin=mean_autocorr_time_n_100)

Y visualizaremos los resultados.

In [None]:
max_images = 4 * 16
bwb_utils.plot_list_of_draws(post_pi_n_5.rvs(max_images), max_images=max_images)

In [None]:
bwb_utils.plot_list_of_draws(
    post_pi_n_50.rvs(max_images), max_images=max_images
)

In [None]:
bwb_utils.plot_list_of_draws(
    post_pi_n_100.rvs(max_images), max_images=max_images
)

In [None]:
face0

## Hamiltonian Monte Carlo (HMC)

In [None]:
class HMCPosteriorPiN(LatentMCMCPosteriorPiN):
    def __init__(
        self,
        log_prob_fn=_log_posterior,
        num_samples=10,
        num_steps_per_sample=5,
        burn=10,
        step_size=0.1,
        **kwargs,
    ) -> None:
        super().__init__(
            log_prob_fn=log_prob_fn,
            num_samples=num_samples,
            num_steps_per_sample=num_steps_per_sample,
            burn=burn,
            step_size=step_size,
            sampler=Sampler.HMC,
            **kwargs,
        )


post_pi_n = HMCPosteriorPiN()
post_pi_n.fit(data[:5], G, transform_out)

max_images = 4 * 9
post_pi_n.reset_samples().run(num_samples=max_images)

bwb_utils.plot_list_of_draws(post_pi_n.rvs(max_images), max_images=max_images)

print(post_pi_n)
print(post_pi_n.total_time)

## Riemannian Manifold Hamiltonian Monte Carlo (RMHMC)

In [None]:
class RMHMCPosteriorPiN(LatentMCMCPosteriorPiN):
    def __init__(
        self,
        log_prob_fn=_log_posterior,
        num_samples=10,
        num_steps_per_sample=5,
        burn=10,
        step_size=0.1,
        fixed_point_max_iterations=1000,
        fixed_point_threshold=1e-5,
        explicit_binding_const=100,
        integrator=Integrator.IMPLICIT,
        metric=Metric.HESSIAN,
        **kwargs,
    ) -> None:
        super().__init__(
            log_prob_fn=log_prob_fn,
            num_samples=num_samples,
            num_steps_per_sample=num_steps_per_sample,
            burn=burn,
            step_size=step_size,
            fixed_point_max_iterations=fixed_point_max_iterations,
            fixed_point_threshold=fixed_point_threshold,
            explicit_binding_const=explicit_binding_const,
            sampler=Sampler.RMHMC,
            integrator=integrator,
            metric=metric,
            **kwargs,
        )

    def _additional_repr_(self, sep):
        to_return = super()._additional_repr_(sep)
        to_return += f"metric={self.hamiltorch_kwargs['metric'].name}" + sep
        return to_return


post_pi_n = RMHMCPosteriorPiN(step_size=0.1)
post_pi_n

### Implicit RMHMC

In [None]:
class ImplicitRMHMCPosteriorPiN(RMHMCPosteriorPiN):
    def __init__(
        self,
        log_prob_fn=_log_posterior,
        num_samples=10,
        num_steps_per_sample=10,
        burn=10,
        step_size=0.1,
        fixed_point_max_iterations=1000,
        fixed_point_threshold=1e-5,
        metric=Metric.HESSIAN,
        **kwargs,
    ) -> None:
        super().__init__(
            log_prob_fn=log_prob_fn,
            num_samples=num_samples,
            num_steps_per_sample=num_steps_per_sample,
            burn=burn,
            step_size=step_size,
            integrator=Integrator.IMPLICIT,
            metric=metric,
            fixed_point_max_iterations=fixed_point_max_iterations,
            fixed_point_threshold=fixed_point_threshold,
            **kwargs,
        )

    def _additional_repr_(self, sep):
        to_return = super()._additional_repr_(sep)
        to_return += (
            f"fixed_point_max_iterations={self.hamiltorch_kwargs['fixed_point_max_iterations']}"
            + sep
        )
        to_return += (
            f"fixed_point_threshold={self.hamiltorch_kwargs['fixed_point_threshold']}"
            + sep
        )
        return to_return


post_pi_n = ImplicitRMHMCPosteriorPiN(num_steps_per_sample=1)
post_pi_n.fit(data[:5], G, transform_out)

max_images = 4 * 9
num_samples = max_images
post_pi_n.reset_samples().run(num_samples=num_samples)

bwb_utils.plot_list_of_draws(post_pi_n.rvs(max_images), max_images=max_images)

print(post_pi_n)
print(post_pi_n.total_time)

### Explicit RMHMC

In [None]:
class ExplicitRMHMCPosteriorPiN(RMHMCPosteriorPiN):
    def __init__(
        self,
        log_prob_fn=_log_posterior,
        num_samples=10,
        num_steps_per_sample=10,
        burn=10,
        step_size=0.1,
        explicit_binding_const=100,
        metric=Metric.HESSIAN,
        **kwargs,
    ) -> None:
        super().__init__(
            log_prob_fn=log_prob_fn,
            num_samples=num_samples,
            num_steps_per_sample=num_steps_per_sample,
            burn=burn,
            step_size=step_size,
            integrator=Integrator.EXPLICIT,
            metric=metric,
            explicit_binding_const=explicit_binding_const,
            **kwargs,
        )

    def _additional_repr_(self, sep):
        to_return = super()._additional_repr_(sep)
        to_return += (
            f"explicit_binding_const={self.hamiltorch_kwargs['explicit_binding_const']}"
            + sep
        )
        return to_return


post_pi_n = ExplicitRMHMCPosteriorPiN()
post_pi_n.fit(data[:5], G, transform_out)

max_images = 4 * 9
post_pi_n.reset_samples().run(num_samples=max_images)

bwb_utils.plot_list_of_draws(post_pi_n.rvs(max_images), max_images=max_images)

print(post_pi_n)
print(post_pi_n.total_time)

# Experimentos

In [None]:
import traceback


def mcmc_experiment(
    data,
    num_steps_per_sample,
    burn=50,
    G=G,
    transform_out=transform_out,
    max_images=4 * 9,
):
    try:
        post_pi_n = NUTSPosteriorPiN(
            burn=burn,
            num_steps_per_sample=num_steps_per_sample,
            desired_accept_rate=0.3,
        )
        post_pi_n.fit(data, G, transform_out)

        post_pi_n.reset_samples().run(num_samples=max_images)

        bwb_utils.plot_list_of_draws(
            post_pi_n.rvs(max_images), max_images=max_images
        )

        print(f"{post_pi_n = }")
        print(f"{post_pi_n.total_time = :.4f}")
    except:
        traceback.print_exc()

## Num steps per sample = 10

In [None]:
num_steps_ps = 10

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

---

## Num steps per sample = 50

In [None]:
num_steps_ps = 50

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

---

## Num steps per sample = 100

In [None]:
num_steps_ps = 100

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

---

## Num steps per sample = 500

In [None]:
num_steps_ps = 500

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps)

---

## Num steps per sample = 1.000

In [None]:
num_steps_ps = 1_000
burn = 20

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

---

## Num steps per sample = 10.000

In [None]:
num_steps_ps = 10_000
burn = 10

### N Data = 5

In [None]:
n_data = 5
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 10

In [None]:
n_data = 10
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 25

In [None]:
n_data = 25
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)

### N Data = 50

In [None]:
n_data = 50
mcmc_experiment(data[:n_data], num_steps_per_sample=num_steps_ps, burn=burn)