In [None]:
# To reload modified python modules
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Optimal transport
from ot.gromov import gromov_wasserstein
from utils.gw_ms import gromov_wasserstein_ms, cost_gw, cost_ms

# Graph functions
import networkx as nx
from utils.panda_functions import *

folder_figs = "figures"

# Testing

## Test position functions

In [None]:
nh = 15
ne1 = ne2 = 5
ear_dist = 2
paste_edge = True
push_ears = False

i1_1, i1_2, i2_1, i2_2 = paste_to_vertices(nh, ear_dist=ear_dist, paste_edge=paste_edge)

X_all = panda_position(
    nh, ne1, ne2, ear_dist=ear_dist, paste_edge=paste_edge, push_ears=push_ears
)
X_head, tt = head_position(nh, i1_end=i1_2, i2_start=i2_1)
X_ear1 = ear_position(ne1, X_head, tt, i1=i1_1, i2=i1_2, paste_edge=paste_edge)
X_ear2 = ear_position(ne1, X_head, tt, i1=i2_1, i2=i2_2, paste_edge=paste_edge)

In [None]:
plt.scatter(X_head[:, 0], X_head[:, 1], s=100)
plt.scatter(X_ear1[:, 0], X_ear1[:, 1])
plt.scatter(X_ear2[:, 0], X_ear2[:, 1])

In [None]:
plt.scatter(X_all[:, 0], X_all[:, 1])
print(X_all.shape)

## Check panda generation

In [None]:
params = [[15, 5, 5], [20, 7, 7]]
nPandas = len(params)
nSteps = len(params[0])

rng = np.random.default_rng(seed=304)
std = 0
ear_dist = 2

Ns = []
Pandas = []
Pandas_pos = []
dm_pandas = []
lCs_pandas = []
for idx, param in enumerate(params):
    nh = param[0]
    ne1 = param[1]
    ne2 = param[2]

    if idx == 0:
        push_ears = False
    else:
        push_ears = True
    N, G, dm, lC = create_panda(
        nh,
        ne1,
        ne2,
        ear_dist=ear_dist,
        paste_edge=True,
        push_ears=push_ears,
        rng=rng,
        std=std,
    )
    pos = panda_position(
        nh, ne1, ne2, ear_dist=ear_dist, paste_edge=True, push_ears=push_ears
    )

    Ns.append(N)
    Pandas.append(G)
    Pandas_pos.append(pos)
    dm_pandas.append(dm)
    lCs_pandas.append(lC)

_ = display_ms_pandas(Pandas, Pandas_pos, dm_pandas, lCs_pandas)

In [None]:
params = [[15, 5, 5], [15, 5, 5]]
nPandas = len(params)
nSteps = len(params[0])

rng = np.random.default_rng(seed=304)
std = 0
ear_dist = 2

Ns = []
Pandas = []
Pandas_pos = []
dm_pandas = []
lCs_pandas = []
for idx, param in enumerate(params):
    nh = param[0]
    ne1 = param[1]
    ne2 = param[2]

    if idx == 0:
        push_ears = False
    else:
        push_ears = True
    N, G, dm, lC = create_panda(
        nh,
        ne1,
        ne2,
        ear_dist=ear_dist,
        paste_edge=False,
        push_ears=push_ears,
        rng=rng,
        std=std,
    )
    pos = panda_position(nh, ne1, ne2, ear_dist=ear_dist, push_ears=push_ears)

    Ns.append(N)
    Pandas.append(G)
    Pandas_pos.append(pos)
    dm_pandas.append(dm)
    lCs_pandas.append(lC)

_ = display_ms_pandas(Pandas, Pandas_pos, dm_pandas, lCs_pandas)

# Experiments

## Panda with ears pasted at vertices

In [None]:
params = [[25, 10, 10], [30, 12, 12]]
nPandas = len(params)
nSteps = len(params[0])

rng = np.random.default_rng(seed=5500)
std = 0.05
ear_dist = 6

Ns = []
Pandas = []
Pandas_pos = []
dm_pandas = []
lCs_pandas = []
for idx, param in enumerate(params):
    nh = param[0]
    ne1 = param[1]
    ne2 = param[2]

    if idx == 0:
        push_ears = False
    else:
        push_ears = True
    N, G, dm, lC = create_panda(
        nh,
        ne1,
        ne2,
        paste_edge=False,
        ear_dist=ear_dist,
        push_ears=push_ears,
        add_neighbors=False,
        rng=rng,
        std=std,
    )
    pos = panda_position(nh, ne1, ne2, ear_dist=ear_dist, push_ears=push_ears)

    Ns.append(N)
    Pandas.append(G)
    Pandas_pos.append(pos)
    dm_pandas.append(dm)
    lCs_pandas.append(lC)

In [None]:
_ = display_ms_pandas(Pandas, Pandas_pos, dm_pandas, lCs_pandas)

In [None]:
_ = save_pandas(Pandas, Pandas_pos, dm_pandas, lCs_pandas, folder=folder_figs)

In [None]:
# Compute GW on distance matrix
T_gw0, log = gromov_wasserstein(dm_pandas[0], dm_pandas[1], log=True)
dGW0 = log["gw_dist"]

# Compute GW at each time
Ts_gw = []
dGWs = []
for idt in range(nSteps):
    C1 = lCs_pandas[0][idt]
    C2 = lCs_pandas[1][idt]
    T, log = gromov_wasserstein(C1, C2, log=True)

    Ts_gw.append(T)
    dGWs.append(log["gw_dist"])

# Compute MS distance
T_ms, log_ms = gromov_wasserstein_ms(lCs_pandas[0], lCs_pandas[1], log=True)
dMS = log_ms["gw_dist"]

In [None]:
fig1, axes1 = plt.subplots(1, nSteps, figsize=(10, 3))

for idt in range(nSteps):
    axes1[idt].imshow(Ts_gw[idt], aspect="auto")
    axes1[idt].set_title("dGW = %0.2f" % dGWs[idt])

fig2 = plt.figure(figsize=(3, 3))
axes2 = plt.gca()
axes2.imshow(T_ms, aspect="auto")
axes2.set_title("dMS = %0.2f" % dMS)

fig4 = plt.figure(figsize=(3, 3))
axes4 = plt.gca()
axes4.imshow(T_gw0, aspect="auto")
axes4.set_title("dGW_0 = %0.2f" % dGW0)

In [None]:
# Check which are optimal couplings for GW_0
print("GW cost of couplings")
print("GW:", cost_gw(T_gw0, dm_pandas[0], dm_pandas[1]))

for idt in range(nSteps):
    print("GW_{}: {}".format(idt, cost_gw(Ts_gw[idt], dm_pandas[0], dm_pandas[1])))

print("MS:", cost_gw(T_ms, dm_pandas[0], dm_pandas[1]))
print()

print("Multiscale cost")
print("GW:", cost_ms(T_gw0, lCs_pandas[0], lCs_pandas[1]))
print("MS:", cost_ms(T_ms, lCs_pandas[0], lCs_pandas[1]))

In [None]:
save_couplings(T_ms, Ts_gw, T_gw0, fs=4, folder=folder_figs)

## Repeating a panda experiment to get average values

In [None]:
nReps = 50

params = [[25, 10, 10], [30, 12, 12]]
nPandas = len(params)
nSteps = len(params[0])

rng = np.random.default_rng(seed=304)
std = 0.05
ear_dist = 6

dists_gw = np.zeros((nSteps + 1, nReps))
dists_ms = np.zeros((1, nReps))

for t in range(nReps):
    Ns = []
    Pandas = []
    Pandas_pos = []
    dm_pandas = []
    lCs_pandas = []
    for idx, param in enumerate(params):
        nh = param[0]
        ne1 = param[1]
        ne2 = param[2]

        if idx == 0:
            push_ears = False
        else:
            push_ears = True
        N, G, dm, lC = create_panda(
            nh,
            ne1,
            ne2,
            paste_edge=False,
            ear_dist=ear_dist,
            push_ears=push_ears,
            add_neighbors=True,
            rng=rng,
            std=std,
        )
        pos = panda_position(nh, ne1, ne2, ear_dist=ear_dist, push_ears=push_ears)

        Ns.append(N)
        Pandas.append(G)
        Pandas_pos.append(pos)
        dm_pandas.append(dm)
        lCs_pandas.append(lC)

    # Compute GW on distance matrix
    T_gw0, log = gromov_wasserstein(dm_pandas[0], dm_pandas[1], log=True)
    dGW0 = log["gw_dist"]

    # Compute GW at each time
    Ts_gw = []
    dGWs = []
    for idt in range(nSteps):
        C1 = lCs_pandas[0][idt]
        C2 = lCs_pandas[1][idt]
        T, log = gromov_wasserstein(C1, C2, log=True)

        Ts_gw.append(T)
        dGWs.append(log["gw_dist"])

    # Compute MS distance
    T_ms, log_ms = gromov_wasserstein_ms(lCs_pandas[0], lCs_pandas[1], log=True)
    dMS = log_ms["gw_dist"]

    # Store the GW cost of the optimal couplings
    # GW_0
    dists_gw[0, t] = cost_gw(T_gw0, dm_pandas[0], dm_pandas[1])

    # GW_t
    for idt in range(nSteps):
        dists_gw[idt + 1, t] = cost_gw(Ts_gw[idt], dm_pandas[0], dm_pandas[1])

    # MS
    dists_ms[0, t] = cost_gw(T_ms, dm_pandas[0], dm_pandas[1])

In [None]:
plt.figure()
plt.plot(dists_gw.T)
plt.legend(["Full panda", "Head", "Ear 1", "Ear 2"])
plt.title("GW cost at each coupling")

In [None]:
plt.figure()
plt.plot(dists_ms.T)
plt.title("GW cost of MS coupling")

In [None]:
# Compute averages
print("GW averages, std dev:")
print(np.mean(dists_gw, axis=1))
print(np.std(dists_gw, axis=1))
print()

print("MS average, std dev:")
print(np.mean(dists_ms, axis=1))
print(np.std(dists_ms, axis=1))

## Pandas with ears pasted at an edge

In [None]:
params = [[25, 10, 10], [30, 12, 12]]
nPandas = len(params)
nSteps = len(params[0])

rng = np.random.default_rng(seed=304)
std = 0.05
ear_dist = 5

Ns = []
Pandas = []
Pandas_pos = []
dm_pandas = []
lCs_pandas = []
for idx, param in enumerate(params):
    nh = param[0]
    ne1 = param[1]
    ne2 = param[2]

    if idx == 0:
        push_ears = False
    else:
        push_ears = True
    N, G, dm, lC = create_panda(
        nh,
        ne1,
        ne2,
        paste_edge=True,
        ear_dist=ear_dist,
        push_ears=push_ears,
        rng=rng,
        std=std,
    )
    pos = panda_position(
        nh, ne1, ne2, paste_edge=True, ear_dist=ear_dist, push_ears=push_ears
    )

    Ns.append(N)
    Pandas.append(G)
    Pandas_pos.append(pos)
    dm_pandas.append(dm)
    lCs_pandas.append(lC)

In [None]:
_ = display_ms_pandas(Pandas, Pandas_pos, dm_pandas, lCs_pandas)

In [None]:
# Compute GW at each time
Ts_gw = []
dGWs = []
for idt in range(nSteps):
    C1 = lCs_pandas[0][idt]
    C2 = lCs_pandas[1][idt]
    T, log = gromov_wasserstein(C1, C2, log=True)

    Ts_gw.append(T)
    dGWs.append(log["gw_dist"])

# Compute MS distance
T_ms, log_ms = gromov_wasserstein_ms(lCs_pandas[0], lCs_pandas[1], log=True)
dMS = log_ms["gw_dist"]

In [None]:
fig1, axes1 = plt.subplots(1, nSteps, figsize=(10, 3))

for idt in range(nSteps):
    axes1[idt].imshow(Ts_gw[idt], aspect="auto")
    axes1[idt].set_title("dGW = %0.2f" % dGWs[idt])

fig2 = plt.figure(figsize=(3, 3))
axes2 = plt.gca()
axes2.imshow(T_ms, aspect="auto")
axes2.set_title("dMS = %0.2f" % dMS)