# Constantes

In [None]:
# Whether to clean the logs directory
CLEAN_LOGS = True

S_k = 25
PROJ_EVERY = 3
REPORT_EVERY = 100
MAX_ITER = 5_000
MAX_ITER_PROJ = MAX_ITER
MAX_IMGS = 22 * 4
# MAX_ITER = MAX_IMGS; REPORT_EVERY = 5  # Descomentar para debuguear
MAX_ITER = MAX_IMGS  # Descomentar para debuguear

# Importaciones Generales

In [None]:
%cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters
# %cd D:\CodeProjects\Python\Bayesian-Learning-with-Wasserstein-Barycenters\

In [None]:
from bwb import sgdw, utils
from bwb.distributions import *
from bwb.transports import *

In [None]:
sgdw.De

## Configuración del Logger

In [None]:
import time
from pathlib import Path


# Create the logs directory
LOG_PATH = Path("logs") / f"notebook_{time.strftime('%Y%m%d_%H%M%S')}.log"
if not LOG_PATH.parent.exists():
    LOG_PATH.parent.mkdir()

# Clean the logs
if CLEAN_LOGS:
    for log_file in Path("logs").glob("*.log"):
        log_file.unlink()

In [None]:
import logging
from bwb._logging import log_config


# Remove the handlers
log_config.remove_all_handlers()

# Define and add FileHandler
fh = logging.FileHandler(LOG_PATH)
log_config.set_default_formatter(fh)
log_config.add_handler(fh)


_log = log_config.get_logger("notebook")
log_config.set_level(level=logging.DEBUG, name="notebook")
log_config.set_level(level=logging.INFO, name="bwb.utils")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw")
log_config.set_level(level=logging.INFO, name="bwb.transports")

In [None]:
# Set the default options for the report
INCLUDE_OPTIONS: sgdw.ReportOptions = {
    "dt": False,
    "dt_per_iter": True,
    "iter": True,
    "step_schd": True,
    "total_time": True,
    "w_dist": False,
}

sgdw.Report.INCLUDE_OPTIONS = INCLUDE_OPTIONS

# Importar las redes neuronales

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import Generator, Encoder, LatentDistribution
import torch
from wgan_gp.wgan_gp_vae.utils import load_checkpoint


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)

CURR_PATH = Path(".")
NETS_PATH = CURR_PATH / "wgan_gp" / "networks" 

DS_NAME = "data"

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()
print()

In [None]:
noise_sampler

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning
disable_beta_transforms_warning()

import torchvision.transforms.v2 as T


z = noise_sampler(1)
m = G(z)

transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImage(),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.5,), (0.5,)),
])

transform_out = T.Compose([
    T.ToDtype(torch.float64),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])

out: DistributionDraw = transform_out(m)
print(out.dtype)
out

# Proyector

In [None]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms as T

transform_in_proj = T.Compose([
    # From pdf to grayscale
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToTensor(),
    T.Normalize(
        [0.5 for _ in range(1)],
        [0.5 for _ in range(1)],
    ),
])

transform_out_proj = T.Compose([
    # Ensure the range is in [0, 1]
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.max(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: x.squeeze(0)),
])

proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

# Obtener Dataset

In [None]:

import bwb.distributions as dist

class DatasetWrapper:
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)
    
    def get(self, i: int, **kwargs) -> dist.DistributionDraw:
        return dist.DistributionDraw.from_grayscale_weights(self.dataset[i][0])

In [None]:
import quick_torch as qt

def get_ds(file_path, transform):
    categories = [qt.Category.FACE]
    dataset_ = qt.QuickDraw(
        root="dataset",
        categories=categories,
        transform=transform,
        download=True,
        recognized=True,
    )
    print(len(dataset_))
    path_dataset = Path(file_path)
    dataset_.data = np.load(path_dataset).reshape(-1, 28, 28)
    dataset_.targets = np.ones(len(dataset_.data), dtype=int)
    dataset = dataset_.get_train_data()
    len(dataset)

    return DatasetWrapper(dataset)

transform_ds = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
    T.Lambda(lambda x: x.squeeze()),
])

# Definir $\gamma_k$

Aquí se utiliza una función de la forma
\begin{equation*}
    \gamma_k = \frac{a}{(b^{1/c} + k)^c}
\end{equation*}

Con $a > 0$, $b \geq 0$ y $0.5 < c \leq 1$

La idea es que cuando $k=0$, $\gamma_0 = \frac{a}{b}$ es la proporción entre $a$ y $b$, permitiendo ajustar el valor inicial.

In [None]:
from bwb.sgdw import Gamma
window = 5

def test_gamma(gamma):

    for t in range(window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 50
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 100
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 300
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 500
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 1_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 3_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    init = 5_000
    for t in range(init, init+window):
        print(f"{t = :_}; {gamma(t) = :.2%}")
    print()

    # init = 10_000
    # for t in range(init, init+window):
    #     print(f"{t = :_}; {gamma(t) = :.2%}")
    # print()

    # init = 20_000
    # for t in range(init, init+window):
    #     print(f"{t = :_}; {gamma(t) = :.2%}")
    # print()

    # init = 50_000
    # for t in range(init, init+window):
    #     print(f"{t = :_}; {gamma(t) = :.2%}")
    # print()

    # init = 100_000
    # for t in range(init, init+window):
    #     print(f"{t = :_}; {gamma(t) = :.2%}")


_a = 3
_eps = 0.2
params = dict(a=_a, b=_a+1e-2, c=0.5+_eps)
# params = dict(a=1, b=1, c=1)

gamma = Gamma(**params)

test_gamma(Gamma(**params))

# Variedad de caritas 1: Caras normales

## Baricentro de la red

### Versión sin proyectar

In [None]:
DS_NAME = "data"

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()

pi_n: GeneratorDistribSampler[DistributionDraw] = GeneratorDistribSampler()
pi_n.fit(generator=G, noise_sampler=noise_sampler, transform_out=transform_out)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)


In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión proyectada

#### Iteración 1

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj. Iter 1")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iteración 2

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj. Iter 2")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iteración 3

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj. Iter 3")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iteración 4

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj. Iter 4")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iteración 5

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj. Iter 5")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

## Baricentro de las imágenes

### Versión sin proyectar

In [None]:
DS_PATH = Path("./wgan_gp/dataset") / "cleaned" / f"{DS_NAME}.npy"
models = get_ds(DS_PATH, transform_ds)
models.get(1)

In [None]:
pi_n: UniformDiscreteSampler[DistributionDraw] = UniformDiscreteSampler().fit(models=models)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión Proyectada

#### Iteración 1

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj. Iter 1")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iter 2

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj. Iter 2")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iter 3

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj. Iter 3")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iter 4

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj. Iter 4")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

#### Iter 5

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj. Iter 5")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

---

Hasta aquí termina!

# Variedad de caritas 2: Caras sin contorno

## Baricentro de la red

### Versión sin proyectar

In [None]:
DS_NAME = "data_sin_contorno"

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()

pi_n: GeneratorDistribSampler[DistributionDraw] = GeneratorDistribSampler()
pi_n.fit(generator=G, noise_sampler=noise_sampler, transform_out=transform_out)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión proyectada

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

## Baricentro de las imágenes

### Versión sin proyectar

In [None]:
DS_PATH = Path("./wgan_gp/dataset") / "cleaned" / f"{DS_NAME}.npy"
models = get_ds(DS_PATH, transform_ds)
models.get(1)

In [None]:
pi_n: UniformDiscreteSampler[DistributionDraw] = UniformDiscreteSampler().fit(models=models)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión Proyectada

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

---

Hasta aquí termina!

# Variedad de caritas 3: Caras sin contorno arriba

## Baricentro de la red

### Versión sin proyectar

In [None]:
DS_NAME = "data_sin_contorno_arriba"

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

FACE_PATH = NETS_PATH / f"cleaned_{DS_NAME}_zDim{LATENT_DIM}_{NOISE}_bs_128"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval(); E.eval()

pi_n: GeneratorDistribSampler[DistributionDraw] = GeneratorDistribSampler()
pi_n.fit(generator=G, noise_sampler=noise_sampler, transform_out=transform_out)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión proyectada

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} GAN bar. with proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

## Baricentro de las imágenes

### Versión sin proyectar

In [None]:
DS_PATH = Path("./wgan_gp/dataset") / "cleaned" / f"{DS_NAME}.npy"
models = get_ds(DS_PATH, transform_ds)
models.get(1)

In [None]:
pi_n: UniformDiscreteSampler[DistributionDraw] = UniformDiscreteSampler().fit(models=models)
pi_n.draw()

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. without proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

### Versión Proyectada

In [None]:
proj = ProjectorOnManifold(
    E, G, 
    transform_in=transform_in_proj,
    transform_out=transform_out_proj,
)

In [None]:
face = pi_n.draw()
face

In [None]:
DistributionDraw.from_grayscale_weights(proj(face.grayscale_weights))

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    projector=proj,
    proj_every=PROJ_EVERY,
    learning_rate=Gamma(**params),
    batch_size=S_k,
    max_iter=MAX_ITER_PROJ,
    report_every=REPORT_EVERY,
).set_geodesic_params(
    reg=1e-2,
    stopThr=1e-3,
)
dist_draw_sgdw.det_params, dist_draw_sgdw.hist

In [None]:
_log.info(f"Running SGD-Wasserstein with {DS_NAME} DS bar. with proj.")
bar, hist = dist_draw_sgdw.run(
    distr_hist=True,
    distr_samp_hist=True,
)
dist_draw_sgdw.iter_params

In [None]:
bar

In [None]:
utils.plot_list_of_draws(hist.distr, max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp], max_images=MAX_IMGS)

In [None]:
utils.plot_list_of_draws(hist.distr[-MAX_IMGS:], max_images=MAX_IMGS)
utils.plot_list_of_draws([x[0] for x in hist.distr_samp[-MAX_IMGS:]], max_images=MAX_IMGS)

---

Hasta aquí termina!