# Geodésicas
Se buscará obtener las geodésicas y baricentros de la misma forma que en el ejemplo [Convolutional Wasserstein Barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html#sphx-glr-auto-examples-barycenters-plot-convolutional-barycenter-py), para luego obtener geodésicas implementadas en esta librería.

In [24]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import ot
import torch
from PIL import Image
# noinspection PyUnresolvedReferences
from PIL.Image import Resampling

# noinspection PyProtectedMember
from bwb import logging
from bwb import transports as tpt
from bwb.distributions import *
from bwb.geodesics import *

_log = logging.get_logger("notebook")
logging.set_level(logging.DEBUG)
_log

<Logger notebook (DEBUG)>

In [25]:
main_path = Path("..")

data_path = main_path / "data"
data_images_path = data_path / "images"
shapes_path = data_images_path / "shapes"
pot_shapes_path = data_images_path / "pot_shapes"

img_path = Path("img")

In [26]:
resolution = 128
size = (resolution, resolution)
resample = Resampling.LANCZOS

In [27]:
# # noinspection PyTypeChecker
# f1 = 1 - np.asarray(Image.open(pot_shapes_path / 'redcross.png').resize(size, resample))[:, :, 2] / 255
# # noinspection PyTypeChecker
# f2 = 1 - np.asarray(Image.open(pot_shapes_path / 'tooth.png').resize(size, resample))[:, :, 2] / 255
# # noinspection PyTypeChecker
# f3 = 1 - np.asarray(Image.open(pot_shapes_path / 'heart.png').resize(size, resample))[:, :, 2] / 255
# # noinspection PyTypeChecker
# f4 = 1 - np.asarray(Image.open(pot_shapes_path / 'duck.png').resize(size, resample))[:, :, 2] / 255

In [28]:
# noinspection PyTypeChecker
f1 = 1 - np.asarray(Image.open(shapes_path / 'shape1filled.png').resize(size, resample))[:, :, 2] / 255
# noinspection PyTypeChecker
f2 = 1 - np.asarray(Image.open(shapes_path / 'shape2filled.png').resize(size, resample))[:, :, 2] / 255
# noinspection PyTypeChecker
f3 = 1 - np.asarray(Image.open(shapes_path / 'shape3filled.png').resize(size, resample))[:, :, 2] / 255
# noinspection PyTypeChecker
f4 = 1 - np.asarray(Image.open(shapes_path / 'shape4filled.png').resize(size, resample))[:, :, 2] / 255

In [29]:
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
A = np.array([f1, f2, f3, f4])

nb_images = 5

# those are the four corners coordinates that will be interpolated by bilinear
# interpolation
v1 = np.array((1, 0, 0, 0))
v2 = np.array((0, 1, 0, 0))
v3 = np.array((0, 0, 1, 0))
v4 = np.array((0, 0, 0, 1))

In [30]:
additional_info = f"resol-{resolution}-nb-images-{nb_images}"

In [None]:
%%time
import time

fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
cm = 'Blues'
# regularization parameter
reg = 4e-3
# reg = 2e-3
tic_ = time.time()
for i in range(nb_images):
    for j in range(nb_images):
        tic = time.time()

        tx = float(i) / (nb_images - 1)
        ty = float(j) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        if i == 0 and j == 0:
            axes[i, j].imshow(f1, cmap=cm)
        elif i == 0 and j == (nb_images - 1):
            axes[i, j].imshow(f3, cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            axes[i, j].imshow(f2, cmap=cm)
        elif i == (nb_images - 1) and j == (nb_images - 1):
            axes[i, j].imshow(f4, cmap=cm)
        else:
            # call to barycenter computation
            axes[i, j].imshow(
                ot.bregman.convolutional_barycenter2d(A, reg, weights),
                cmap=cm
            )
        axes[i, j].axis('off')

        toc = time.time()
        _log.debug(f"{i = }, {j = } ==> Total time: {toc - tic:.4f} [seg]")
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"

plt.suptitle(f'Convolutional Wasserstein Barycenters in POT. {d_time}')

plt.tight_layout()
plt.savefig(img_path / f"{additional_info}-conv-wasserstein-bar.png", dpi=400)
plt.show()

2023-05-11 09:08:50,937: DEBUG [notebook:38]
> i = 0, j = 0 ==> Total time: 0.0020 [seg]
2023-05-11 09:08:55,315: DEBUG [notebook:38]
> i = 0, j = 1 ==> Total time: 4.3740 [seg]
2023-05-11 09:08:59,050: DEBUG [notebook:38]
> i = 0, j = 2 ==> Total time: 3.7342 [seg]
2023-05-11 09:09:03,028: DEBUG [notebook:38]
> i = 0, j = 3 ==> Total time: 3.9784 [seg]
2023-05-11 09:09:03,030: DEBUG [notebook:38]
> i = 0, j = 4 ==> Total time: 0.0010 [seg]
2023-05-11 09:09:06,708: DEBUG [notebook:38]
> i = 1, j = 0 ==> Total time: 3.6766 [seg]


# Utilizando geodésicas
Ahora que tenemos los resultados replicados del notebook de ejemplo, se procederá a replicar los resultados utilizando las clases creadas en esta librería, sólo calculando las geodésicas de par a par.

In [None]:
original_shape = f1.shape

dd1 = DistributionDraw.from_weights(f1.reshape(-1), original_shape)
dd2 = DistributionDraw.from_weights(f2.reshape(-1), original_shape)
dd3 = DistributionDraw.from_weights(f3.reshape(-1), original_shape)
dd4 = DistributionDraw.from_weights(f4.reshape(-1), original_shape)

Ahora se realizarán las matrices que serán graficadas

In [None]:
%%time
geodesic12 = McCannGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd1, dd_t=dd2)
geodesic13 = McCannGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd1, dd_t=dd3)
geodesic34 = McCannGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd3, dd_t=dd4)
geodesic24 = McCannGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd2, dd_t=dd4)

In [None]:
%%time
import time

fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
cm = 'Blues'
tic_ = time.time()
for i in range(nb_images):
    for j in range(nb_images):
        tic = time.time()

        tx = float(i) / (nb_images - 1)
        ty = float(j) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        axes_ij = axes[i, j]

        if i == 0 and j == 0:
            axes_ij.imshow(f1, cmap=cm)
        elif i == 0 and j == (nb_images - 1):
            axes_ij.imshow(f3, cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            axes_ij.imshow(f2, cmap=cm)
        elif i == (nb_images - 1) and j == (nb_images - 1):
            axes_ij.imshow(f4, cmap=cm)
        elif i == 0:
            dd_t = DistributionDraw(*geodesic13.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif i == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic24.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == 0:
            dd_t = DistributionDraw(*geodesic12.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic34.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        else:
            axes_ij.imshow(np.zeros(original_shape), cmap=cm)
        axes_ij.axis('off')

        toc = time.time()
        _log.debug(f"{i = }, {j = } ==> Total time: {toc - tic:.4f} [seg]")
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"

plt.suptitle(f'McCann Interpolation with EMD Transport. {d_time}')

plt.tight_layout()
plt.savefig(img_path / f"{additional_info}-mccaan-interpolation-emd.png", dpi=800)
plt.show()

# Interpolación con la proyección baricéntrica

In [None]:
%%time
geodesic12 = BarycentricProjGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd1, dd_t=dd2)
geodesic13 = BarycentricProjGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd1, dd_t=dd3)
geodesic34 = BarycentricProjGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd3, dd_t=dd4)
geodesic24 = BarycentricProjGeodesic(tpt.EMDTransport(max_iter=250_000)).fit_wd(dd_s=dd2, dd_t=dd4)

In [None]:
%%time
import time

fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
cm = 'Blues'
tic_ = time.time()
for i in range(nb_images):
    for j in range(nb_images):
        tic = time.time()

        tx = float(i) / (nb_images - 1)
        ty = float(j) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        axes_ij = axes[i, j]

        if i == 0 and j == 0:
            axes_ij.imshow(f1, cmap=cm)
        elif i == 0 and j == (nb_images - 1):
            axes_ij.imshow(f3, cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            axes_ij.imshow(f2, cmap=cm)
        elif i == (nb_images - 1) and j == (nb_images - 1):
            axes_ij.imshow(f4, cmap=cm)
        elif i == 0:
            dd_t = DistributionDraw(*geodesic13.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif i == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic24.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == 0:
            dd_t = DistributionDraw(*geodesic12.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic34.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        else:
            axes_ij.imshow(np.zeros(original_shape), cmap=cm)
        axes_ij.axis('off')

        toc = time.time()
        _log.debug(f"{i = }, {j = } ==> Total time: {toc - tic:.4f} [seg]")
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"
plt.suptitle(f'Barycentric Projection Interpolation with EMD Transport. {d_time}')
plt.tight_layout()
plt.savefig(img_path / f"{additional_info}-barycentric-proj-interpolation-emd.png", dpi=800)
plt.show()

# Interpolación con la proyección baricéntrica particionada

In [None]:
%%time
geodesic12 = PartitionedBarycentricProjGeodesic(
    tpt.EMDTransport(max_iter=250_000), alpha=0.1
).fit_wd(dd_s=dd1, dd_t=dd2)
geodesic13 = PartitionedBarycentricProjGeodesic(
    tpt.EMDTransport(max_iter=250_000), alpha=0.1
).fit_wd(dd_s=dd1, dd_t=dd3)
geodesic34 = PartitionedBarycentricProjGeodesic(
    tpt.EMDTransport(max_iter=250_000), alpha=0.1
).fit_wd(dd_s=dd3, dd_t=dd4)
geodesic24 = PartitionedBarycentricProjGeodesic(
    tpt.EMDTransport(max_iter=250_000), alpha=0.1
).fit_wd(dd_s=dd2, dd_t=dd4)

In [None]:
%%time
import time

fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
cm = 'Blues'
tic_ = time.time()
for i in range(nb_images):
    for j in range(nb_images):
        tic = time.time()

        tx = float(i) / (nb_images - 1)
        ty = float(j) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        axes_ij = axes[i, j]

        if i == 0 and j == 0:
            axes_ij.imshow(f1, cmap=cm)
        elif i == 0 and j == (nb_images - 1):
            axes_ij.imshow(f3, cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            axes_ij.imshow(f2, cmap=cm)
        elif i == (nb_images - 1) and j == (nb_images - 1):
            axes_ij.imshow(f4, cmap=cm)
        elif i == 0:
            dd_t = DistributionDraw(*geodesic13.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif i == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic24.interpolate(ty), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == 0:
            dd_t = DistributionDraw(*geodesic12.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic34.interpolate(tx), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        else:
            axes_ij.imshow(np.zeros(original_shape), cmap=cm)
        axes_ij.axis('off')

        toc = time.time()
        _log.debug(f"{i = }, {j = } ==> Total time: {toc - tic:.4f} [seg]")
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"
plt.suptitle(f'Partitioned Barycentric Projection Interpolation with EMD Transport. {d_time}')
plt.tight_layout()
plt.savefig(img_path / f"{additional_info}-barycentric-proj-interpolation-emd.png", dpi=800)
plt.show()

# Interpolación con Sinkhorn

In [None]:
%%time
kwargs = {
    "max_iter": 250_000,
    "reg_e": 1e-3,
    "norm": "max"
}

In [None]:
%%time
geodesic12 = McCannGeodesic(tpt.SinkhornTransport(**kwargs)).fit_wd(dd_s=dd1, dd_t=dd2)

In [None]:
interp_param = {}
# interp_param = {
#     "rtol": ,
#     "atol": 0,
# }

DistributionDraw(*geodesic12.interpolate(0.5, **interp_param), original_shape)

In [None]:
%%time
geodesic13 = McCannGeodesic(tpt.SinkhornTransport(**kwargs)).fit_wd(dd_s=dd1, dd_t=dd3)

In [None]:
%%time
geodesic34 = McCannGeodesic(tpt.SinkhornTransport(**kwargs)).fit_wd(dd_s=dd3, dd_t=dd4)

In [None]:
%%time
geodesic24 = McCannGeodesic(tpt.SinkhornTransport(**kwargs)).fit_wd(dd_s=dd2, dd_t=dd4)

In [None]:
%%time
import time


fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
cm = 'Blues'
tic_ = time.time()
for i in range(nb_images):
    for j in range(nb_images):
        tic = time.time()

        tx = float(i) / (nb_images - 1)
        ty = float(j) / (nb_images - 1)

        # weights are constructed by bilinear interpolation
        tmp1 = (1 - tx) * v1 + tx * v2
        tmp2 = (1 - tx) * v3 + tx * v4
        weights = (1 - ty) * tmp1 + ty * tmp2

        axes_ij = axes[i, j]

        if i == 0 and j == 0:
            axes_ij.imshow(f1, cmap=cm)
        elif i == 0 and j == (nb_images - 1):
            axes_ij.imshow(f3, cmap=cm)
        elif i == (nb_images - 1) and j == 0:
            axes_ij.imshow(f2, cmap=cm)
        elif i == (nb_images - 1) and j == (nb_images - 1):
            axes_ij.imshow(f4, cmap=cm)
        elif i == 0:
            dd_t = DistributionDraw(*geodesic13.interpolate(ty, **interp_param), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif i == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic24.interpolate(ty, **interp_param), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == 0:
            dd_t = DistributionDraw(*geodesic12.interpolate(tx, **interp_param), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        elif j == (nb_images - 1):
            dd_t = DistributionDraw(*geodesic34.interpolate(tx, **interp_param), original_shape)
            axes_ij.imshow(dd_t.grayscale, cmap=cm)
        else:
            axes_ij.imshow(np.zeros(original_shape), cmap=cm)
        axes_ij.axis('off')

        toc = time.time()
        _log.debug(f"{i = }, {j = } ==> Total time: {toc - tic:.4f} [seg]")
toc_ = time.time()
d_time = f"\nΔt={toc_-tic_:.1f}[seg]"
plt.suptitle(f'McCann Interpolation with Sinkhorn Transport. {d_time}')
plt.tight_layout()
plt.savefig(img_path / f"{additional_info}-mccaan-interpolation-sinkhorn.png", dpi=800)
plt.show()