# Constantes y Logger

In [1]:
# %cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters

In [2]:
SAVE_FIGS = True  # If you want to save the figures.
RUN_MCMC = False  # If you want to run the MCMC's algorithms or use saved chains

BURN = 2_000
NUM_SAMPLES = 50_000
N_WALKERS = 8

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from icecream import ic
import time

import bwb.plotters as plotters
from bwb.distributions.posterior_samplers import NUTSPosteriorSampler

In [4]:
import torch
import numpy as np
import random

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic = True

In [5]:
from bwb import _logging as logging

log = logging.get_logger(__name__)

In [6]:
from pathlib import Path

CURR_PATH = Path().absolute()
print(f"{CURR_PATH = }")
BASE_PATH = CURR_PATH.parent.parent
print(f"{BASE_PATH = }")
DATA_PATH = BASE_PATH / "data"
print(f"{DATA_PATH = }")
NETS_PATH = BASE_PATH / "wgan_gp" / "networks" 
print(f"{NETS_PATH = }")
IMGS_PATH = CURR_PATH / "imgs" / "notebook-03"
IMGS_PATH.mkdir(parents=True, exist_ok=True)
print(f"{IMGS_PATH = }")
MCMC_PATH = BASE_PATH / "saved_mcmc"
print(f"{MCMC_PATH = }")
NUTS_PATH = MCMC_PATH / "NUTS"
print(f"{NUTS_PATH = }")

In [7]:
from bwb.config import conf

conf.use_single_precision()
conf.set_eps(1e-20)
conf

# Sampleador de Distribuciones Posterior

Al igual que los muestreadores de distribuciones anteriores, los muestreadores a posteriori heredan de `bwb.distributions.distribution_samplers.DistributionSampler`. En este caso, tenemos a la clase abstracta
`bwb.distributions.posterior_samplers.BaseLatentMCMCPosteriorSampler` que define un MCMC utilizando la librería `hamiltorch`.

Al igual que en `bwb.distributions.distribution_samplers.GeneratorDistribSampler`, la forma de ajustar esta clase es con un generador `generator`, una transformación `transform_out`, un muestreador de ruido `noise_sampler` y datos para la posterior `data`.

## Obtener el modelo para muestrear los datos

Definimos el Dataset para obtener la primera cara y poder muestrear de ella

In [8]:
from quick_torch import QuickDraw
import torchvision.transforms.v2 as T
from pathlib import Path

transforms = T.Compose([
    T.ToImage(),
    T.Resize(32),
    T.ToDtype(torch.float32, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])

ds = QuickDraw(
    DATA_PATH,
    categories="face",
    download=True,
    transform=transforms,
)

# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset

ds = ModelDataset(ds)

from bwb.distributions.distribution_samplers import UniformDiscreteSampler
ds_sampler = UniformDiscreteSampler().fit(ds)

first_face = ds.get(0)
print(first_face.shape)
_ = plotters.plot_draw(ds.get(0), title="First face")

Obtenemos una muestra y lo graficamos en un histograma

## Obtener data

In [9]:
def get_sampler(sampler, n_data, plot=True, exp=None):
    face = sampler.draw()

    data = face.sample((n_data,))

    shape = face.shape
    data_coords = face.enumerate_support_()[data].cpu().numpy() + np.random.randn(n_data, 2) * 0.1

    
    if plot:
        fig_ax1 = plotters.plot_draw(face, title="Face sampled from dataset")
        ax2 = plotters.plot_histogram_from_points(data_coords, rotate=True, shape=shape, histplot_kwargs=dict(bins=shape[0]))

    return face, data, fig_ax1, (ax2.figure, ax2)

_, _, (fig1, ax1), (fig2, ax2) = get_sampler(ds_sampler, 100)

In [10]:
fig2

## Definir red neuronal generadora y transformador

Se define la red neuronal de la misma manera que en el notebook anterior

In [11]:
from wgan_gp.wgan_gp_vae.model_resnet import Generator, Encoder, LatentDistribution
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

DS_NAME = "data"
FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()
print()

In [12]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(torch.float64),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])

out: DistributionDraw = transform_out(m)
print(out.dtype)
out

# Experimentos

## Experimento 1

In [13]:
EXP = 1

NUTS_POSTERIOR_PATH = NUTS_PATH / f"exp-{EXP}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

In [14]:
_, data, (fig1, _), (fig2, _) = get_sampler(ds_sampler, 100)

In [15]:
if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = NUTSPosteriorSampler(
        n_walkers=N_WALKERS,
        num_steps_per_sample=1,
        burn=BURN,
        desired_accept_rate=0.6,
    ).fit(
        G, transform_out_, noise_sampler, data,
    )
    
    post_pi_n.run(
        n_steps=NUM_SAMPLES,
    )

    tic = time.perf_counter()
    post_pi_n.save(NUTS_POSTERIOR_PATH)
    toc = time.perf_counter()
    ic(toc - tic)
    
else:
    post_pi_n = NUTSPosteriorSampler.load(NUTS_POSTERIOR_PATH)
    post_pi_n.fit(G, transform_out_, noise_sampler, data)

post_pi_n

In [16]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [17]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [18]:
n_rows, n_cols = 6, 12
max_imgs = n_rows * n_cols
fig, ax = plotters.plot_list_of_draws(
    post_pi_n.rvs(max_imgs), 
    n_rows=n_rows, n_cols=n_cols,
    title=f"Samples from the MCMC"
)

In [19]:
if SAVE_FIGS:
    PATH_TO_SAVE = IMGS_PATH / f"exp-{EXP}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}"
    fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
    fig.savefig(PATH_TO_SAVE.with_suffix(".png"))
    PATH_TO_SAVE_FIG1 = IMGS_PATH / f"exp-{EXP}-face"
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".pdf"))
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".png"))
    PATH_TO_SAVE_FIG2 = IMGS_PATH / f"exp-{EXP}-hist"
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".pdf"))
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".png"))

## Experimento 2

In [20]:
EXP = 2

NUTS_POSTERIOR_PATH = NUTS_PATH / f"exp-{EXP}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

In [22]:
_, data, (fig1, _), (fig2, _) = get_sampler(ds_sampler, 100)

In [23]:
if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = NUTSPosteriorSampler(
        n_walkers=N_WALKERS,
        num_steps_per_sample=1,
        burn=BURN,
        desired_accept_rate=0.6,
    ).fit(
        G, transform_out_, noise_sampler, data,
    )
    
    post_pi_n.run(
        n_steps=NUM_SAMPLES,
    )

    tic = time.perf_counter()
    post_pi_n.save(NUTS_POSTERIOR_PATH)
    toc = time.perf_counter()
    ic(toc - tic)
    
else:
    post_pi_n = NUTSPosteriorSampler.load(NUTS_POSTERIOR_PATH)
    post_pi_n.fit(G, transform_out_, noise_sampler, data)

post_pi_n

In [24]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [25]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [26]:
n_rows, n_cols = 6, 12
max_imgs = n_rows * n_cols
fig, ax = plotters.plot_list_of_draws(
    post_pi_n.rvs(max_imgs), 
    n_rows=n_rows, n_cols=n_cols,
    title=f"Samples from the MCMC"
)

In [27]:
if SAVE_FIGS:
    PATH_TO_SAVE = IMGS_PATH / f"exp-{EXP}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}"
    fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
    fig.savefig(PATH_TO_SAVE.with_suffix(".png"))
    PATH_TO_SAVE_FIG1 = IMGS_PATH / f"exp-{EXP}-face"
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".pdf"))
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".png"))
    PATH_TO_SAVE_FIG2 = IMGS_PATH / f"exp-{EXP}-hist"
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".pdf"))
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".png"))

## Experimento 3

In [28]:
EXP = 3

NUTS_POSTERIOR_PATH = NUTS_PATH / f"exp-{EXP}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

In [29]:
_, data, (fig1, _), (fig2, _) = get_sampler(ds_sampler, 100)

In [30]:
if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = NUTSPosteriorSampler(
        n_walkers=N_WALKERS,
        num_steps_per_sample=1,
        burn=BURN,
        desired_accept_rate=0.6,
    ).fit(
        G, transform_out_, noise_sampler, data,
    )
    
    post_pi_n.run(
        n_steps=NUM_SAMPLES,
    )

    tic = time.perf_counter()
    post_pi_n.save(NUTS_POSTERIOR_PATH)
    toc = time.perf_counter()
    ic(toc - tic)
    
else:
    post_pi_n = NUTSPosteriorSampler.load(NUTS_POSTERIOR_PATH)
    post_pi_n.fit(G, transform_out_, noise_sampler, data)

post_pi_n

In [31]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [32]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [33]:
n_rows, n_cols = 6, 12
max_imgs = n_rows * n_cols
fig, ax = plotters.plot_list_of_draws(
    post_pi_n.rvs(max_imgs), 
    n_rows=n_rows, n_cols=n_cols,
    title=f"Samples from the MCMC"
)

In [34]:
if SAVE_FIGS:
    PATH_TO_SAVE = IMGS_PATH / f"exp-{EXP}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}"
    fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
    fig.savefig(PATH_TO_SAVE.with_suffix(".png"))
    PATH_TO_SAVE_FIG1 = IMGS_PATH / f"exp-{EXP}-face"
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".pdf"))
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".png"))
    PATH_TO_SAVE_FIG2 = IMGS_PATH / f"exp-{EXP}-hist"
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".pdf"))
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".png"))

## Experimento 4

In [35]:
EXP = 4

NUTS_POSTERIOR_PATH = NUTS_PATH / f"exp-{EXP}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

In [36]:
_, data, (fig1, _), (fig2, _) = get_sampler(ds_sampler, 100)

In [37]:
if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = NUTSPosteriorSampler(
        n_walkers=N_WALKERS,
        num_steps_per_sample=1,
        burn=BURN,
        desired_accept_rate=0.6,
    ).fit(
        G, transform_out_, noise_sampler, data,
    )
    
    post_pi_n.run(
        n_steps=NUM_SAMPLES,
    )

    tic = time.perf_counter()
    post_pi_n.save(NUTS_POSTERIOR_PATH)
    toc = time.perf_counter()
    ic(toc - tic)
    
else:
    post_pi_n = NUTSPosteriorSampler.load(NUTS_POSTERIOR_PATH)
    post_pi_n.fit(G, transform_out_, noise_sampler, data)

post_pi_n

In [38]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [39]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [40]:
n_rows, n_cols = 6, 12
max_imgs = n_rows * n_cols
fig, ax = plotters.plot_list_of_draws(
    post_pi_n.rvs(max_imgs), 
    n_rows=n_rows, n_cols=n_cols,
    title=f"Samples from the MCMC"
)

In [41]:
if SAVE_FIGS:
    PATH_TO_SAVE = IMGS_PATH / f"exp-{EXP}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}"
    fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
    fig.savefig(PATH_TO_SAVE.with_suffix(".png"))
    PATH_TO_SAVE_FIG1 = IMGS_PATH / f"exp-{EXP}-face"
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".pdf"))
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".png"))
    PATH_TO_SAVE_FIG2 = IMGS_PATH / f"exp-{EXP}-hist"
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".pdf"))
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".png"))

## Experimento 5

In [42]:
EXP = 5

NUTS_POSTERIOR_PATH = NUTS_PATH / f"exp-{EXP}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

In [43]:
_, data, (fig1, _), (fig2, _) = get_sampler(ds_sampler, 100)

In [44]:
if not NUTS_POSTERIOR_PATH.exists() or RUN_MCMC:
    post_pi_n = NUTSPosteriorSampler(
        n_walkers=N_WALKERS,
        num_steps_per_sample=1,
        burn=BURN,
        desired_accept_rate=0.6,
    ).fit(
        G, transform_out_, noise_sampler, data,
    )
    
    post_pi_n.run(
        n_steps=NUM_SAMPLES,
    )

    tic = time.perf_counter()
    post_pi_n.save(NUTS_POSTERIOR_PATH)
    toc = time.perf_counter()
    ic(toc - tic)
    
else:
    post_pi_n = NUTSPosteriorSampler.load(NUTS_POSTERIOR_PATH)
    post_pi_n.fit(G, transform_out_, noise_sampler, data)

post_pi_n

In [45]:
mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
print(mean_autocorr_time)

In [46]:
post_pi_n.shuffle_samples_cache(thin=int(mean_autocorr_time / 10))

In [47]:
n_rows, n_cols = 6, 12
max_imgs = n_rows * n_cols
fig, ax = plotters.plot_list_of_draws(
    post_pi_n.rvs(max_imgs), 
    n_rows=n_rows, n_cols=n_cols,
    title=f"Samples from the MCMC"
)

In [48]:
if SAVE_FIGS:
    PATH_TO_SAVE = IMGS_PATH / f"exp-{EXP}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}"
    fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
    fig.savefig(PATH_TO_SAVE.with_suffix(".png"))
    PATH_TO_SAVE_FIG1 = IMGS_PATH / f"exp-{EXP}-face"
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".pdf"))
    fig1.savefig(PATH_TO_SAVE_FIG1.with_suffix(".png"))
    PATH_TO_SAVE_FIG2 = IMGS_PATH / f"exp-{EXP}-hist"
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".pdf"))
    fig2.savefig(PATH_TO_SAVE_FIG2.with_suffix(".png"))