In [None]:
%cd ~/codeProjects/pythonProjects/Bayesian-Learning-with-Wasserstein-Barycenters

In [None]:
import bwb.config as cfg

# cfg.use_single_precision()
# cfg.use_cpu()

In [None]:
import bwb
from pathlib import Path

main_path = Path(".")
data_path = main_path / "data"
face_path = data_path / "face_recognized.npy"
face_path

In [None]:
import numpy as np

arr = np.load(face_path)
arr.shape

In [None]:
from bwb.distributions.data_loaders import *

In [None]:
faces = DistributionDrawDataLoader(arr, (28, 28))
faces

In [None]:
indx = 4
faces[indx]

In [None]:
f1_, f2_ = faces[0], faces[indx]
f1, f2 = f1_.grayscale_weights, f2_.grayscale_weights

In [None]:
import torch
from bwb.config import conf

A = torch.stack([f1, f2])

v1 = torch.tensor((1, 0, 0, 0), dtype=conf.dtype, device=conf.device)
v2 = torch.tensor((0, 1, 0, 0), dtype=conf.dtype, device=conf.device)
v3 = torch.tensor((0, 0, 1, 0), dtype=conf.dtype, device=conf.device)
v4 = torch.tensor((0, 0, 0, 1), dtype=conf.dtype, device=conf.device)
A.shape

In [None]:
import time

from matplotlib import pyplot as plt
from bwb.config import config
from bwb import bregman

eps = config.eps

nb_images = 9
fig, axes = plt.subplots(1, nb_images, figsize=(7, 2))
cm = "Blues"
# regularization parameter
reg = 2e-3
entrop_sharp = False
tic_ = time.time()
for i in range(nb_images):
    for j in range(1):
        ax = axes[i]
        tic = time.time()

        tx = float(i) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        weights = (1 - tx) * torch.tensor(
            [1, 0], device=conf.device
        ) + tx * torch.tensor([0, 1], device=conf.device)

        if i == 0 and j == 0:
            ax.imshow(f1.cpu(), cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            ax.imshow(f2.cpu(), cmap=cm)
        else:
            # call to barycenter computation
            bar, log = bregman.convolutional_barycenter2d(
                A,
                reg,
                weights,
                # entrop_sharp=entrop_sharp,
                numItermax=1_000,
                stopThr=1e-8,
                # verbose=True,
                warn=False,
                log=True,
            )
            bar = bar.cpu()
            ax.imshow(bar, cmap=cm)
        ax.set_title(f"$t={weights[1].item():.2f}$")
        ax.axis("off")

        toc = time.time()
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"

fig.suptitle(f"Convolutional Wasserstein Barycenters.")

plt.tight_layout()
# plt.savefig(img_path / f"{additional_info}-entrop-sharp-{entrop_sharp}-conv-wasserstein-bar.png",
#             dpi=400)
plt.show()

In [None]:
# from bwb.geodesics import *
# from bwb.distributions import *
# from bwb import transports as tpt

# geodesic = PartitionedBarycentricProjGeodesic(
#     tpt.EMDTransport(norm="max", max_iter=5_000)
# ).fit_wd(
#     f1_, f2_,
# )
# geod, weights = geodesic.interpolate(0.5)
# weights

In [None]:
from wgan_gp.wgan_gp_vae.model_resnet import Generator, Encoder

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
LATENT_DIM = 128
CHANNELS_IMG = 1
NUM_FILTERS = [256, 128, 64, 32]

G = Generator(LATENT_DIM, CHANNELS_IMG).to(device)
E = Encoder(LATENT_DIM, CHANNELS_IMG).to(device)

In [None]:
CURR_PATH = Path(".")
NETS_PATH = CURR_PATH / "wgan_gp" / "networks"
# FACE_PATH = NETS_PATH / f"_resnet_face_zDim{LATENT_DIM}_gauss_bs_128_recognized_augmented_WAE_WGAN_loss_l1_32p32"
FACE_PATH = NETS_PATH / "data_cleaned_principal"

FACE_PATH

In [None]:
from wgan_gp.wgan_gp_vae.utils import load_checkpoint

load_checkpoint(G, FACE_PATH, "generator", device)
load_checkpoint(E, FACE_PATH, "encoder", device)

In [None]:
from wgan_gp.wgan_gp_vae.utils import ProjectorOnManifold
import torchvision.transforms as T


proj = ProjectorOnManifold(
    E,
    G,
    transform_in=T.Compose([
        # From pdf to grayscale
        T.Lambda(lambda x: x / torch.max(x)),
        # T.Lambda(lambda x: x),
        T.ToPILImage(),
        T.Resize((32, 32)),
        T.ToTensor(),
        T.Normalize(
            [0.5 for _ in range(1)],
            [0.5 for _ in range(1)],
        ),
    ]),
    transform_out=T.Compose([
        # Ensure the range is in [0, 1]
        T.Lambda(lambda x: x - torch.min(x)),
        T.Lambda(lambda x: x / torch.max(x)),
        # T.Lambda(lambda x: 1 - x),
        T.ToPILImage(),
        T.Resize((28, 28)),
        T.ToTensor(),
        T.Lambda(lambda x: x / torch.sum(x)),
        T.Lambda(lambda x: x.squeeze(0)),
    ]),
)

In [None]:
axes.shape

In [None]:
import time

from matplotlib import pyplot as plt
from bwb.config import config
import ot

eps = config.eps

# nb_images = 7
fig, axes = plt.subplots(2, nb_images, figsize=(7, 2))
cm = "Blues"
# regularization parameter
reg = 0.01
stopThr = 5e-4
entrop_sharp = False
tic_ = time.time()
for i in range(nb_images):
    for j in range(2):
        ax = axes[j, i]
        tic = time.time()

        tx = float(i) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        weights = (1 - tx) * torch.tensor(
            [1, 0], device=conf.device
        ) + tx * torch.tensor([0, 1], device=conf.device)

        if i == 0 and j == 0:
            ax.imshow(f1.cpu(), cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            ax.imshow(f2.cpu(), cmap=cm)
        if i == 0 and j == 1:
            ax.imshow(f1.cpu(), cmap=cm)
        elif i == (nb_images - 1) and j == 1:
            ax.imshow(f2.cpu(), cmap=cm)
        elif j == 0:
            # call to barycenter computation
            bar, log = ot.bregman.convolutional_barycenter2d_debiased(
                A,
                reg,
                weights,
                # entrop_sharp=entrop_sharp,
                # reg=0.01,
                stopThr=stopThr,
                # numItermax=1_000, stopThr=1e-8,
                # verbose=True,
                warn=False,
                log=True,
            )
            # bar = proj(bar)
            ax.imshow(bar.cpu(), cmap=cm)
        elif j == 1:
            # call to barycenter computation
            bar, log = ot.bregman.convolutional_barycenter2d_debiased(
                A,
                reg,
                weights,
                # entrop_sharp=entrop_sharp,
                # reg=0.01,
                stopThr=stopThr,
                # numItermax=1_000, stopThr=1e-8,
                # verbose=True,
                warn=False,
                log=True,
            )
            bar = proj(bar)
            ax.imshow(bar.cpu(), cmap=cm)
        ax.set_title(f"$t={weights[1].item():.2f}$")
        ax.axis("off")

        toc = time.time()
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"

# plt.suptitle(f'Convolutional Wasserstein Barycenters in POT. {d_time}')

plt.tight_layout()
# plt.savefig(img_path / f"{additional_info}-entrop-sharp-{entrop_sharp}-conv-wasserstein-bar.png",
#             dpi=400)
plt.show()

In [None]:
from bwb.distributions import DistributionDraw

with torch.no_grad():
    img = G(torch.randn(1, LATENT_DIM, 1, 1).to(device))
    img = img.squeeze()
    img = img - img.min()
    img = img / img.sum()
    dd = DistributionDraw.from_grayscale_weights(img)
dd