Notebook que sirve de ejemplo para calcular el baricentro de un conjunto de datos de imágenes.

# Configuraciones iniciales

## Constantes

In [1]:
NOTEBOOK = 70
CLEAN_LOGS = True  # If you want to clean the logs directory
SAVE_FIGS = True  # If you want to save the figures.

REPORT_EVERY = 100  # To report at the logger
PLOT_EVERY = 100
MAX_ITER = 1_000  # MAx number of iterations for the SGDW
BATCH_SIZE = 1
PROJ_EVERY = 1

# MAX_ITER = 50; REPORT_EVERY = 5  # Descomentar para debuguear

In [2]:
import random

import numpy as np
import torch

SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

torch.backends.cudnn.deterministic = True

In [3]:
from pathlib import Path
from icecream import ic


DS_NAME = "data"

CURR_PATH = Path().absolute()
ic(CURR_PATH)
BASE_PATH = CURR_PATH.parent.parent
ic(BASE_PATH)
DATA_PATH = BASE_PATH / "data"
ic(DATA_PATH)
WGAN_PATH = BASE_PATH / "wgan_gp"
ic(WGAN_PATH)
NETS_PATH = WGAN_PATH / "networks" 
ic(NETS_PATH)
IMGS_PATH = CURR_PATH / "imgs" / f"notebook-{NOTEBOOK:02d}"
IMGS_PATH.mkdir(parents=True, exist_ok=True)
ic(IMGS_PATH)

In [4]:
def save_fig(fig_, name_to_save: str) -> None:
    """
    Saves a figure using the name
    """
    if SAVE_FIGS:
        PATH_TO_SAVE = IMGS_PATH / name_to_save
        kwargs = dict(
            bbox_inches="tight",
        )
        fig_.savefig(PATH_TO_SAVE.with_suffix(".pdf"), **kwargs)
        fig_.savefig(PATH_TO_SAVE.with_suffix(".png"), **kwargs)

## Importaciones generales

In [5]:
from icecream import ic
from bwb.sgdw import sgdw
from bwb.sgdw import wrappers
from bwb.sgdw import plotters
from bwb.distributions import *

## Configuraciones 

In [6]:
from bwb.config import conf

conf.use_gpu()
conf.use_single_precision()
conf.set_eps(1e-16)
conf

## Configuración del Logger

In [7]:
import time
from pathlib import Path


# Create the logs directory
LOG_PATH = Path("logs") / f"notebook-{NOTEBOOK:02d}_{time.strftime('%Y%m%d_%H%M%S')}.log"
if not LOG_PATH.parent.exists():
    LOG_PATH.parent.mkdir()

# Clean the logs
if CLEAN_LOGS:
    for log_file in Path("logs").glob(f"notebook-{NOTEBOOK:02d}*.log"):
        log_file.unlink()

In [8]:
import logging
from bwb.logging_ import log_config


# Remove the handlers
log_config.remove_all_handlers()
ic(log_config.loggers)

# Define and add FileHandler
fh = logging.FileHandler(LOG_PATH)
log_config.add_handler(fh)


_log = log_config.get_logger("notebook")
log_config.set_level(level=logging.DEBUG, name="notebook")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.sgdw")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.plotters")

Esta celda es para configurar la información mostrada en el logger

In [9]:
# Set the default options for the report
wrappers.ReportProxy.INCLUDE_OPTIONS = wrappers.ReportOptions(
    dt=False,
    dt_per_iter=True,
    iter=True,
    step_schd=True,
    total_time=True,
    w_dist=False
)

## Obtención del dataset

In [10]:
# You can use the wrapper to transform the usual DataSet into a model set
from bwb.distributions.models import ModelDataset
import quick_torch as qt
import torchvision.transforms.v2 as T

transform_ds = T.Compose([
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Lambda(lambda x: x.squeeze()),
])

def get_ds(file_path: Path, transform=transform_ds) -> ModelDataset:
    """
    Get a dataset to follow the interface
    """
    ic(file_path)
    categories = [qt.Category.FACE]
    dataset_ = qt.QuickDraw(
        root=DATA_PATH,
        categories=categories,
        transform=transform,
        download=True,
        recognized=True,
    )
    path_dataset = Path(file_path)
    dataset_.data = np.load(path_dataset).reshape(-1, 28, 28)
    dataset_.targets = np.ones(len(dataset_.data), dtype=int)
    dataset = dataset_.get_train_data()
    ic(len(dataset))

    return ModelDataset(dataset)

DS_PATH = WGAN_PATH / "dataset" / "cleaned" / f"{DS_NAME}.npy"
ds_models = get_ds(DS_PATH)

ds_models.get(0)

In [11]:
ds_models

## Obtener GAN

De la misma manera, se puede definir un muestreador de distribuciones utilizando una GAN. Para ello, empezamos definiendo las redes neuronales a utilizar

In [12]:
from wgan_gp.wgan_gp_vae.model_resnet import Generator, Encoder, LatentDistribution
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = conf.device

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)


G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

DS_NAME = "data"
FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"
ic(FACE_PATH)

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()
print()

In [13]:
noise_sampler

In [14]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(conf.dtype),
    T.Normalize((0.5,), (0.5,)),
])

transform_out_ = T.Compose([
    T.ToDtype(conf.dtype),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
])

transform_out = T.Compose([
    transform_out_,
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])


out: DistributionDraw = transform_out(m)
print(out.dtype)
out

## Definir Proyector

In [15]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms.v2 as T

transform_in_proj = T.Compose([
    # From pdf to grayscale
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToImage(),
    T.ToDtype(conf.dtype, scale=True),
    T.Normalize(
        [0.5 for _ in range(1)],
        [0.5 for _ in range(1)],
    ),
])

transform_out_proj = T.Compose([
    # Ensure the range is in [0, 1]
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.max(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: x.squeeze(0)),
])

_proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

def proj(input_: torch.Tensor) -> torch.Tensor:
    """
    Defines a projector using the interface.
    """
    return _proj(input_).to(input_)

## Definir $\gamma_k$

Aquí se utiliza una función de la forma
\begin{equation*}
    \gamma_k = \frac{a}{(b^{1/c} + k)^c}
\end{equation*}

Con $a > 0$, $b \geq 0$ y $0.5 < c \leq 1$

La idea es que cuando $k = 0 $, $\gamma_0 = \frac{a}{b}$ es la proporción entre $a$ y $b$, permitiendo ajustar el valor
 inicial.

In [16]:
from bwb.sgdw.utils import step_scheduler
window = 5

def test_gamma(gamma_) -> None:

    for t in range(window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 50
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 100
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 300
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 500
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 1_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 3_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()

    init = 5_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma_(t) = :.2%}")
    print()


_a = 3
_eps = 1e-3
params = dict(a=_a, b=_a+1e-2, c=0.5+_eps)
# params = dict(a=1, b=1, c=1)

gamma = step_scheduler(**params)
# gamma = 0.1

test_gamma(step_scheduler(**params))

# Experimentos

In [17]:
import bwb.utils.plotters as plotters_


def experiment(
    max_iter=MAX_ITER,
    gamma_=gamma,
    batch_size=BATCH_SIZE,
    proj_every=PROJ_EVERY,
    plot_every=None,
    col_row=(8, 2),
    register_iter=False,
):
    # Def sampler
    distr_sampler = UniformDiscreteSampler[DistributionDraw]().fit(models=ds_models)

    # Def sgdw
    dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
        distr_sampler=distr_sampler,
        step_scheduler=gamma_,
        batch_size=batch_size,
        max_iter=max_iter,
    )
    dist_draw_sgdw = wrappers.ReportProxy(
        dist_draw_sgdw,
        report_every=REPORT_EVERY,
        log=_log
    )

    # Def plotter
    plotter_comp = plotter_comp_ds = plotters.PlotterComparisonProjected(
        dist_draw_sgdw,
        projector=proj,
        proj_every=proj_every,
        n_cols=col_row[0],
        n_rows=col_row[1],
        cmap="binary_r",
        plot_every=plot_every,
    )

    if register_iter:
        plotter_comp.sgdw = wrappers.LogDistrIterProxy(plotter_comp.sgdw)

    bar = plotter_comp.run()
    
    return plotter_comp, bar

Experimento del proyectado cada tanto

In [18]:
for pe in [1, 3, 5, 10]:
    plotter_comp, bar = experiment(
        gamma_=gamma,
        batch_size=1,
        proj_every=pe,
        col_row=(8, 2)
    )

    fig, _ = plotter_comp.plot(
        0, title=f"Primeras iteraciones SGDWP, proyectando cada {pe}"
    )
    save_fig(fig, f"first-iters-pe-{pe}")

    fig, _ = plotter_comp.plot(
        title=f"Últimas iteraciones SGDWP, proyectando cada {pe}"
    )
    save_fig(fig, f"last-iters-pe-{pe}")

    fig, _ = plotters_.plot_draw(bar, title=f"Baricentro usando SGDWP, proyectando cada {pe}")
    save_fig(fig, f"bar-SGDWP-pe-{pe}")

Experimento de correr más veces

In [19]:
for i in range(1, 11):
    plotter_comp, bar = experiment(
        gamma_=gamma,
        batch_size=1,
        proj_every=1,
        col_row=(8, 2)
    )

    fig, _ = plotters_.plot_draw(bar, title=f"Baricentro usando SGDWP, experimento {i}")
    save_fig(fig, f"bar-SGDWP-exp-{i}")

Experimento paseo aleatorio

In [20]:

plotter_comp, bar = experiment(
    max_iter=50,
    gamma_=0.1,
    batch_size=1,
    proj_every=1,
    col_row=(8, 2),
    register_iter=True,
)

sgdw_ = plotter_comp.sgdw

for i in range(500, len(sgdw_.register_lst), 100):
    d = sgdw_[i]
    plotters_.plot_draw(d, title=f"Baricentro SGDWP en la iteración {i}")
    save_fig(fig, f"bar-SGDWP-random-walk-iter-{i}")