Notebook que sirve de ejemplo para calcular el baricentro de un conjunto de datos de imágenes.

# Configuraciones iniciales

## Constantes

In [None]:
NOTEBOOK = 3
CLEAN_LOGS = True  # If you want to clean the logs directory
SAVE_FIGS = True  # If you want to save the figures.

REPORT_EVERY = 10  # To report at the logger
PLOT_EVERY = 20
MIN_ITER = 1_000
MAX_ITER = 5_000  # Max number of iterations for the SGDW
BATCH_SIZE = 30
PROJ_EVERY = 1
TOL = 0.05
WASS_DIST_EVERY = 5

# MAX_ITER = 50; REPORT_EVERY = 5  # Descomentar para debuguear

In [None]:
import random

import numpy as np
import torch

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic = True

In [None]:
from pathlib import Path
from icecream import ic


DS_NAME = "data"

CURR_PATH = Path().absolute()
ic(CURR_PATH)
BASE_PATH = CURR_PATH.parent.parent
ic(BASE_PATH)
DATA_PATH = BASE_PATH / "data"
ic(DATA_PATH)
WGAN_PATH = BASE_PATH / "wgan_gp"
ic(WGAN_PATH)
NETS_PATH = WGAN_PATH / "networks"
ic(NETS_PATH)
IMGS_PATH = CURR_PATH / "imgs" / f"notebook-{NOTEBOOK:02d}"
IMGS_PATH.mkdir(parents=True, exist_ok=True)
ic(IMGS_PATH)

In [None]:
def save_fig(fig_, name_to_save: str) -> None:
    """
    Saves a figure using the name
    """
    if SAVE_FIGS:
        PATH_TO_SAVE = IMGS_PATH / name_to_save
        fig_.savefig(PATH_TO_SAVE.with_suffix(".pdf"))
        fig_.savefig(PATH_TO_SAVE.with_suffix(".png"))

## Importaciones generales

In [None]:
from icecream import ic
from bwb.sgdw import sgdw
from bwb.sgdw import wrappers
from bwb.sgdw import plotters
from bwb.distributions import *

## Configuraciones 

In [None]:
from bwb.config import conf

conf.use_gpu()
conf.use_single_precision()
conf.set_eps(1e-16)
conf

## Configuración del Logger

In [None]:
import time
from pathlib import Path


# Create the logs directory
LOG_PATH = (
    Path("logs")
    / f"notebook-{NOTEBOOK:02d}_{time.strftime('%Y%m%d_%H%M%S')}.log"
)
if not LOG_PATH.parent.exists():
    LOG_PATH.parent.mkdir()

# Clean the logs
if CLEAN_LOGS:
    for log_file in Path("logs").glob(f"notebook-{NOTEBOOK:02d}*.log"):
        log_file.unlink()

In [None]:
import logging
from bwb.logging_ import log_config


# Remove the handlers
log_config.remove_all_handlers()
ic(log_config.loggers)

# Define and add FileHandler
fh = logging.FileHandler(LOG_PATH)
log_config.add_handler(fh)


_log = log_config.get_logger("notebook")
log_config.set_level(level=logging.DEBUG, name="notebook")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.sgdw")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.utils")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.plotters")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.wrappers")

Esta celda es para configurar la información mostrada en el logger

In [None]:
# Set the default options for the report
wrappers.ReportProxy.INCLUDE_OPTIONS = wrappers.ReportOptions(
    dt=False,
    dt_per_iter=True,
    iter=True,
    step_schd=True,
    total_time=True,
    w_dist=True,
)

## Obtención del dataset

In [None]:
# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset
import quick_torch as qt
import torchvision.transforms.v2 as T

transform_ds = T.Compose([
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])


def get_ds(file_path: Path, transform=transform_ds) -> ModelDataset:
    """
    Get a dataset to follow the interface
    """
    ic(file_path)
    categories = [qt.Category.FACE]
    dataset_ = qt.QuickDraw(
        root=DATA_PATH,
        categories=categories,
        transform=transform,
        download=True,
        recognized=True,
    )
    path_dataset = Path(file_path)
    dataset_.data = np.load(path_dataset).reshape(-1, 28, 28)
    dataset_.targets = np.ones(len(dataset_.data), dtype=int)
    dataset = dataset_.get_train_data()
    ic(len(dataset))

    return ModelDataset(dataset)


DS_PATH = WGAN_PATH / "dataset" / "cleaned" / f"{DS_NAME}.npy"
ds_models = get_ds(DS_PATH)

ds_models.get(0)

In [None]:
ds_models

## Obtener GAN

De la misma manera, se puede definir un muestreador de distribuciones utilizando una GAN. Para ello, empezamos definiendo las redes neuronales a utilizar

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import (
    Generator,
    Encoder,
    LatentDistribution,
)
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = conf.device

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)


G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

DS_NAME = "data"
FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"
ic(FACE_PATH)

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval()
E.eval()
print()

In [None]:
noise_sampler

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning

disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(conf.dtype),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(conf.dtype),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])


out: DistributionDraw = transform_out(m)
print(out.dtype)
out

## Definir Proyector

In [None]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms.v2 as T

transform_in_proj = T.Compose([
    # From pdf to grayscale
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Normalize(
        [0.5 for _ in range(1)],
        [0.5 for _ in range(1)],
    ),
])

transform_out_proj = T.Compose([
    # Ensure the range is in [0, 1]
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.max(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: x.squeeze(0)),
])

_proj = ProjectorOnManifold(
    E,
    G,
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)


def proj(input_: torch.Tensor) -> torch.Tensor:
    """
    Defines a projector using the interface.
    """
    return _proj(input_).to(input_)

## Definir $\gamma_k$

Aquí se utiliza una función de la forma
\begin{equation*}
    \gamma_k = \frac{a}{(b^{1/c} + k)^c}
\end{equation*}

Con $a > 0$, $b \geq 0$ y $0.5 < c \leq 1$

La idea es que cuando $k = 0 $, $\gamma_0 = \frac{a}{b}$ es la proporción entre $a$ y $b$, permitiendo ajustar el valor
 inicial.

In [None]:
from bwb.sgdw.utils import step_scheduler

window = 5


def test_gamma(gamma_) -> None:

    for t in range(window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 50
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 100
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 300
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 500
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 1_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 3_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 5_000
    for t in range(init, init + window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()


_a = 3
_eps = 1e-2
params = dict(a=_a, b=_a + 1e-2, c=0.5 + _eps)
params_interp = dict(a=_a, b=_a + 1e-2, c=0.5 + _eps / 10)
# params = dict(a=1, b=1, c=1)

gamma = step_scheduler(**params)
# gamma = 0.1

test_gamma(step_scheduler(**params))

# Baricentro de imágenes

Para obtener el baricentro de un conjunto de imágenes, es necesario utilizar la clase `UniformDiscreteSampler` y fijarla con una clase que tenga los siguientes métodos:
* `get(i) -> Distribution` que retorne la distribución $i$-ésima.
* `__len__() -> int` que retorne el tamaño del dataset.

In [None]:
distr_sampler = UniformDiscreteSampler[DistributionDraw]().fit(models=ds_models)

Luego definimos el algoritmo a utilizar. En este caso, utilizaremos `DebiesedDistributionDrawSGDW` que realiza los transportes utilizando un método convolución debiesed. 

In [None]:
def print_x(x: dict) -> None:
    print(x.keys())


def register_wass_dist():
    k_lst = []
    wass_dist_lst = []
    wass_dist_smooth_lst = []

    def _register_wass_dist(x: dict) -> None:
        if "wass_dist" in x:
            k_lst.append(x["k"])
            wass_dist_lst.append(x["wass_dist"])
            wass_dist_smooth_lst.append(x["wass_dist_smooth"])

    return _register_wass_dist, k_lst, wass_dist_lst, wass_dist_smooth_lst


def register_key(key: str, map_=lambda x: x):
    elements = []

    def _register(x: dict) -> None:
        if key not in x:
            print(f"{key} not found. Available keys: {list(x.keys())}")
            return None
        elements.append(map_(x[key]))

    return _register, elements


def register_distr():
    return register_key(
        "pos_wgt",
        lambda pos_wgt: DistributionDraw.from_grayscale_weights(pos_wgt),
    )


def register_pos_wgt_sampled():
    return register_key(
        "lst_mu",
        lambda lst_mu: [mu.grayscale_weights for mu in lst_mu],
    )


def mix_callbacks(callbacks: list):
    def _callback(x: dict) -> None:
        for callback in callbacks:
            callback(x)

    return _callback


# register_wass_dist_callback, k_lst, wass_dist_lst, wass_dist_smooth_lst = (
#     register_wass_dist()
# )
# register_pos_wgt_callback, pos_wgt_lst_ = register_key("pos_wgt")
# register_pos_wgt_proj_callback, pos_wgt_proj_lst_ = register_key("pos_wgt_proj")
# register_pos_wgt_proj_interp_callback, pos_wgt_proj_interp_lst = (
#     register_key("pos_wgt_proj_interp")
# )  # fmt: skip
# register_pos_wgt_sampled_callback, pos_wgt_sampled_lst = (
#     register_pos_wgt_sampled()
# )
# callback_ = mix_callbacks([
#     # print_x,
#     register_wass_dist_callback,
#     register_pos_wgt_callback,
#     register_pos_wgt_proj_callback,
#     register_pos_wgt_proj_interp_callback,
#     register_pos_wgt_sampled_callback,
# ])

In [None]:
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
import numpy.typing as npt
import itertools
import typing as t


def plot_comparison[
    SGDWT: sgdw.SGDW[DistributionDraw, torch.Tensor]
](
    sgdw_: SGDWT,
    plot_every: int | None = None,
    n_cols: int = 12,
    n_rows: int = 2,
    factor: float = 1.5,
    cmap: str = "binary",
    **kwargs,
) -> (t.Callable[..., tuple[Figure, npt.NDArray[Axes]]] | None):
    if plot_every is None:
        return None

    if plot_every < n_rows * n_cols:
        raise ValueError(
            f"The plot_every must be greater than {n_rows * n_cols = }."
        )

    projected, interpolated = False, False
    if isinstance(sgdw_, wrappers.SGDWProjectedDecorator):
        projected: bool = True
        sgdw_: wrappers.SGDWProjectedDecorator
        interpolated: bool = sgdw_.interp_strategy is not None

    # Previous callback to be decorated
    previous_callback = sgdw_.callback

    # Needed callbacks for this plotter
    pos_wgt_callback, pos_wgt = register_key("pos_wgt")
    pos_wgt: list[torch.Tensor]

    pos_wgt_samp_callback, pos_wgt_samp = register_key("lst_mu")
    pos_wgt_samp: list[list[DistributionDraw]]

    pos_wgt_proj_callback, pos_wgt_proj = lambda x: None, []
    if projected:
        pos_wgt_proj_callback, pos_wgt_proj = register_key("pos_wgt_proj")
    pos_wgt_proj: list[torch.Tensor]

    pos_wgt_proj_interp_callback, pos_wgt_proj_interp = lambda x: None, []
    if interpolated:
        pos_wgt_proj_interp_callback, pos_wgt_proj_interp = register_key(
            "pos_wgt_proj_interp"
        )
    pos_wgt_proj_interp: list[torch.Tensor]

    # The decorated callback
    decorated_callback = mix_callbacks([
        previous_callback,
        pos_wgt_callback,
        pos_wgt_samp_callback,
        pos_wgt_proj_callback,
        pos_wgt_proj_interp_callback,
    ])

    def plot_fn(
        init: int = None,
        n_rows: int = n_rows,
        n_cols: int = n_cols,
    ) -> tuple[Figure, npt.NDArray[Axes]]:
        max_imgs = n_cols * n_rows
        max_k = sgdw_.iter_params.k
        init = max_k - max_imgs if init is None else init
        if init < 0:
            raise ValueError("The init must be greater than 0")
        if init > max_k - max_imgs:
            raise ValueError(
                f"The init must be less than {max_k - max_imgs}. "
                f"Currently: {init = }"
            )

        rows_per_iter = 2
        if projected:
            rows_per_iter += 1
        if interpolated:
            rows_per_iter += 1
        row, col = n_rows * rows_per_iter, n_cols

        fig, ax = plt.subplots(
            row,
            col,
            figsize=(col * factor, row * factor),
            subplot_kw={"xticks": [], "yticks": []},
        )
        fig: Figure

        title = "SGDW" if not projected else "Projected SGDW"
        fig.suptitle(kwargs.get("title", title))

        for i, j in itertools.product(range(n_rows), range(n_cols)):
            k = init + i * n_cols + j
            gamma_k = sgdw_.schd.step_schedule(k)

            ax0: Axes = ax[rows_per_iter * i, j]
            ax1: Axes = ax[rows_per_iter * i + 1, j]
            if projected:
                ax2: Axes = ax[rows_per_iter * i + 2, j]
            if interpolated:
                ax3: Axes = ax[rows_per_iter * i + 3, j]

            # Label the y-axis
            if j == 0:
                ax0.set_ylabel(f"Sample")
                ax1.set_ylabel(f"Step")
                if projected:
                    ax2.set_ylabel(f"Projected")
                if interpolated:
                    ax3.set_ylabel(f"Interpolated")

            # Plot the sample
            fig_sample: DistributionDraw = pos_wgt_samp[k][0]
            ax0.imshow(fig_sample.image, cmap=cmap)
            ax0.set_title(
                f"$k={k}$",
                size="x-small",
            )

            # Plot the step
            fig_step: DistributionDraw = sgdw_.as_distribution(pos_wgt[k])
            ax1.imshow(fig_step.image, cmap=cmap)
            ax1.set_title(
                f"$\\text{{step}}_k={gamma_k * 100:.1f}\\%$",
                size="x-small",
            )

            # Plot the projected
            if projected:
                fig_proj: DistributionDraw = sgdw_.as_distribution(
                    pos_wgt_proj[k]
                )
                ax2.imshow(fig_proj.image, cmap=cmap)

            if interpolated:
                # Plot the interpolated
                fig_interp: DistributionDraw = sgdw_.as_distribution(
                    pos_wgt_proj_interp[k]
                )
                ax3.imshow(fig_interp.image, cmap=cmap)
                interp_k = sgdw_.interp_step_schd(k)
                ax3.set_title(
                    f"$\\text{{interp}}_k={interp_k * 100:.1f}\\%$",
                    size="x-small",
                )

        plt.tight_layout()

        plt.show()

        return fig, ax

    def plot_callback(x: dict) -> None:
        decorated_callback(x)

        if x["k"] > 0 and x["k"] % plot_every == 0:
            plot_fn()

    sgdw_.callback = plot_callback

    return plot_fn

In [None]:
dist_draw_sgdw = sgdw.DistributionDrawSGDW(
    distr_sampler=distr_sampler,
    step_scheduler=step_scheduler(**params),
    batch_size=BATCH_SIZE,
    max_iter=MAX_ITER,
    tol=TOL,
    wass_dist_every=WASS_DIST_EVERY,
    # callback=print_x,
)

dist_draw_sgdw = wrappers.ReportProxy(
    dist_draw_sgdw,
    report_every=REPORT_EVERY,
    log=_log,
)

dist_draw_sgdw = wrappers.SGDWProjectedDecorator(
    dist_draw_sgdw,
    projector=proj,
    project_every=PROJ_EVERY,
    interp_strategy="geodesic",
    # interp_strategy="linear",
    interp_step_schd=step_scheduler(**params_interp),
)
dist_draw_sgdw

Definimos una clase para comparar las imágenes de las muestras con la iteración del algoritmo.

In [None]:
plot_fn = plot_comparison(
    dist_draw_sgdw,
    plot_every=8,  # PLOT_EVERY,
    n_cols=8,
    n_rows=1,
    cmap="binary_r",
)

In [None]:
_log.info(f"Running SGD-Wasserstein with '{DS_NAME}' DS barycenter")
bar_ds = dist_draw_sgdw.run()

In [None]:
plotter_comp = plotter_comp_ds = plotters.PlotterComparisonProjected(
    dist_draw_sgdw,
    projector=proj,
    proj_every=PROJ_EVERY,
    n_cols=8,
    n_rows=2,
    cmap="binary_r",
    plot_every=PLOT_EVERY,
    proj_kwargs=dict(
        interp_strategy="geodesic",
        # interp_strategy="linear",
        interp_step_schd=step_scheduler(**params_interp),
    ),
)
plotter_comp.sgdw

In [None]:
_log.info(f"Running SGD-Wasserstein with '{DS_NAME}' DS barycenter")
bar_ds = plotter_comp.run()

In [None]:
plotter_comp.sgdw

Obtenemos una visualización de las primeras imágenes.

In [None]:
fig, _ = plotter_comp.plot(0)
save_fig(fig, "first-iters-DS")

Obtenemos una visualización de las últimas imágenes.

In [None]:
fig, _ = plotter_comp.plot()
save_fig(fig, "last-iters-DS")

In [None]:
import bwb.utils.plotters as utils_plotters

fig, _ = utils_plotters.plot_draw(bar_ds, title="Baricentro Proyectado del DS")
# save_fig(fig, "bar-GAN")

# Baricentro de la GAN

## Definir el algoritmo

In [None]:
distr_sampler = GeneratorDistribSampler()
distr_sampler.fit(
    generator=G,
    noise_sampler=noise_sampler,
    transform_out=transform_out_,
)
distr_sampler.draw()

In [None]:
dist_draw_sgdw = sgdw.DistributionDrawSGDW(
    distr_sampler=distr_sampler,
    step_scheduler=step_scheduler(**params),
    batch_size=BATCH_SIZE,
    max_iter=MAX_ITER,
    tol=TOL,
    wass_dist_every=WASS_DIST_EVERY,
)

dist_draw_sgdw = wrappers.LogWassDistProxy(dist_draw_sgdw)
iterations, wass_dist_list, wass_dist_smoothed_list = dist_draw_sgdw.get_lists()

dist_draw_sgdw = wrappers.ReportProxy(
    dist_draw_sgdw, report_every=REPORT_EVERY, log=_log
)
dist_draw_sgdw

Definimos una clase para comparar las imágenes de las muestras con la iteración del algoritmo.

In [None]:
plotter_comp = plotters.PlotterComparisonProjected(
    dist_draw_sgdw,
    projector=proj,
    proj_every=PROJ_EVERY,
    n_cols=8,
    n_rows=2,
    cmap="binary_r",
    plot_every=PLOT_EVERY,
    proj_kwargs=dict(
        interp_strategy="geodesic",
        # interp_strategy="linear",
        interp_step_schd=step_scheduler(**params_interp),
    ),
)
plotter_comp.sgdw

### Celda para correr algoritmo

In [None]:
_log.info(f"Running SGD-Wasserstein with '{DS_NAME}' GAN barycenter")
bar_gan = plotter_comp.run()
plotter_comp.sgdw

Obtenemos una visualización de las primeras imágenes.

In [None]:
fig, _ = plotter_comp.plot(0)
save_fig(fig, "first-iters-GAN")

Obtenemos una visualización de las últimas imágenes.

In [None]:
fig, _ = plotter_comp.plot()
save_fig(fig, "last-iters-GAN")

In [None]:
import bwb.utils.plotters as utils_plotters

fig, _ = utils_plotters.plot_draw(
    bar_gan, title="Baricentro Proyectado de la GAN"
)
# save_fig(fig, "bar-GAN")

In [None]:
wass_distance(bar_ds, bar_gan)