# Constantes y Logger

In [None]:
# %cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters

In [None]:
SAVE_FIGS = True  # If you want to save the figures.
RUN_MCMC = True  # If you want to run the MCMC's algorithms or use saved chains

BURN = 2_000
NUM_SAMPLES = 50_000
N_WALKERS = 8


BURN = 100
NUM_SAMPLES = 1_000
N_WALKERS = 2

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from icecream import ic
import time

import bwb.utils.plotters as plotters
from bwb.distributions.posterior_samplers import NUTSPosteriorSampler

In [None]:
import torch
import numpy as np
import random

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic = True

In [None]:
from bwb import logging_ as logging

log = logging.get_logger(__name__)

In [None]:
from pathlib import Path

CURR_PATH = Path().absolute()
print(f"{CURR_PATH = }")
BASE_PATH = CURR_PATH.parent.parent
print(f"{BASE_PATH = }")
DATA_PATH = BASE_PATH / "data"
print(f"{DATA_PATH = }")
NETS_PATH = BASE_PATH / "wgan_gp" / "networks"
print(f"{NETS_PATH = }")
IMGS_PATH = CURR_PATH / "imgs" / "notebook-03-new"
IMGS_PATH.mkdir(parents=True, exist_ok=True)
print(f"{IMGS_PATH = }")
MCMC_PATH = BASE_PATH / "saved_mcmc"
print(f"{MCMC_PATH = }")
NUTS_PATH = MCMC_PATH / "NUTS"
print(f"{NUTS_PATH = }")

In [None]:
from bwb.config import conf

conf.use_single_precision()
conf.set_eps(1e-20)
conf

# Sampleador de Distribuciones Posterior

Al igual que los muestreadores de distribuciones anteriores, los muestreadores a posteriori heredan de `bwb.distributions.distribution_samplers.DistributionSampler`. En este caso, tenemos a la clase abstracta
`bwb.distributions.posterior_samplers.BaseLatentMCMCPosteriorSampler` que define un MCMC utilizando la librería `hamiltorch`.

Al igual que en `bwb.distributions.distribution_samplers.GeneratorDistribSampler`, la forma de ajustar esta clase es con un generador `generator`, una transformación `transform_out`, un muestreador de ruido `noise_sampler` y datos para la posterior `data`.

## Obtener el modelo para muestrear los datos

Definimos el Dataset para obtener la primera cara y poder muestrear de ella


In [None]:
from quick_torch import QuickDraw
import torchvision.transforms.v2 as T
from pathlib import Path

transforms = T.Compose([
    T.ToImage(),
    T.Resize(32),
    T.ToDtype(torch.float32, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])

ds = QuickDraw(
    DATA_PATH,
    categories="face",
    download=True,
    transform=transforms,
)

# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset

ds = ModelDataset(ds)

from bwb.distributions.distribution_samplers import UniformDiscreteSampler

ds_sampler = UniformDiscreteSampler().fit(ds)

i = 37
first_face = ds.get(i)
print(first_face.shape)
_ = plotters.plot_draw(first_face, title=f"Cara $i={i}$")

Obtenemos una muestra y lo graficamos en un histograma

## Obtener data

In [None]:
def plot_img_hist(
    face,
    data,
    n_data=None,
    plot=True,
    exp=None,
    title="Imagen de la cara a muestrear",
    hist_title="$n={}$ muestras a partir de la imagen",
):
    data_ = data.clone()
    if n_data is None:
        n_data = len(data)
    else:
        data = data[:n_data]

    shape = face.shape
    data_coords = (
        face.enumerate_support_()[data].cpu().numpy()
        + np.random.randn(n_data, 2) * 0.1
    )

    if plot:
        fig_ax1 = plotters.plot_draw(face, title=title)
        ax2 = plotters.plot_histogram_from_points(
            data_coords,
            title=hist_title.format(n_data),
            rotate=True,
            shape=shape,
            histplot_kwargs=dict(bins=shape[0]),
        )

    return face, data_, fig_ax1, ax2


def get_data(
    face,
    n_data,
    plot=True,
    exp=None,
    title="Imagen de la cara a muestrear",
    hist_title="$n={}$ muestras a partir de la imagen",
):
    data = face.sample((n_data,))

    return plot_img_hist(face, data, n_data, plot, exp, title, hist_title)


def get_sampler(
    sampler,
    n_data,
    plot=True,
    exp=None,
    title="Face sampled from dataset",
    hist_title="Histogram of the distribution generated by a drawing",
):
    return get_data(sampler.draw(), n_data, plot, exp, title, hist_title)


# _, _, (fig1, ax1), (fig2, ax2) = get_sampler(ds_sampler, 100)
_, data, (fig1, ax1), (fig2, ax2) = get_data(first_face, 100)
_ = plot_img_hist(first_face, data, 5)

In [None]:
fig2

## Definir red neuronal generadora y transformador

Se define la red neuronal de la misma manera que en el notebook anterior

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import (
    Generator,
    Encoder,
    LatentDistribution,
)
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

DS_NAME = "data"
FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval()
E.eval()
print()

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning

disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(torch.float64),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])

# Experimentos

### Experimentos con una cara fija, n variable

#### n = 5

In [None]:
nuts_params = dict(
    n_walkers=N_WALKERS,
    num_steps_per_sample=1,
    # burn=BURN,
    desired_accept_rate=0.6,
)
nuts_params.get("burn", None)

In [None]:
N = 5

NUTS_PARAMS = dict(
    n_walkers=N_WALKERS,
    num_steps_per_sample=1,
    burn=BURN,
    desired_accept_rate=0.6,
)

RUN_PARAMS = dict(n_steps=NUM_SAMPLES)

MCMC_PLOT_PARAMS = dict(
    n_rows=3,
    n_cols=6,
)


def save_fig(fig, name_to_save: str, imgs_path=IMGS_PATH):
    path_to_save = imgs_path / name_to_save
    fig.savefig(path_to_save.with_suffix(".pdf"), bbox_inches="tight")
    fig.savefig(path_to_save.with_suffix(".png"), bbox_inches="tight")
    print("Images saved with name", name_to_save)


def run_mcmc_experiment(
    face,
    data,
    n_data,
    thin=None,
    nuts_params=NUTS_PARAMS,
    run_params=RUN_PARAMS,
    mcmc_plot_params=MCMC_PLOT_PARAMS,
    run_mcmc=RUN_MCMC,
    mcmc_path=NUTS_PATH,
):
    # Constants
    burn = nuts_params.get("burn", None)
    num_samples = run_params.get("n_steps", None)
    n_walkers = nuts_params.get("n_walkers", None)
    n_rows = mcmc_plot_params.get("n_rows", 6)
    n_cols = mcmc_plot_params.get("n_cols", 12)

    # Path to save the chain
    mcmc_2_save_path = (
        mcmc_path
        / f"n-{n_data}-burn-{burn:_}-num_samples-{num_samples:_}-n_walkers-{n_walkers}"
    )
    mcmc_2_save_path = mcmc_2_save_path.with_suffix(".pkl.gz")

    # Getting the data
    data = data.clone()[:n_data]
    face, data, (fig1, _), (hist1, _) = plot_img_hist(face, data, n_data)
    print(data.shape)

    # Train the MCMC, or load if there are one in the cache
    if not mcmc_2_save_path.exists() or run_mcmc:
        post_pi_n = NUTSPosteriorSampler(**nuts_params).fit(
            G,
            transform_out_,
            noise_sampler,
            data,
        )

        post_pi_n.run(**run_params)

        tic = time.perf_counter()
        post_pi_n.save(mcmc_2_save_path)
        toc = time.perf_counter()
        ic(toc - tic)

    else:
        post_pi_n = NUTSPosteriorSampler.load(mcmc_2_save_path)
        post_pi_n.fit(G, transform_out_, noise_sampler, data)

    # Computing the mean autocorr time
    mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
    print(mean_autocorr_time)

    title = (
        "Muestras del MCMC para "
        + r"$n = "
        + f"{n_data}$, \n"
        + r"$\hat\tau_\mathrm{mean}"
        + f"={mean_autocorr_time}$"
    )
    mcmc_plot_params.update({"title": title})

    post_pi_n.shuffle_samples_cache(thin=thin or int(mean_autocorr_time))

    # Plot some samples form the posterior
    max_imgs = n_rows * n_cols
    fig2, ax = plotters.plot_list_of_draws(
        post_pi_n.sample(max_imgs), **mcmc_plot_params
    )

    return post_pi_n, (fig1, hist1, fig2)


data = first_face.sample((1_000,))

post_pi_n, (fig1, hist1, fig2) = run_mcmc_experiment(first_face, data, N)

if SAVE_FIGS:
    save_fig(fig1, f"image-sampler-n-{N}")
    save_fig(hist1, f"samples-hist-n-{N}")

In [None]:
POST_DATA_PATH = CURR_PATH / "data"  # / f"n_data-{N_DATA}.pkl"
POST_DATA_PATH.mkdir(parents=True, exist_ok=True)
DATA_PATH_ = POST_DATA_PATH / f"data-{i}.pkl"
print(DATA_PATH_)

In [None]:
import pickle

data = first_face.sample((1_000,))

if ic(not DATA_PATH_.exists()):
    with open(DATA_PATH_, "wb") as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open(DATA_PATH_, "rb") as f:
        data = pickle.load(f)

for n in [5, 20, 50, 100]:
    print("Ejecutando la cadena con n =", n)
    post_pi_n, (fig1, hist1, fig2) = run_mcmc_experiment(
        first_face,
        data,
        n,
        thin=None,
        nuts_params=dict(
            n_walkers=8,
            num_steps_per_sample=1,
            burn=2_000,
            desired_accept_rate=0.6,
        ),
        run_params=dict(
            n_steps=150_000,
        ),
        mcmc_plot_params=MCMC_PLOT_PARAMS,
        run_mcmc=RUN_MCMC,
        mcmc_path=NUTS_PATH,
    )

    if SAVE_FIGS:
        save_fig(fig1, f"image-sampler-i-{i}")
        save_fig(hist1, f"samples-hist-n-{n}")
        post_name = post_pi_n.__class__.__name__
        save_fig(fig2, f"mcmc-n-{n}-{post_name}")