In [None]:
import numpy as np
import torch

from bwb import transports as tpt
from bwb.distributions import data_loaders as dl
from bwb.geodesics import *
from bwb.distributions import *

In [None]:
from bwb import logging

log = logging.get_logger(__name__)

logging.set_level(logging.DEBUG)

In [None]:
arr = np.load(r"..\data\face.npy")
arr.shape

In [None]:
faces = dl.DistributionDrawDataLoader(arr, (28, 28))
ddraw0 = faces[0]
ddraw1 = faces[2]

In [None]:
ddraw0

In [None]:
ddraw1

In [None]:
%%time

mst = tpt.SinkhornTransport(max_iter=10_000).fit_wd(dd_s=ddraw0, dd_t=ddraw1)
mst

In [None]:
%%time

memdt = tpt.EMDTransport().fit_wd(dd_s=ddraw0, dd_t=ddraw1)
memdt

In [None]:
print(f"{len(ddraw0.nz_probs) = }, {len(ddraw1.nz_probs) = }")

In [None]:
len(memdt.coupling_.nonzero())

In [None]:
X0 = ddraw0.enumerate_nz_support_().reshape(-1, 1, 2)
X1 = ddraw1.enumerate_nz_support_().reshape(1, -1, 2)

t = 0.5
coord = (1 - t) * X0 + t * X1

coord.shape, X0.shape, X1.shape

In [None]:
nz_coord = memdt.coupling_.nonzero(as_tuple=True)


coord[nz_coord].shape

In [None]:
296 + 264 - 1

In [None]:
ddraw0

In [None]:
ddraw1

In [None]:
from bwb.distributions import *

cm = "Blues"

In [None]:
# %%time
# import matplotlib.animation as animation
# import matplotlib.pylab as pl
# from IPython.display import HTML
# import time
#
# pl.figure(3)
# n_iter_max = 25
# t_list = np.linspace(0, 1, n_iter_max)
# transform = mst.transform(memdt.xs_)
#
# X0 = ddraw0.enumerate_nz_support_().reshape(-1, 1, 2)
# X1 = ddraw1.enumerate_nz_support_().reshape(1, -1, 2)
#
# nz_coords = n, m = mst.coupling_.nonzero(as_tuple=True)
#
# def _update_plot(i):
#     tic = time.time()
#     pl.clf()
#     t = t_list[i]
#     coords = (1-t) * X0 + t * X1
#     geod = coords[n, m, :]
#     weights = mst.coupling_[nz_coords]
#     # pl.scatter(geod[:, 0], geod[:, 1], c="b", alpha=weights / weights.max())
#     dd = DistributionDraw(weights, geod, (28, 28))
#     pl.imshow(dd.grayscale, cmap=cm)
#     pl.axis("off")
#     pl.title(f"$t = {t:.2f}$")
#     toc = time.time()
#     print(f"{i = }, Δt = {toc - tic:.4f} [seg]")
#     return 1
#
# ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter_max, interval=100, repeat_delay=2000)
# anim_html = HTML(ani.to_jshtml())
# pl.close(pl.gcf())
# anim_html

# McCann Interpolatiton with Sinkhorn

In [None]:
logging.set_level(logging.WARNING, "bwb.geodesics")
logging.log_config.loggers

In [None]:
%%time

import matplotlib.animation as animation
import matplotlib.pylab as pl
from IPython.display import HTML

pl.figure(3)
n_iter_max = 51
t_list = np.linspace(0, 1, n_iter_max)
geodesic = McCannGeodesic(
    tpt.SinkhornTransport(reg_e=1e-3, norm="max", max_iter=5_000)
).fit_wd(dd_s=ddraw0, dd_t=ddraw1)


def _update_plot(i):
    log.debug(f"plot {i = }")
    pl.clf()
    t = t_list[i]
    geod, weights = geodesic.interpolate(t)
    # pl.scatter(geod[:, 0], geod[:, 1], c="b", alpha=weights / weights.max())
    dd = DistributionDraw(geod, weights, (28, 28))
    pl.imshow(dd.grayscale, cmap=cm)
    pl.axis("off")
    pl.title(f"$t = {t:.2f}$")
    return 1


ani = animation.FuncAnimation(
    pl.gcf(), _update_plot, n_iter_max, interval=100, repeat_delay=2000
)
anim_html = HTML(ani.to_jshtml())
pl.close(pl.gcf())
anim_html

# McCann Interpolatiton with EMD

In [None]:
%%time

import matplotlib.animation as animation
import matplotlib.pylab as pl
from IPython.display import HTML

pl.figure(3)
n_iter_max = 51
t_list = np.linspace(0, 1, n_iter_max)
geodesic = McCannGeodesic(tpt.EMDTransport(norm="max", max_iter=5_000)).fit_wd(
    dd_s=ddraw0, dd_t=ddraw1
)


def _update_plot(i):
    log.debug(f"plot {i = }")
    pl.clf()
    t = t_list[i]
    geod, weights = geodesic.interpolate(t)
    pl.scatter(geod[:, 0], geod[:, 1], c="b", alpha=weights / weights.max())
    # dd = DistributionDraw(*geodesic.interpolate(t), (28, 28))
    # pl.imshow(dd.grayscale, cmap=cm)
    pl.axis("off")
    pl.xlim((0, 28))
    pl.ylim((0, 28))
    pl.title(f"$t = {t:.2f}$")
    return 1


ani = animation.FuncAnimation(
    pl.gcf(), _update_plot, n_iter_max, interval=100, repeat_delay=2000
)
anim_html = HTML(ani.to_jshtml())
pl.close(pl.gcf())
anim_html

In [None]:
Xs, mu_s = [], []
min_w = ddraw0.nz_probs.min() * 3
max_w = ddraw0.nz_probs.max()

for x, w, n in zip(
    ddraw0.enumerate_nz_support_(),
    ddraw0.nz_probs,
    torch.ceil(ddraw0.nz_probs / min_w).to(torch.int),
):
    for _ in range(n):
        Xs.append(x.reshape(1, -1))
        mu_s.append(w / n)

Xs = torch.cat(Xs, dim=0)
mu_s = torch.as_tensor(mu_s)

In [None]:
%%time

memdt = tpt.EMDTransport().fit_wm(
    Xs=Xs,
    mu_s=mu_s,
    Xt=ddraw1.enumerate_nz_support_(),
    mu_t=ddraw1.nz_probs,
)
memdt

In [None]:
import matplotlib.animation as animation
import matplotlib.pylab as pl
from IPython.display import HTML

pl.figure(3)
n_iter_max = 100
t_list = np.linspace(0, 1, n_iter_max)
transform = memdt.transform(memdt.xs_)


def _update_plot(i):
    pl.clf()
    t = t_list[i]
    geod = (1 - t) * memdt.xs_ + t * transform
    pl.scatter(geod[:, 0], geod[:, 1], c="b", alpha=memdt.mu_s / max_w)
    pl.axis("equal")
    pl.title(f"$t = {t:.2f}$")
    return 1


ani = animation.FuncAnimation(
    pl.gcf(), _update_plot, n_iter_max, interval=100, repeat_delay=2000
)
# ani
# animation.FuncAnimation(pl.gcf(), _update_plot, n_iter_max, interval=100, repeat_delay=2000)
anim_html = HTML(ani.to_jshtml())
pl.close(pl.gcf())
anim_html