# Importaciones generales

In [None]:

%cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters
# %cd D:\CodeProjects\Python\Bayesian-Learning-with-Wasserstein-Barycenters\
# Soy una nueva linea

In [None]:
from torch import linalg as LA

from bwb import sgdw, utils
from bwb.distributions import *
from bwb.distributions.data_loaders import *
from bwb.geodesics import *
from bwb.transports import *

In [None]:
import logging
from bwb.logging import log_config
import time
from pathlib import Path

# Remove the handlers
log_config.remove_all_handlers()

# Define and add FileHandler
LOG_PATH = Path("logs") / f"notebook_{time.strftime('%Y%m%d_%H%M%S')}.log"
fh = logging.FileHandler(LOG_PATH)
log_config.set_default_formatter(fh)
log_config.add_handler(fh)

_log = log_config.get_logger("notebook")
print(_log.level)
log_config.set_level(level=logging.DEBUG, name="notebook")
log_config.set_level(level=logging.INFO, name="bwb.utils")
log_config.set_level(level=logging.DEBUG, name="bwb.sgdw.sgdw")
log_config.set_level(level=logging.INFO, name="bwb.transports")

In [None]:
_log.level

# Cargar datos

In [None]:
CURRENT_PATH = Path(".")
DATA_PATH = CURRENT_PATH / Path("data")
FACE_PATH = DATA_PATH / Path("face.npy")
FACE_PATH

In [None]:
arr = np.load(FACE_PATH)
arr.shape

In [None]:
shape = (28, 28)
faces = DistributionDrawDataLoader(arr, shape)
faces

In [None]:
first_face = faces[0]
first_face

In [None]:
max_images = 36
faces_list = []
for k in range(max_images):
    faces_list.append(faces[k])

In [None]:
utils.plot_list_of_draws(faces_list, max_images=max_images)

## Construir Posterior

In [None]:
x = first_face.sample((3,))
x

In [None]:
from bwb.config import config

gen = torch.Generator(device=config.device).manual_seed(2147483647)
gen.initial_seed()
# gen.seed()
gen.initial_seed()

In [None]:
%%time

pi_n: ExplicitPosteriorSampler[DistributionDraw] = (
    ExplicitPosteriorSampler().fit(data=x, models=faces)
)
pi_n

# Calcular baricentros

## Definir $\gamma_k$

Aquí se utiliza una función de la forma
\begin{equation*}
    \gamma_k = \frac{a}{(b^{1/c} + k)^c}
\end{equation*}

Con $a > 0$, $b \geq 0$ y $0.5 < c \leq 1$

La idea es que cuando $k=0$, $\gamma_0 = \frac{a}{b}$ es la proporción entre $a$ y $b$, permitiendo ajustar el valor inicial.

In [None]:
class Gamma:
    def __init__(self, a=1, b=0, c=1):
        self.a = a
        self.b = b
        self.c = c

    def __call__(self, k):
        return self.a / (self.b ** (1 / self.c) + k) ** self.c

In [None]:
window = 3

a_ = 0.5
params = dict(a=a_, b=a_ + 0.1, c=0.51)
gamma = Gamma(**params)

for t in range(window):
    print(f"{t = }; {gamma(t) = :.8f}")
print()

init = 300
for t in range(init, init + window):
    print(f"{t = }; {gamma(t) = :.8f}")
print()

init = 500
for t in range(init, init + window):
    print(f"{t = }; {gamma(t) = :.8f}")
print()

init = 3_000
for t in range(init, init + window):
    print(f"{t = }; {gamma(t) = :.8f}")
print()

init = 10_000
for t in range(init, init + window):
    print(f"{t = }; {gamma(t) = :.8f}")

## Baricentro con 50 caras, 3 datos

In [None]:
faces = DistributionDrawDataLoader(arr[:100, :], shape)
faces

In [None]:
x = first_face.sample((3,))
x

In [None]:
pi_n: ExplicitPosteriorSampler[DistributionDraw] = (
    ExplicitPosteriorSampler().fit(data=x, models=faces)
)
pi_n

In [None]:
probs = pi_n.probabilities_
torch.round(
    pi_n.probabilities_[torch.round(probs, decimals=4) > 0] * 100, decimals=2
)

In [None]:
transport = EMDTransport(max_iter=250_000)

In [None]:
X_k, m, pos_hist, samples_hist = sgdw.sgdw.compute_bwb_discrete_distribution(
    transport=transport,
    distrib_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=1,
    alpha=0.01,
    tol=0.0,
    max_iter=500,
    max_time=0.5,
    position_history=True,
    distrib_sampler_history=True,
    report_every=20,
)

### Primeras iteraciones

In [None]:
utils.plot_list_of_draws(
    [DistributionDraw(m, shape, X) for X in pos_hist[:36]], max_images=36
)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=33)

### Últimas iteraciones

In [None]:
utils.plot_list_of_draws(
    [DistributionDraw(m, shape, X) for X in pos_hist[-36:]], max_images=36
)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist[-33:]], max_images=33)

## Baricentro usando la clase, con 50 caras y 3 datos

In [None]:
dd_sgdw = sgdw.DiscreteDistributionSGDW(
    transport=transport,
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    batch_size=1,
    alpha=0.01,
    tol=0.0,
    max_iter=500,
    max_time=10,
    report_every=20,
)

In [None]:
dd_sgdw.hist

In [None]:
# from bwb.sgdw.plotters import PlotterComparison

# plotter = PlotterComparison(dd_sgdw)

print(dd_sgdw.iter_params)
bar, hist = dd_sgdw.run(
    pos_wgt_hist=True,
    distr_samp_hist=True,
)
# out, out2 = dd_sgdw.run(pos_wgt_hist=True, distr_hist=True)
print(dd_sgdw.iter_params)
pos_wgt_hist, distr_samp_hist = hist.pos_wgt, hist.distr_samp

In [None]:
import matplotlib.pyplot as plt

row, col = 3, 12
factor = 1.5

fig, axes = plt.subplots(
    row,
    col,
    figsize=(col * factor, row * factor),
    subplot_kw={"xticks": [], "yticks": []},
)

fig.suptitle("TITULO")

for j in range(col):
    img: DistributionDraw = hist.distr_samp[j][0]
    ax1, ax2, ax3 = axes[0, j], axes[1, j], axes[2, j]
    # ax1.set_axis_off()
    if j == 0:
        ax1.set_ylabel("Abcde")
        ax2.set_ylabel("Fghi")
        ax3.set_ylabel("Jklm")

    ax1.imshow(img.image, cmap="binary")
    ax2.imshow(img.image, cmap="binary_r")
    ax3.imshow(img.image, cmap="binary_r")
    ax1.set_title(f"$k={j}$")
    ax3.set_title(f"$k={j}$")

plt.tight_layout(w_pad=0.1)

In [None]:
img.image

In [None]:
dd_sgdw.iter_params

In [None]:
DistributionDraw.from_discrete_distribution(bar, shape)

### Primeras iteraciones

In [None]:
utils.plot_list_of_draws(
    [DistributionDraw(m, shape, X) for (X, m) in pos_wgt_hist[:36]],
    max_images=36,
)

In [None]:
utils.plot_list_of_draws([x[0] for x in distr_samp_hist[:33]], max_images=33)

### Últimas iteraciones

In [None]:
utils.plot_list_of_draws(
    [DistributionDraw(m, shape, X) for (X, m) in pos_wgt_hist[-36:]],
    max_images=36,
)

In [None]:
utils.plot_list_of_draws([x[0] for x in distr_samp_hist[-33:]], max_images=33)

## Baricentro utilizando la GAN

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import (
    Generator,
    Encoder,
    LatentDistribution,
)
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NOISE = "norm"
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

G = Generator(LATENT_DIM, CHANNELS_IMG, latent_distr=NOISE).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)
noise_sampler = LatentDistribution(NOISE, LATENT_DIM, device)

CURR_PATH = Path(".")
NETS_PATH = CURR_PATH / "wgan_gp" / "networks"

FACE_PATH = NETS_PATH / f"cleaned_clustered_zDim128_norm_bs_128"
# FACE_PATH = NETS_PATH / f"_resnet_face_zDim{LATENT_DIM}_{NOISE}_bs_128_cleaned_augmented_WAE_WGAN_loss_l1_32p32"
# FACE_PATH = NETS_PATH / f"_resnet_face_zDim{LATENT_DIM}_{NOISE}_bs_128_cleaned_sin_contorno_augmented_WAE_WGAN_loss_l1_32p32"

from wgan_gp.wgan_gp_vae.utils import load_checkpoint

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

G.eval()
E.eval()
print()

In [None]:
noise_sampler

In [None]:
z_ = torch.zeros((1,), device="cuda")
# # z = G.sample_noise(100, type_as=z_)
# G_script = torch.jit.script(G)
# # G_script(z)
# G, G_ = torch.jit.freeze(G_script), G
# G
G_ = G

In [None]:
with torch.cuda.amp.autocast():
    z = noise_sampler(1)
    print(z.dtype)

In [None]:
from bwb.distributions import DistributionDraw
from torchvision import disable_beta_transforms_warning

disable_beta_transforms_warning()

import torchvision.transforms.v2 as T

with torch.cuda.amp.autocast():
    z = noise_sampler(1)
    m = G(z)
    print(m.dtype)
print(m.dtype)
transform_in = T.Compose([
    T.Lambda(lambda x: x / torch.max(x)),
    T.ToPILImage(),
    T.Resize(32),
    T.ToImageTensor(),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.5,), (0.5,)),
])
transform_out = T.Compose([
    T.ToDtype(torch.float64),
    T.Lambda(lambda x: x.squeeze()),
    T.Lambda(lambda x: x - torch.min(x)),
    T.Lambda(lambda x: x / torch.sum(x)),
    T.Lambda(lambda x: DistributionDraw.from_grayscale_weights(x)),
])
out: DistributionDraw = transform_out(m)
print(out.dtype)
out

In [None]:
z = noise_sampler(1)
for v in z:
    print(v.unsqueeze(0).shape)

In [None]:
pi_n: GeneratorDistribSampler[DistributionDraw] = GeneratorDistribSampler()
pi_n.fit(generator=G, noise_sampler=noise_sampler, transform_out=transform_out)
pi_n.draw()

In [None]:
pi_n.draw()

In [None]:
mu_k, dist_hist, samples_hist = sgdw.sgdw.compute_bwb_distribution_draw(
    distrib_sampler=pi_n,
    learning_rate=Gamma(**params),
    reg=3e-3,
    max_iter=100,
    max_time=1,
    distribution_history=True,
    distrib_sampler_history=True,
)

In [None]:
max_images = 22 * 4
utils.plot_list_of_draws(dist_hist, max_images=max_images)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=max_images)

In [None]:
utils.plot_list_of_draws(dist_hist[-max_images:], max_images=max_images)

In [None]:
utils.plot_list_of_draws(
    [x[0] for x in samples_hist[-max_images:]], max_images=max_images
)

## Baricentro utilizando la GAN, y utilizando la clase

In [None]:
dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    learning_rate=Gamma(**params),
    max_iter=500,
    # max_time=10,
    report_every=10,
)
dist_draw_sgdw.det_params

In [None]:
from bwb.sgdw.plotters import Plotter
from bwb.distributions import DistributionDraw
from bwb.sgdw.sgdw import BaseSGDW
from bwb.utils import _DistributionT
from bwb.sgdw.utils import _PosWgt


# noinspection PyShadowingNames
class PlotterComparison(Plotter[DistributionDraw, torch.Tensor]):

    def __init__(
        self,
        sgdw: BaseSGDW[_DistributionT, _PosWgt],
        plot_every=50,
        n_cols=12,
        n_rows=1,
        factor=1.5,
        cmap="binary",
    ):
        super().__init__(sgdw, plot_every, n_cols, n_rows, factor, cmap)
        if n_rows * n_cols > plot_every:
            msg = (
                "'plot_every' should not be less than n_rows * n_cols."
                f" Currently: {plot_every = } < {n_rows * n_cols = }"
            )
            raise ValueError(msg)
        self.pos_wgt_hist = True
        self.pos_wgt_samp_hist = True

    def plot(self, init: int = None):
        create_distr = self.sgdw.create_distribution
        max_imgs = self.n_rows * self.n_cols
        max_k = self.sgdw.iter_params.k
        init = max_k - max_imgs + 1 if init is None else init

        row, col = self.n_rows * 2, self.n_cols

        fig, ax = plt.subplots(
            row,
            col,
            figsize=(col * self.factor, row * self.factor),
            subplot_kw={"xticks": [], "yticks": []},
        )

        fig.suptitle("SGDW")

        for i in range(self.n_rows):
            for j in range(self.n_cols):
                k = init + j + i * self.n_cols
                # print(f"{i = }, {j = }, {k = }")
                ax0, ax1 = ax[i * 2, j], ax[i * 2 + 1, j]

                # Label the y-axis
                if j == 0:
                    ax0.set_ylabel("Sample")
                    ax1.set_ylabel("Step")

                # Plot the sample
                fig_sample: DistributionDraw = create_distr(
                    hist.pos_wgt_samp[k][0]
                )
                ax0.imshow(fig_sample.image, cmap=self.cmap)
                ax0.set_title(f"$k={k}$")

                # Plot the step
                fig_step: DistributionDraw = create_distr(hist.pos_wgt[k])
                ax1.imshow(fig_step.image, cmap=self.cmap)

        plt.tight_layout(pad=0.3)

        plt.show()

        return fig, ax


dist_draw_sgdw = sgdw.DebiesedDistributionDrawSGDW(
    distr_sampler=pi_n,
    step_scheduler=Gamma(**params),
    max_iter=100,
    # max_time=10,
    report_every=10,
).set_geodesic_params(
    reg=0.01,
    stop_thr=1e-3,
)

plotter = PlotterComparison(
    dist_draw_sgdw, plot_every=30, n_cols=12, n_rows=2, cmap="binary_r"
)

bar, hist = plotter.run(
    # distr_hist=True,
    # distr_samp_hist=True,
    include_dict=dict(total_time=True),
    # include_time=True,
)
dist_hist, samples_hist = hist.distr, hist.distr_samp

In [None]:
fig, _ = plotter.plot(10)

In [None]:
dist_draw_sgdw.set_geodesic_params(
    reg=0.01,
    stop_thr=1e-3,
)

plotter = PlotterComparison(
    dist_draw_sgdw, plot_every=30, n_rows=2, cmap="binary_r"
)

bar, hist = plotter.run(
    distr_hist=True,
    distr_samp_hist=True,
    include_dict=dict(total_time=True),
    # include_time=True,
)
dist_hist, samples_hist = hist.distr, hist.distr_samp

In [None]:
Gamma(**params)(10000)

In [None]:
import ot

ot.__version__

In [None]:
bar

In [None]:
max_images = 22 * 4
utils.plot_list_of_draws(dist_hist, max_images=max_images)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=max_images)

In [None]:
utils.plot_list_of_draws(dist_hist[-max_images:], max_images=max_images)

In [None]:
utils.plot_list_of_draws(
    [x[0] for x in samples_hist[-max_images:]], max_images=max_images
)

In [None]:
# X_k, m, pos_hist, samples_hist = sgdw.compute_bwb_discrete_distribution(
#     transport=transport,
#     distrib_sampler=pi_n,
#     learning_rate=Gamma(**params),
#     batch_size=1,
#     alpha=1e-6,
#     tol=0.,
#     max_iter=500,
#     max_time=30,
#     position_history=True,
#     distrib_sampler_history=True,
#     report_every=20
# )

### Primeras iteraciones

In [None]:
# utils.plot_list_of_draws([DistributionDraw(m, shape, X) for X in pos_hist[:36]], max_images=36)

In [None]:
# utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=33)

# Baricentro usando métodos convolucionales, con 50 caras, 3 datos

In [None]:
mu_k, dist_hist, samples_hist = sgdw.compute_bwb_distribution_draw(
    distrib_sampler=pi_n,
    learning_rate=Gamma(**params),
    reg=3e-3,
    max_iter=500,
    max_time=30,
    distribution_history=True,
    distrib_sampler_history=True,
)

In [None]:
max_images = 38
utils.plot_list_of_draws(dist_hist, max_images=max_images)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=max_images)

In [None]:
utils.plot_list_of_draws(dist_hist[-max_images:], max_images=max_images)

In [None]:
utils.plot_list_of_draws(
    [x[0] for x in samples_hist[-max_images:]], max_images=max_images
)

# Baricentro usando métodos convolucionales y proyección sobre la variedad, con 50 caras, 3 datos

In [None]:
import torch
from wgan_gp.wgan_gp_vae.model import Generator, Encoder
from wgan_gp.wgan_gp_vae.utils import load_checkpoint
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LATENT_DIM = 100
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

G = Generator(LATENT_DIM, CHANNELS_IMG, NUM_FILTERS).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG, NUM_FILTERS[::-1]).to(device)

CURR_PATH = Path(".")
NETS_PATH = CURR_PATH / "wgan_gp" / "networks"
FACE_PATH = NETS_PATH / "face"

DATA_PATH = CURR_PATH / "data" / "face.npy"

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

proj = ProjectorOnManifold(
    E,
    G,
    transform_in=transforms.Compose([
        # From pdf to grayscale
        transforms.Lambda(lambda x: x / torch.max(x)),
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(1)],
            [0.5 for _ in range(1)],
        ),
    ]),
    transform_out=transforms.Compose([
        # Ensure the range is in [0, 1]
        transforms.Lambda(lambda x: x - torch.min(x)),
        transforms.Lambda(lambda x: x / torch.max(x)),
        transforms.ToPILImage(),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x / torch.sum(x)),
        transforms.Lambda(lambda x: x.squeeze(0)),
    ]),
)

In [None]:
mu_k, dist_hist, samples_hist = sgdw.compute_bwb_distribution_draw_projected(
    distrib_sampler=pi_n,
    projector=proj,
    learning_rate=Gamma(**params),
    reg=2e-3,
    max_iter=500,
    max_time=30,
    distribution_history=True,
    distrib_sampler_history=True,
)

In [None]:
max_images = 17 * 3
utils.plot_list_of_draws(dist_hist, max_images=max_images)

In [None]:
utils.plot_list_of_draws([x[0] for x in samples_hist], max_images=max_images)

In [None]:
utils.plot_list_of_draws(dist_hist[-max_images:], max_images=max_images)

In [None]:
utils.plot_list_of_draws(
    [x[0] for x in samples_hist[-max_images:]], max_images=max_images
)

In [None]:
%%time

geod = BarycentricProjGeodesic(EMDTransport(max_iter=250_000))
eps = 0.0
max_time_iter = 10
max_iter = float("inf")
alpha = 1

# Camino de baricentros
position_historial = []

# Paso 1: samplear un mu_0
mu_0 = pi_n.draw()

# Calcular las masas y ubicaciones
X_k, m = utils.partition(
    X=mu_0.enumerate_nz_support_(), mu=mu_0.nz_probs, alpha=alpha
)
position_historial.append(X_k)

_log.debug(f"{len(X_k) = }")

k = 0
tic, toc = time.time(), time.time()

while k < max_iter and toc - tic < max_time_iter:
    _log.debug("=" * 10 + f" {k = }, Δt = {toc - tic:.4f} [seg] " + "=" * 10)
    # Paso 2: Samplear \tilde\mu_k
    t_mu_k: DistributionDraw = pi_n.draw()
    t_X_k, t_m_k = t_mu_k.enumerate_nz_support_(), t_mu_k.nz_probs

    # Calcular transporte óptimo
    geod.fit(
        Xs=X_k,
        mu_s=m,
        Xt=t_X_k,
        mu_t=t_m_k,
    )
    T_X_k = geod.transport.transform(X_k)
    _log.debug(f"{T_X_k.shape = }")

    # Calcular la distribución de mu_{k+1}
    gamma_k = gamma(k)
    _log.debug(f"{gamma_k = :.6f}")
    X_kp1, _ = geod.interpolate(gamma_k)
    _log.debug(f"{X_kp1.shape = }")

    # Calcular la distancia de Wasserstein
    diff = X_k - T_X_k
    dist = float((gamma_k**2) * torch.sum(m * LA.norm(diff, dim=1) ** 2))
    _log.debug(f"{dist = :.8f}")

    while dist < eps:
        break

    position_historial.append(X_kp1)

    k += 1
    X_k = X_kp1
    toc = time.time()

In [None]:
eps = 0.0
alpha = 1.0 / 10
transp = EMDTransport()
max_time_iter = 10  # 60 * 60  # Una hora
max_iter = float("inf")

# Camino de baricentros
position_historial = []

# Paso 1: samplear un mu_0
mu_0 = pi_n.draw()

# Calcular las masas y ubicaciones
X, m_ = [], []
min_w, max_w = mu_0.nz_probs.min(), mu_0.nz_probs.max()

for x, w, n in zip(
    mu_0.enumerate_nz_support_(),
    mu_0.nz_probs,
    torch.ceil(alpha * mu_0.nz_probs / min_w).to(int),
):
    for _ in range(n):
        X.append(x.reshape(1, -1))
        m_.append(w / n)

X_k = torch.cat(X, 0)
m = torch.as_tensor(m_, dtype=mu_0.dtype, device=mu_0.device)
position_historial.append(X_k)

_log.debug(f"{len(X_k) = }")

k = 0

tic, toc = time.time(), time.time()

while k < max_iter and toc - tic < max_time_iter:
    _log.debug(
        "\n" + "=" * 10 + f" {k = }, Δt = {toc - tic:.4f} [seg] " + "=" * 10
    )
    # Paso 2: Samplear \tilde\mu_k
    t_mu_k: DistributionDraw = pi_n.draw()
    t_X_k, t_m_k = t_mu_k.enumerate_nz_support_(), t_mu_k.nz_probs

    # Calcular transporte óptimo
    transp.fit(
        Xs=X_k,
        mu_s=m,
        Xt=t_X_k,
        mu_t=t_m_k,
    )
    T_X_k = transp.transform(X_k)
    _log.debug(f"{T_X_k.shape = }")

    # Calcular la distribución de mu_{k+1}
    gamma_k = gamma(k)
    _log.debug(f"{gamma_k = :.6f}")
    X_kp1 = (
        1 - gamma_k
    ) * X_k + gamma_k * T_X_k  # Basta con calcular las transformaciones, porque los pesos son los mismos
    _log.debug(f"{X_kp1.shape = }")

    # Calcular la distancia de Wasserstein
    diff = X_k - T_X_k
    dist = float((gamma_k**2) * torch.sum(m * LA.norm(diff, dim=1) ** 2))
    _log.debug(f"{dist = :.8f}")

    while dist < eps:
        break

    position_historial.append(X_kp1)

    k += 1
    X_k = X_kp1
    toc = time.time()

In [None]:
reshaped_positions = []
n_coord, n_dim = position_historial[0].shape
for p in position_historial:
    p_reshape = p.reshape(n_coord, 1, n_dim)
    reshaped_positions.append(p_reshape)

In [None]:
positions_historial_batch = torch.cat(reshaped_positions, dim=1)

In [None]:
positions_historial_batch.shape

In [None]:
import pickle
from pathlib import Path
import os

SAVE = False

PICKLES_PATH = Path("./pickles")

POSITION_HISTORIAL_PATH = PICKLES_PATH / "postion_historial.dat"
WEIGHTS_PATH = PICKLES_PATH / "weights.dat"

if not PICKLES_PATH.exists():
    PICKLES_PATH.mkdir()

if SAVE:
    if POSITION_HISTORIAL_PATH.exists():
        os.remove(POSITION_HISTORIAL_PATH)

    if WEIGHTS_PATH.exists():
        os.remove(WEIGHTS_PATH)

    with POSITION_HISTORIAL_PATH.open("wb") as f:
        pickle.dump(positions_historial_batch, f)

    with WEIGHTS_PATH.open("wb") as f:
        pickle.dump(m, f)

PICKLES_PATH.exists()

In [None]:
import pickle
from pathlib import Path

d = 2
PICKLES_PATH = Path("./pickles")
POSITION_HISTORIAL_PATH = PICKLES_PATH / f"postion_historial{d}.dat"
WEIGHTS_PATH = PICKLES_PATH / f"weights{d}.dat"

with POSITION_HISTORIAL_PATH.open("rb") as f:
    new_position_historial = pickle.load(f)

with WEIGHTS_PATH.open("rb") as f:
    m = pickle.load(f)

new_position_historial, m

In [None]:
len(new_position_historial)

In [None]:
new_position_historial[0].shape, new_position_historial[
    1
].shape, new_position_historial[2].shape

In [None]:
new_position_historial = new_position_historial[1:]

In [None]:
%%time

reshaped_positions = []
n_coord, n_dim = new_position_historial[0].shape
for p in new_position_historial:
    p_reshape = p.reshape(n_coord, 1, n_dim)
    reshaped_positions.append(p_reshape)

In [None]:
%%time

positions_historial_batch = torch.cat(reshaped_positions, dim=1)

In [None]:
positions_historial_batch.shape

In [None]:
m.shape

In [None]:
# position_historial = new_position_historial
# position_historial[:, 0, :].shape

In [None]:
# pos_hist_rounded = torch.round(positions_historial_batch)
# pos_hist_ind = (pos_hist_rounded[:, :, 0] * m + pos_hist_rounded[:, :, 1]).to(int)
# pos_hist_ind

In [None]:
# pos_hist_ind.shape

In [None]:
# m.shape

In [None]:
# to_return = torch.zeros((13, 28*28))
#
# for w, p in zip(m, pos_hist_ind):
#     for i, pp in enumerate(p):
#         to_return[i, pp] += w

In [None]:
# to_return.sum(1)

In [None]:
def position_to_weights(positions, weights, shape):
    positions, weights = torch.as_tensor(positions), torch.as_tensor(weights)

    n, m = shape
    n_coord, n_pos, n_dim = positions.shape

    to_return = torch.zeros((n_pos, n * m))
    pos_rounded = torch.round(positions)
    pos_ind = (pos_rounded[:, :, 0] * m + pos_rounded[:, :, 1]).to(int)

    for w, p in zip(weights, pos_ind):
        for i, pp in enumerate(p):
            to_return[i, pp] += w

    return to_return


# ptw = position_to_weights([[0, 0], [0, 0], [1, 0], [1, 1]], [0.2, 0.3, 0.1, 0.4], (2, 2))
# ptw

ptw = position_to_weights(positions_historial_batch[:, :10, :], m, (28, 28))

In [None]:
def position_to_list_of_dd(positions, weights, shape, i, j):
    return [
        DistributionDraw(w, shape)
        for w in position_to_weights(positions[:, i:j, :], weights, shape)
    ]


list_dd = position_to_list_of_dd(
    positions_historial_batch[:, :10, :], m, (28, 28), 0, 10
)
list_dd

In [None]:
from bwb.utils import plot_list_of_draws

list_dd = position_to_list_of_dd(positions_historial_batch, m, (28, 28), 0, 36)
plot_list_of_draws(list_dd)

In [None]:
list_dd = position_to_list_of_dd(positions_historial_batch, m, (28, 28), 36, 72)
plot_list_of_draws(list_dd)

In [None]:
list_dd = position_to_list_of_dd(
    positions_historial_batch, m, (28, 28), -37, -1
)
plot_list_of_draws(list_dd)

In [None]:
n_coord, n_pos, n_dim = positions_historial_batch.shape
to_return = torch.zeros((n_pos, 28 * 28))
pos_rounded = torch.round(positions_historial_batch)
pos_ind = (pos_rounded[:, :, 0] * 28 + pos_rounded[:, :, 1]).to(int)

gen = zip(m, pos_ind)
w, p = next(gen)
w.shape, p.shape

In [None]:
to_return[:, 0].shape

In [None]:
ptw

In [None]:
%%time

historial_bsgd = torch.concat(
    [
        position_to_weights(p, m, (28, 28)).reshape((1, -1))
        for p in new_position_historial
    ],
    0,
)

In [None]:
dddl = DistributionDrawDataLoader()

In [None]:
arr = torch.arange(28 * 28)
# arr = torch.reshape(arr, (28, 28))

In [None]:
pos = torch.round(new_position_historial[-1][:10]).to(int)
pos_ind = pos[:, 0] * 28 + pos[:, 1]
# for p in map(tuple, pos):
#     print(arr[p])
# pos_ind

arr[pos_ind]