In [None]:
!nvidia-smi

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

import ot
import ot.plot

import math
from tqdm.notebook import tqdm

from tempfile import TemporaryDirectory

import imageio
import scipy.sparse as sparse
import pickle

plt.rcParams['axes.facecolor'] = 'white'

%load_ext autoreload
%autoreload 2

In [None]:
n_points = 60000
size = None #(600, 800)

In [None]:
def convert_to_black_probabilities(im):
    flattened_array = im.reshape(-1) # 0 is black, 255 is white
    blackness = 255 - flattened_array # 0 is white, 255 is black
    blackness = blackness ** 3.25
    blackness = blackness / blackness.sum()
    return blackness

In [None]:
def convert_pic_to_scatter_plot(black_probabilities, initial_shape, n_points=10000):
    points = np.random.choice(black_probabilities.shape[0], replace=True, size=(n_points, ), p=black_probabilities)
    p1, p2 = np.unravel_index(points, initial_shape)
    rand = np.random.random((2, n_points))
    return p1 + rand[0], p2 + rand[1]

In [None]:
def display_scatter(source, show=True, size=None):
    x, y = source[..., 0], source[..., 1]
    size = size or (y.max() - y.min(), x.max() - x.min())
    n_points = x.shape[0]
    plt.scatter(y, -x, s=4000. / n_points, c='k', marker=".")
    plt.xlim(0, size[0] + 1)
    plt.ylim(-size[1], 0)
    plt.axis('off')
    if show:
        plt.show()

In [None]:
def change_contrast(img, level):
    factor = (259 * (level + 255)) / (255 * (259 - level))
    def contrast(c):
        return 128 + factor * (c - 128)
    return img.point(contrast)

In [None]:
def im2dots(original_img, n_points, size):
    pil_image = Image.new("RGBA", original_img.size, "WHITE") 
    pil_image.paste(original_img, (0, 0), original_img)
    if size is not None:
        pil_image = pil_image.resize(size)
    pil_image = pil_image.convert(mode="L")
    pil_image = change_contrast(pil_image, 70)
    im = np.array(pil_image)
    probabilities = convert_to_black_probabilities(im)
    x, y = convert_pic_to_scatter_plot(probabilities, im.shape, n_points)
    return np.stack([x, y], axis=-1)

In [None]:
image_source = Image.open("Examples/Monge-Kanto/Kantorovich.png").convert("RGBA")
image_target = Image.open("Examples/Monge-Kanto/Monge.jpg").convert("RGBA")

In [None]:
xs = im2dots(image_source, n_points, size)
xt = im2dots(image_target, n_points, size)

In [None]:
plt.subplot(1, 2, 1)
display_scatter(xs, show=False, size=size)
plt.subplot(1, 2, 2)
display_scatter(xt, show=False, size=size)
plt.show()

In [None]:
def compute_transport_plan(xs, xt):
    n_points = xs.shape[0]
    a = np.ones((n_points,)) / n_points
    b = np.ones((n_points,)) / n_points
    M = ot.dist(xs, xt)
    G, log = ot.emd(a, b, M, numThreads=24, numItermax=10_000_000, log=True)
    return sparse.csr_matrix(G) * n_points, log

In [None]:
G, log = compute_transport_plan(xs, xt)

In [None]:
G

In [None]:
log

In [None]:
def get_points_at_t(xs, xt, G, t):
    # barycenter between xs and its corresponding xt
    # linear translation: at xs for t=0 and at xt for t=1
    return (1 - t) * xs + t * G @ xt

In [None]:
def make_time_dimension(n_frames):
    if n_frames % 2 == 0:
        print("Adding a frame for symmetry purposes")
        n_frames += 1
    x = np.linspace(0, 1, math.ceil(n_frames / 2))
    t = np.concatenate([
            x[:-1], np.flip(x)
        ],
        axis=0
    )
    return t

In [None]:
def make_gif(xs, xt, G, n_frames, gif_name="mygif"):
    with TemporaryDirectory() as tmpdirname:
        
        time_dim = make_time_dimension(n_frames)
        n = len(time_dim)
        fig = plt.figure(figsize=(6.4, 9.6))
        
        def make_frame(index, xs, xt, G):
            t = time_dim[index]
            points = get_points_at_t(xs, xt, G, t)
            display_scatter(points, show=False, size=size)
            plt.savefig(f"{tmpdirname}/{index}.png", transparent=False, bbox_inches='tight')
            plt.clf()

        for index in tqdm(range(n), desc="Making pictures"):
            make_frame(index, xs, xt, G)

        plt.close()
        plt.cla()
        plt.clf()
        
        with imageio.get_writer(f"{gif_name}.gif", mode="I", fps=60) as writer:
            for i in tqdm(range(n), desc="Merging pictures"):
                filename = f"{tmpdirname}/{i}.png"
                image = imageio.imread(filename)
                writer.append_data(image)
            print("The merger is complete")


In [None]:
make_gif(xs, xt, G, 240)