# Configuraciones iniciales

## Constantes

In [None]:
NOTEBOOK = 5
CLEAN_LOGS = True  # If you want to clean the logs directory
SAVE_FIGS = True  # If you want to save the figures.
NEW_DATA = False
N_ROWS, N_COLS = 1, 6

# MCMC Configurations
RUN_MCMC = False
BURN = 2_500
NUM_SAMPLES = 25_000
N_WALKERS = 16

# Posterior
N_DATA = 30

REPORT_EVERY = 100  # To report at the logger
PLOT_EVERY = 250
MAX_ITER = 5_000  # MAx number of iterations for the SGDW
BATCH_SIZE = 1
PROJ_EVERY = None
LIST_N_DATA = [5, 10, 25, 50]

# MAX_ITER = 50; REPORT_EVERY = 5  # Descomentar para debuguear
# BURN = 200
# NUM_SAMPLES = 1_000
# N_WALKERS = 2
# LIST_N_DATA = [5]

In [None]:
import random

import numpy as np
import torch

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic = True

In [None]:
from pathlib import Path
from icecream import ic


DS_NAME = "data"

CURR_PATH = Path().absolute()
ic(CURR_PATH)
BASE_PATH = CURR_PATH.parent.parent
ic(BASE_PATH)
DATA_PATH = BASE_PATH / "data"
ic(DATA_PATH)
WGAN_PATH = BASE_PATH / "wgan_gp"
ic(WGAN_PATH)
NETS_PATH = WGAN_PATH / "networks"
ic(NETS_PATH)
IMGS_PATH = CURR_PATH / "imgs" / f"notebook-{NOTEBOOK:02d}"
IMGS_PATH.mkdir(parents=True, exist_ok=True)
ic(IMGS_PATH)
MCMC_PATH = BASE_PATH / "saved_mcmc"
ic(MCMC_PATH)
NUTS_PATH = MCMC_PATH / "NUTS"
ic(NUTS_PATH)

In [None]:
def save_fig(fig, name_to_save):
    if SAVE_FIGS:
        PATH_TO_SAVE = IMGS_PATH / name_to_save
        fig.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
        fig.savefig(PATH_TO_SAVE.with_suffix(".png"))

## Importaciones generales

In [None]:
from icecream import ic
from bwb.sgdw import sgdw
from bwb.sgdw import wrappers
from bwb.sgdw import plotters as plotters_
from bwb.distributions import *
import bwb.utils.plotters as plotters
import matplotlib.pyplot as plt
from bwb.distributions.posterior_samplers import NUTSPosteriorSampler

## Configuraciones 

In [None]:
from bwb.config import conf

conf.use_gpu()
conf.use_single_precision()
conf.set_eps(1e-16)
conf

## Configuración del Logger

In [None]:
import time
from pathlib import Path


# Create the logs directory
LOG_PATH = (
    Path("logs")
    / f"notebook-{NOTEBOOK:02d}_{time.strftime('%Y%m%d_%H%M%S')}.log"
)
if not LOG_PATH.parent.exists():
    LOG_PATH.parent.mkdir()

# Clean the logs
if CLEAN_LOGS:
    for log_file in Path("logs").glob(f"notebook-{NOTEBOOK:02d}*.log"):
        log_file.unlink()

In [None]:
import logging
from bwb.logging_ import log_config
from bwb import logging_


# Remove the handlers
log_config.remove_all_handlers()
ic(log_config.loggers)

# Define and add FileHandler
fh = logging.FileHandler(LOG_PATH)
log_config.add_handler(fh)


_log = log_config.get_logger("notebook")
log_config.set_level(level=logging.DEBUG, name="notebook")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.sgdw")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.plotters")

Esta celda es para configurar la información mostrada en el logger

In [None]:
# Set the default options for the report
wrappers.ReportProxy.INCLUDE_OPTIONS = wrappers.ReportOptions(
    dt=False,
    dt_per_iter=True,
    iter=True,
    step_schd=True,
    total_time=True,
    w_dist=False,
)

## Obtención del dataset

In [None]:
# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset
import quick_torch as qt
import torchvision.transforms.v2 as T

transform_ds = T.Compose([
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])


def get_ds(file_path, transform=transform_ds):
    ic(file_path)
    categories = [qt.Category.FACE]
    dataset_ = qt.QuickDraw(
        root=DATA_PATH,
        categories=categories,
        transform=transform,
        download=True,
        recognized=True,
    )
    path_dataset = Path(file_path)
    dataset_.data = np.load(path_dataset).reshape(-1, 28, 28)
    dataset_.targets = np.ones(len(dataset_.data), dtype=int)
    dataset = dataset_.get_train_data()
    ic(len(dataset))

    return ModelDataset(dataset)


DS_PATH = WGAN_PATH / "dataset" / "cleaned" / f"{DS_NAME}.npy"
ds_models = get_ds(DS_PATH)
ds_dist_sampler = UniformDiscreteSampler().fit(ds_models)

i = 37
first_face = ds_models.get(i)
fig, _ = plotters.plot_draw(first_face, title=f"Cara $i={i}$")
save_fig(fig, "first_face")

In [None]:
from quick_torch import QuickDraw
import torchvision.transforms.v2 as T
from pathlib import Path

transforms = T.Compose([
    T.ToImage(),
    T.Resize(32),
    T.ToDtype(torch.float32, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])

ds = QuickDraw(
    DATA_PATH,
    categories="face",
    download=True,
    transform=transforms,
)

# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset

ds = ModelDataset(ds)

from bwb.distributions.distribution_samplers import UniformDiscreteSampler

ds_sampler = UniformDiscreteSampler().fit(ds)

i = 37
first_face = ds.get(i)
print(first_face.shape)
_ = plotters.plot_draw(first_face, title=f"Cara $i={i}$")

## Obtener data

In [None]:
data = first_face.sample((1_000,))[:N_DATA]

shape = first_face.shape
data_coords = (
    first_face.enumerate_support_()[data].cpu().numpy()
    + np.random.randn(len(data), 2) * 0.1
)

plotters.plot_histogram_from_points(
    data_coords, rotate=True, shape=shape, histplot_kwargs=dict(bins=shape[0])
)
plt.show()

In [None]:
POST_DATA_PATH = CURR_PATH / "data"  # / f"n_data-{N_DATA}.pkl"
POST_DATA_PATH.mkdir(parents=True, exist_ok=True)
DATA_PATH_ = POST_DATA_PATH / f"data-{i}.pkl"
print(DATA_PATH_)

In [None]:
def plot_img_hist(
    face,
    data,
    n_data=None,
    plot=True,
    exp=None,
    title="Imagen de la cara a muestrear",
    hist_title="$n={}$ muestras a partir de la imagen",
):
    data_ = data.clone()
    if n_data is None:
        n_data = len(data)
    else:
        data = data[:n_data]

    shape = face.shape
    data_coords = (
        face.enumerate_support_()[data].cpu().numpy()
        + np.random.randn(n_data, 2) * 0.1
    )

    if plot:
        fig_ax1 = plotters.plot_draw(face, title=title)
        ax2 = plotters.plot_histogram_from_points(
            data_coords,
            title=hist_title.format(n_data),
            rotate=True,
            shape=shape,
            histplot_kwargs=dict(bins=shape[0]),
        )

    return face, data_, fig_ax1, ax2


def get_data(
    face,
    n_data,
    plot=True,
    exp=None,
    title="Imagen de la cara a muestrear",
    hist_title="$n={}$ muestras a partir de la imagen",
):
    data = face.sample((n_data,))

    return plot_img_hist(face, data, n_data, plot, exp, title, hist_title)


def get_sampler(
    sampler,
    n_data,
    plot=True,
    exp=None,
    title="Face sampled from dataset",
    hist_title="Histogram of the distribution generated by a drawing",
):
    return get_data(sampler.draw(), n_data, plot, exp, title, hist_title)


# _, _, (fig1, ax1), (fig2, ax2) = get_sampler(ds_sampler, 100)
_, data, (fig1, ax1), (fig2, ax2) = get_data(first_face, 100)

import pickle

if NEW_DATA or not DATA_PATH_.exists():
    print("Salvando data...")
    with open(DATA_PATH_, "wb") as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
    print("Obteniendo data...")
    with open(DATA_PATH_, "rb") as f:
        data = pickle.load(f)

_ = plot_img_hist(first_face, data, 5)

## Obtener GAN

De la misma manera, se puede definir un muestreador de distribuciones utilizando una GAN. Para ello, empezamos definiendo las redes neuronales a utilizar

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import (
    Generator,
    Encoder,
    LatentDistribution,
)
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = conf.device

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)


G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

DS_NAME = "data"
FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"
ic(FACE_PATH)

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval()
E.eval()
print()

In [None]:
noise_sampler

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning

disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(conf.dtype),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(conf.dtype),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])


out: DistributionDraw = transform_out(m)
print(out.dtype)
out

## Definir Proyector

In [None]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms.v2 as T

transform_in_proj = T.Compose([
    # From pdf to grayscale
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Normalize(
        [0.5 for _ in range(1)],
        [0.5 for _ in range(1)],
    ),
])

transform_out_proj = T.Compose([
    # Ensure the range is in [0, 1]
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.max(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: x.squeeze(0)),
])

_proj = ProjectorOnManifold(
    E,
    G,
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)


def proj(input_: torch.Tensor) -> torch.Tensor:
    """
    Defines a projector using the interface.
    """
    return _proj(input_).to(input_)

## Definir $\gamma_k$

Aquí se utiliza una función de la forma
\begin{equation*}
    \gamma_k = \frac{a}{(b^{1/c} + k)^c}
\end{equation*}

Con $a > 0$, $b \geq 0$ y $0.5 < c \leq 1$

La idea es que cuando $k=0$, $\gamma_0 = \frac{a}{b}$ es la proporción entre $a$ y $b$, permitiendo ajustar el valor inicial.

In [None]:
from bwb.sgdw.utils import step_scheduler

window = 5


def test_gamma(gamma):

    for t in range(window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 50
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 100
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 300
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 500
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 1_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 3_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 5_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()


_a = 3
_eps = 1e-3
params = dict(a=_a, b=_a + 1e-2, c=0.5 + _eps)
# params = dict(a=1, b=1, c=1)

gamma = step_scheduler(**params)

test_gamma(step_scheduler(**params))

## Definir distribución a posteriori con MCMC

In [None]:
NUTS_POSTERIOR_PATH = (
    NUTS_PATH
    / f"bayes-bar-i={i}-n_data-{N_DATA}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}"
)
NUTS_POSTERIOR_PATH = NUTS_POSTERIOR_PATH.with_suffix(".pkl.gz")
print(NUTS_POSTERIOR_PATH)

POST_DATA_PATH = CURR_PATH / "data"  # / f"n_data-{N_DATA}.pkl"
POST_DATA_PATH.mkdir(parents=True, exist_ok=True)
DATA_PATH_ = (
    POST_DATA_PATH
    / f"i-{i}-n_data-{N_DATA}-burn-{BURN:_}-num_samples-{NUM_SAMPLES:_}-n_walkers-{N_WALKERS}.pkl"
)
print(DATA_PATH_)

In [None]:
N = 5

NUTS_PARAMS = dict(
    n_walkers=N_WALKERS,
    num_steps_per_sample=1,
    burn=BURN,
    desired_accept_rate=0.6,
    use_half=True,
)

RUN_PARAMS = dict(
    n_steps=NUM_SAMPLES,
)

MCMC_PLOT_PARAMS = dict(
    n_rows=3,
    n_cols=6,
)


def save_fig(fig, name_to_save: str, imgs_path=IMGS_PATH):
    path_to_save = imgs_path / name_to_save
    fig.savefig(path_to_save.with_suffix(".pdf"), bbox_inches="tight")
    fig.savefig(path_to_save.with_suffix(".png"), bbox_inches="tight")
    print("Images saved with name", name_to_save)


def run_mcmc_experiment(
    face,
    data,
    n_data,
    thin=None,
    nuts_params=NUTS_PARAMS,
    run_params=RUN_PARAMS,
    mcmc_plot_params=MCMC_PLOT_PARAMS,
    run_mcmc=RUN_MCMC,
    mcmc_path=NUTS_PATH,
):
    # Constants
    burn = nuts_params.get("burn", None)
    num_samples = run_params.get("n_steps", None)
    n_walkers = nuts_params.get("n_walkers", None)
    n_rows = mcmc_plot_params.get("n_rows", N_ROWS)
    n_cols = mcmc_plot_params.get("n_cols", N_COLS)

    # Path to save the chain
    mcmc_2_save_path = (
        mcmc_path
        / f"n-{n_data}-burn-{burn:_}-num_samples-{num_samples:_}-n_walkers-{n_walkers}"
    )
    mcmc_2_save_path = mcmc_2_save_path.with_suffix(".pkl.gz")

    # Getting the data
    data = data.clone()[:n_data]
    face, data, (fig1, _), (hist1, _) = plot_img_hist(face, data, n_data)
    print(f"{data.shape=}")

    # Train the MCMC, or load if there are one in the cache
    if not mcmc_2_save_path.exists() or run_mcmc:
        post_pi_n = NUTSPosteriorSampler(**nuts_params).fit(
            G,
            transform_out_,
            noise_sampler,
            data,
        )

        post_pi_n.run(**run_params)

        tic = time.perf_counter()
        post_pi_n.save(mcmc_2_save_path)
        toc = time.perf_counter()
        ic(toc - tic)

    else:
        post_pi_n = NUTSPosteriorSampler.load(mcmc_2_save_path)
        post_pi_n.fit(G, transform_out_, noise_sampler, data)

    # Computing the mean autocorr time
    mean_autocorr_time = int(post_pi_n.get_autocorr_time().mean())
    print(f"{mean_autocorr_time=}")

    title = (
        "Muestras del MCMC para "
        + r"$n = "
        + f"{n_data}$, \n"
        + r"$\hat\tau_\mathrm{mean}"
        + f"={mean_autocorr_time}$"
    )
    mcmc_plot_params.update({"title": title})

    post_pi_n.shuffle_samples_cache(thin=thin or int(mean_autocorr_time))

    # Plot some samples form the posterior
    max_imgs = n_rows * n_cols
    fig2, ax = plotters.plot_list_of_draws(
        post_pi_n.sample(max_imgs), **mcmc_plot_params
    )

    return post_pi_n, (fig1, hist1, fig2)

In [None]:
LIST_POST_PI_N = []
post_pi_n = None

for n in LIST_N_DATA:
    # Ejecutar MCMC
    print("Ejecutando la cadena con n =", n)
    post_pi_n, (fig_mcmc, hist_mcmc, fig_samples) = run_mcmc_experiment(
        first_face,
        data,
        n,
        thin=None,
    )

    LIST_POST_PI_N.append(post_pi_n)

    # Calcular baricentro
    # SGDW con la posterior
    dist_draw_sgdw = wrappers.ReportProxy(
        sgdw.DebiesedDistributionDrawSGDW(
            distr_sampler=post_pi_n,
            step_scheduler=gamma,
            batch_size=BATCH_SIZE,
            max_iter=int(min(MAX_ITER, post_pi_n.n_cached_samples - 1)),
        ),
        report_every=REPORT_EVERY,
        log=_log,
    )

    # Plotter
    plotter_comp = plotters_.PlotterComparison(
        dist_draw_sgdw,
        # projector=proj,
        # proj_every=PROJ_EVERY,
        n_cols=N_COLS,
        n_rows=N_ROWS,
        cmap="binary_r",
        plot_every=PLOT_EVERY,
    )

    # Correr el plotter
    _log.info(
        f"Running SGD-Wasserstein with '{DS_NAME}' bayesian projected"
        " barycenter"
    )
    with logging_.register_total_time(_log) as timer:
        bar = plotter_comp.run()

    fig_first_iter, _ = plotter_comp.plot(0)
    save_fig(fig_first_iter, f"first-iters-n-data-{n}")

    fig_last_iter, _ = plotter_comp.plot()
    save_fig(fig_last_iter, f"last-iters-n-data-{n}")

    fig_bwb, _ = plotters.plot_draw(bar, title=r"BWB with $n={}$".format(n))
    save_fig(fig_bwb, f"BWB-n-data-{n}")

    if SAVE_FIGS:
        save_fig(fig_mcmc, f"image-sampler-i-{i}")
        save_fig(hist_mcmc, f"samples-hist-n-{n}")
        post_name = post_pi_n.__class__.__name__
        save_fig(fig_samples, f"mcmc-n-{n}-{post_name}")

In [None]:
# shape = first_face.shape
# data_coords = first_face.enumerate_support_()[data].cpu().numpy() + np.random.randn(len(data), 2) * 0.1
#
# fig, _ = plotters.plot_histogram_from_points(data_coords, rotate=True, shape=shape, histplot_kwargs=dict(bins=shape[0]))
# save_fig(fig, f"n_data-{N_DATA}")
# plt.show()

In [None]:
# post_pi_n.mean_autocorr_time

In [None]:
# mean_autocorr_time = int(autocorr_time.mean())
# ic(mean_autocorr_time)
# max_autocorr_time = int(autocorr_time.max())
# ic(max_autocorr_time)

# post_pi_n.shuffle_samples_cache()

In [None]:
# from copy import copy
# post_pi_n_ = copy(post_pi_n)
# n_rows, n_cols = 6, 12
# max_imgs = n_rows * n_cols
# fig, ax = plotters.plot_list_of_draws(
#     post_pi_n_.sample(max_imgs),
#     n_rows=n_rows, n_cols=n_cols,
#     title=f"Muestras a partir del MCMC con "
# )
# save_fig(fig, f"n-data-{N_DATA}-{post_pi_n.__class__.__name__}-{n_rows}x{n_cols}")
# del post_pi_n_

# Cálculo del Baricentro

In [None]:
# dist_draw_sgdw = wrappers.ReportProxy(
#     sgdw.DebiesedDistributionDrawSGDW(
#         distr_sampler=post_pi_n,
#         step_scheduler=gamma,
#         batch_size=BATCH_SIZE,
#         max_iter=int(min(MAX_ITER, post_pi_n.n_cached_samples - 1)),
#     ),
#     report_every=REPORT_EVERY,
#     log=_log
# )
# dist_draw_sgdw

In [None]:
# plotter_comp = plotters_.PlotterComparison(
#     dist_draw_sgdw,
#     # projector=proj,
#     # proj_every=PROJ_EVERY,
#     n_cols=N_COLS,
#     n_rows=N_ROWS,
#     cmap="binary_r",
#     plot_every=PLOT_EVERY,
# )
# plotter_comp.sgdw

In [None]:
# _log.info(f"Running SGD-Wasserstein with '{DS_NAME}' bayesian projected barycenter")
# with logging_.register_total_time(_log) as timer:
#     bar = plotter_comp.run()
# ic(timer.elapsed_time)

In [None]:
# plotter_comp.sgdw

In [None]:
# fig, _ = plotter_comp.plot(0)
# save_fig(fig, F"first-iters-n-data-{N_DATA}")

In [None]:
# fig, _ = plotter_comp.plot()
# save_fig(fig, f"last-iters-n-data-{N_DATA}")

In [None]:
# fig, _ = plotters.plot_draw(bar, title=r"BWB with $n={}$".format(N_DATA))
# save_fig(fig, f"BWB-n-data-{N_DATA}")