In [None]:
# To reload modified python modules
%load_ext autoreload
%autoreload 2

**Distances:**
Objetive:
- Construct a family of graphs that the heat kernel at a fixed time cannot distinguish, but that a sequence of heat kernels can.

Try:
- Cycle of cycle of cliques, but each cycle has different length.

In [None]:
import numpy as np
from scipy.linalg import expm
from matplotlib import pyplot as plt
import matplotlib

# Optimal transport
from ot.gromov import gromov_wasserstein
from utils.gw_ms import gromov_wasserstein_ms

# Graph functions
import networkx as nx
import pygsp as pg

from itertools import product, chain

from time import time
from utils.utils import display_time

# Path to save figures
from pathlib import Path
folder_figs = Path("figures")

# Functions

In [None]:
def from_networkx(G):
    A = nx.adjacency_matrix(G)
    G2 = pg.graphs.Graph(A.todense())

    return G2


def cast_pg(G):
    # Cast the graph as a pygsp graph
    G_pg = from_networkx(G)
    G_pg.compute_fourier_basis()
    G_pg.set_coordinates()

    return G_pg


def heat_kernel(G, t):
    L = nx.laplacian_matrix(G).todense()
    H = expm(-t * L)

    return H

def group_size_cliques(i):
    if i == 0:
        return 1
    else:
        return i + 2


def relative_change(base, new, round=2):
    change = 100 * (new - base) / base

    if round is None:
        return change
    else:
        return np.round(change, 2).astype(float)

In [None]:
def sampling_simplex_log(exp_0, d, n_samples=None):
    """
    NOTE: exp_0 is different from np.min(np.log10(X)).
    Currently, I don't know how to predict the latter number.
    """
    # Uniformly sample from logspace
    if n_samples is None:
        E0 = np.logspace(1, exp_0, exp_0)
    else:
        E0 = np.logspace(1, exp_0, n_samples)

    # Create a d-dim grid
    E_mesh = np.meshgrid(*[E0] * d)

    # Change order of first 2 elements
    # (somehow the order is not 100% right)
    temp = E_mesh[0]
    E_mesh[0] = E_mesh[1]
    E_mesh[1] = temp

    # Reshape arrays into columns
    E_cols = []
    for M in E_mesh:
        M = M.reshape((-1, 1))
        E_cols.append(M)

    # Create a single array
    E = np.concatenate(E_cols, axis=1)

    # Normalize the sum of each row
    S = np.sum(E, axis=1)
    X = E / np.repeat(S[:, np.newaxis], d, axis=1)

    # Remove duplicate points
    X = np.unique(X, axis=0)

    return X


def sample_simplex_unif(N, d, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    elif isinstance(rng, int):
        rng = np.random.default_rng(seed=rng)

    U = rng.uniform(size=(N, d))
    E = -np.log10(U)
    S = np.sum(E, axis=1)
    nus = E / np.repeat(S[:, np.newaxis], d, axis=1)

    return nus

## Graph generators

In [None]:
def cycle_of_cliques(num_groups, gp_size):
    # If group size is 1, create a cycle graph
    if gp_size == 1:
        G = nx.cycle_graph(num_groups)
    # Doesn't work as intended
    elif gp_size == 2:
        raise NotImplementedError
    # Create graphs with larger communities
    else:
        G = nx.ring_of_cliques(num_groups, gp_size)

    return G


# Use the same set of arguments for all generators
def cycle_of_generators(num_groups, generator, *args):
    arglist = [[*args]] * num_groups
    return cycle_of_generators_variable(num_groups, generator, arglist)


# Create several graphs with a graph generator and put them in a cycle
# by joining their 0-th vertices
def cycle_of_generators_variable(num_groups, generator, arglist):
    n_nodes = 0
    t_prev = 0
    for idx in range(num_groups):
        # Create graph and update number of nodes
        G_i = generator(*arglist[idx])
        n_i = G_i.number_of_nodes()

        # Store G_i if we have a single graph
        if idx == 0:
            G = G_i
            n_nodes += n_i
            continue

        # Otherwise, we join G_i and G
        mapping = {t: t + n_nodes for t in range(n_i)}
        G_i = nx.relabel_nodes(G_i, mapping)
        G = nx.compose(G, G_i)

        # Add an edge between G and G_i
        G.add_edge(t_prev, n_nodes)

        # Update markers
        t_prev = n_nodes
        n_nodes += n_i

    # Close the cycle
    G.add_edge(t_prev, 0)

    return G


# Remove the generator argument from cycle_of_generators
def cycle_of_generators_fun(generator):
    def fun(num_groups, *args):
        return cycle_of_generators(num_groups, generator, *args)

    return fun

# Wrapper for the generators of nested cycles
def nested_cycles(*args):
    num_args = len(args)

    generator = nx.complete_graph
    for idx in range(1, num_args):
        generator = cycle_of_generators_fun(generator)
    
    return generator(*args)

# Creates a set of graphs with generator(arg) for arg in arglist
# then joins all of them at their 0-th vertex
def wedge_of_generators(generator, arglist):
    n_graphs = len(arglist)
    n_nodes = 0

    for idx in range(n_graphs):
        # Create graph and update number of nodes
        G_i = generator(*arglist[idx])
        n_i = G_i.number_of_nodes()

        # Store G_i if we have a single graph
        if idx == 0:
            G = G_i
            n_nodes += n_i
            continue

        # Otherwise, we join G_i and G
        mapping = {t: t + n_nodes - 1 for t in range(1, n_i)}
        G_i = nx.relabel_nodes(G_i, mapping, copy=False)
        G = nx.compose(G, G_i)

        # Update number of nodes
        n_nodes += n_i

    return G


# Remove the generator argument from cycle_of_generators
def wedge_of_generators_fun(generator):
    def fun(arglist):
        return wedge_of_generators(generator, arglist)

    return fun

## Plotting graphs

In [None]:
def pos_cycle(m, R=1, c=[0,0], t0=np.pi/2):
    # Draw a complete graph around a circle
    tt = t0 + np.linspace(0, 2*np.pi, m+1)
    tt = np.delete(tt, -1)

    xx = R*np.cos(tt)
    yy = R*np.sin(tt)

    pos = np.column_stack([xx, yy])
    return pos

def rotation_matrix(t):
    return np.array([
        [np.cos(t), -np.sin(t)],
        [np.sin(t),  np.cos(t)]
    ])

In [None]:
def pos_cycle_of_graphs(n, pos_list, R=4, t0=np.pi):
    # pos_list is the default position of a single subgraph
    # We incorporate it into the outer cycle in this function

    # Angles in the outer cycle
    tt = np.linspace(0, 2*np.pi, n+1)
    tt = np.delete(tt, -1)

    # Position of the outer cycle
    pos_out = pos_cycle(n, R=R)

    # Place each inner subgraph at a vertex of the outer cycle
    pos_all = np.zeros((0,2))
    for idx in range(n):
        t = tt[idx]
        M = rotation_matrix(t + t0)
        dR = pos_out[idx]

        # I need to rotate and translate the position of the inner cycle
        pos_0 = pos_list[idx]
        pos_all = np.concatenate([pos_all, dR+pos_0 @ M.T], axis=0)
    
    return pos_all

# Calls cycle_of_graphs assuming all subgraphs are the same
def pos_cycle_uniform(n, pos_0, R=4, t0=np.pi):
    return pos_cycle_of_graphs(n, [pos_0]*n, R=R, t0=t0)

def pos_nested_cycles(*args, scale=4):
    num_args = len(args)

    # It's more convenient for the user to pass arguments from
    # outer cycle to inner cycle. However, for us, it's easier to
    # consume the arguments from inner to outer
    args = args[::-1]

    # Recursively compute positions of nested cycles
    pos_new = pos_cycle(args[0])
    for idx in range(1,num_args):
        pos_old = pos_new
        pos_new = pos_cycle_uniform(args[idx], pos_old, R=scale**idx)

    return pos_new

## Block structure

In [None]:
# Extracts n*m blocks from the last two dimensions of the array A.
# Assumes that the number of rows and columns are divisible by n and m
# respectively
def blocks(A, n, m):
    sz = A.shape

    N = sz[-2] // n
    M = sz[-1] // m

    # Reshape expands the array of size [N*n, M*m] into a 4D array
    # of size [N, n, M, m].
    # As a default, numpy reads A row by row when reshaping and fills
    # in the last dimension of an array. So for any given row, we reshape
    # a row with M*m elements into a block of size [M, m].
    # There are N*n rows, which are then arranged into [N,n] blocks of
    # size [M, m].
    # At the end, we just reorder the indices so that we have an N-by-M
    # array of blocks of size n-by-m.
    B = A.reshape(*sz[:-2], N, n, M, m)
    return np.moveaxis(B, -3, -2)


# Just sums the last 2 dimensions of an array
def sum_blocks(A):
    return np.sum(A, axis=(-2, -1))

In [None]:
# Checks if a coupling is a cyclic permutation
# (after filtering for small entries)
def is_cyclic_permutation(T, eps=1e-15):
    # T can only be a permutation if it has the
    # same number of rows and columns
    n = T.shape[0]
    if T.shape[1] != n:
        return False

    # Filter out small entries
    [r, c] = np.where(T > eps)

    # Check if r has the same number of entries as T
    if len(r) != T.shape[0]:
        return False

    # r and c are a bijection if either r-c or r+c
    # is a constant vector
    diff1 = np.mod(r - c, n)
    diff2 = np.mod(r + c, n)

    if len(np.unique(diff1)) == 1 or len(np.unique(diff2)) == 1:
        return True
    else:
        return False


def cyclic_by_levels(tree, dtype=int):
    dist = nx.shortest_path_length(tree, 0)
    diam = max(dist.values())

    # Don't evaluate leaves -- they are not cyclic
    level_cyclic = np.zeros(diam, dtype=bool)
    for level in range(diam):
        # Get nodes in this level
        nodes = [node for node in tree.nodes if dist[node] == level]

        # Evaluate if all nodes are cyclic
        level_cyclic[level] = all([tree.nodes[i]["cyclic"] for i in nodes])

    return level_cyclic.astype(dtype)


# Checks if all block matrices are cyclic permutations.
# Takes as input a networkx tree generated by the function
# extract_multiscale_info
def all_cyclic_permutations(tree):
    # Collect all nodes that have descendants.
    # Remember that a child of a node is the sub-coupling induced
    # by a block of the matrix. Hence, for all nodes with children,
    # we want to check if they all are cyclic permutations.
    non_leaves = [x for x in tree.nodes() if tree.out_degree(x) > 0]

    for node in non_leaves:
        if not tree.nodes[node]["cyclic"]:
            return False

    # If we come out of the loop, every coupling was cyclic
    return True


def extract_multiscale_info(T, scales1, scales2, eps=1e-15):
    # Check that the size of T matches the scales
    assert T.shape[0] == np.prod(
        scales1
    ), "T.shape[0] doesn't match product of elements of scales1"
    assert T.shape[1] == np.prod(
        scales2
    ), "T.shape[1] doesn't match product of elements of scales2"
    assert len(scales1) == len(scales2), "scales1 and scales2 must have the same length"

    # Create a node for this block
    root = nx.DiGraph()
    root.add_node(0)

    if len(scales1) <= 1:
        cyclic = is_cyclic_permutation(T, eps=eps)
        nx.set_node_attributes(root, {0: T}, "coup")
        nx.set_node_attributes(root, {0: cyclic}, "cyclic")
        return root

    # Construct blocks at the highest level and find the sum of the blocks
    sz_x = np.prod(scales1[1:])
    sz_y = np.prod(scales2[1:])
    B = blocks(T, sz_x, sz_y)
    B_sum = sum_blocks(B)

    # Store the compressed coupling in the tree
    # and check if it is a cyclic permutation
    cyclic = is_cyclic_permutation(B_sum, eps=eps)
    nx.set_node_attributes(root, {0: B_sum}, "coup")
    nx.set_node_attributes(root, {0: cyclic}, "cyclic")

    # Find the blocks that have sum larger than eps
    [r, c] = np.where(B_sum > eps)

    # Extract the subblocks of every block
    n_nodes = 1
    for idx in range(len(r)):
        T_sub = B[r[idx], c[idx]]
        sub_tree = extract_multiscale_info(T_sub, scales1[1:], scales2[1:], eps=eps)

        # Add information
        nx.set_node_attributes(sub_tree, {0: (r[idx], c[idx])}, "loc")

        # Add child to root
        # root.add_edge(0, n_nodes)

        # Paste to the root
        n_i = sub_tree.number_of_nodes()
        mapping = {t: t + n_nodes for t in range(n_i)}
        sub_tree = nx.relabel_nodes(sub_tree, mapping)

        # root = nx.compose(root, sub_tree)
        root = nx.union(root, sub_tree)
        root.add_edge(0, n_nodes)

        # update size
        n_nodes += n_i

    # Indicate if all non-leaf nodes have a cyclic permutation
    all_cyclic = all_cyclic_permutations(root)
    nx.set_node_attributes(root, {0: all_cyclic}, "all_cyclic")

    return root

# Examples

In [None]:
# Construct a cycle graph of graphs
sizes = [[3],[4],[5],[6],[7]]
n = len(sizes)
G = cycle_of_generators_variable(n, nx.cycle_graph, sizes)

# Compute positions of inner graphs
pos_list = [pos_cycle(s[0]) for s in sizes]
pos = pos_cycle_of_graphs(n, pos_list)

# Paint basepoints of inner graphs with a different color
bp = 0
ids = 0
color_map = []
for idx in range(G.number_of_nodes()):
    # Paint basepoint red
    if idx==bp:
        color_map.append('red')

        # Find next basepoint
        bp += sizes[ids][0]
        ids += 1
    # Paint other vertices in blue
    else:
        color_map.append('blue')

fig1 = plt.figure()
ax1 = plt.gca()
fig1.tight_layout(pad=0)
nx.draw(G, pos, node_size=50, ax=ax1, node_color=color_map)

In [None]:
params = [5, 4, 7]
G = nested_cycles(*params)
pos = pos_nested_cycles(*params, scale=3)

# Paint basepoints of inner graphs with a different color
bp = 0
ds = np.prod(params[1:])
color_map = []
for idx in range(G.number_of_nodes()):
    # Paint basepoint red
    if idx==bp:
        color_map.append('red')

        # Find next basepoint
        bp += ds
    # Paint other vertices in blue
    else:
        color_map.append('blue')

fig2 = plt.figure()
ax2 = plt.gca()
fig2.tight_layout()
nx.draw(G, pos, node_size=25, ax=ax2, node_color=color_map)

In [None]:
# Save for publication
fig_list = [(fig1, ax1), (fig2, ax2)]
names = ['Cycle_of_graphs.pdf', 'Cycle_of_cliques.pdf']

for idx, (fig, ax) in enumerate(fig_list):
    name = names[idx]

    # Remove whitespace around NetworkX plots
    # (it didn't work, but I'm leaving it here for reference)
    ax.set_axis_off()
    fig.subplots_adjust(top=1, bottom=0, right=1, left=0,
                        hspace=0, wspace=0)
    ax.margins(0,0)
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    fig.savefig(Path(folder_figs, name), bbox_inches="tight", pad_inches=0)

# Heat kernels of 2-nested cycles

## Generate graphs

In [None]:
# list: [outer_cycle, inner_cycle]
cycles = [10, 5]
gp_sizes = [5, 20]
params2 = list(product([cycles], [[sz] for sz in gp_sizes]))
params2 = [tuple(chain.from_iterable(param)) for param in params2]

print(params2)
nGraphs2 = len(params2)

# Generator function
# Args: num_groups, group_size
f = cycle_of_generators_fun(nx.complete_graph)
# Args: num_groups_outer, num_groups_inner, group_size
generator = cycle_of_generators_fun(f)

In [None]:
plot = False
if plot:
    fig, axes = plt.subplots(1, nGraphs2, figsize=(20, 5))

# Create several graphs that are almost cycles
sp = 0
Graphs2 = []
# pgGraphs = []
for idx, param in enumerate(params2):
    G = generator(*param)

    # Plot the graph
    if plot:
        nx.draw(G, ax=axes[idx], with_labels=True)

    Graphs2.append(G)

In [None]:
fig, axes = plt.subplots(1, nGraphs2, figsize=(20, 5))

# Show the adjacency matrices
for idx, G in enumerate(Graphs2):
    A = nx.adjacency_matrix(G).todense()
    A = A.astype(np.uint8)

    axes[idx].imshow(A)

# plt.imshow(nx.adjacency_matrix(Graphs[0]).todense())

## Generate heat kernels

In [None]:
normalize = False
scale_0 = False
compute_diagonal = False

step0_all = 10
step_all = 10
nSteps_all = 50
Times_all = step0_all + np.arange(0, nSteps_all * step_all, step_all)
print(Times_all)

In [None]:
Hs_all = []

for idg in range(nGraphs2):
    G = Graphs2[idg]

    # We have one matrix per time step
    H_G = []
    for idt in range(nSteps_all):
        time_start = time()
        t = Times_all[idt]

        H = heat_kernel(G, t)
        if normalize:
            # H /= np.max(H)
            H /= np.linalg.norm(H)

        if t == 0 and scale_0:
            H /= 100

        H_G.append(H)

        time_end = time()
        print(
            "(%i, %i)/(%i, %i): %s"
            % (
                idg + 1,
                idt + 1,
                nGraphs2,
                nSteps_all,
                display_time(time_end - time_start),
            )
        )

    # if idg == 1:
    # H_G = H_G[::-1]
    Hs_all.append(H_G)

### Heat kernel graphs

Show ranges of heat kernels

In [None]:
# 0: min
# 1: max
# 2: average
H_ranges_all = np.zeros((nGraphs2, nSteps_all, 3))
H_ranges_normalized = np.zeros((nGraphs2, nSteps_all, 3))


for idg in range(nGraphs2):
    for idt in range(nSteps_all):
        # unnormalized
        H = Hs_all[idg][idt]
        H_ranges_all[idg, idt, 0] = np.min(H)
        H_ranges_all[idg, idt, 1] = np.max(H)
        H_ranges_all[idg, idt, 2] = np.mean(H)

        # normalized
        H = H / np.linalg.norm(H)
        H_ranges_normalized[idg, idt, 0] = np.min(H)
        H_ranges_normalized[idg, idt, 1] = np.max(H)
        H_ranges_normalized[idg, idt, 2] = np.mean(H)

In [None]:
colors = ["red", "blue"]

fig, axes = plt.subplots(2, 2, sharey=True, figsize=(10, 10))

# Heat kernels
for idg in range(nGraphs2):
    axes[0, 0].plot(Times_all, H_ranges_all[idg, :, 0], colors[idg])
    axes[0, 0].plot(
        Times_all, H_ranges_all[idg, :, 1], colors[idg], label=f"Graph {idg}"
    )
    axes[0, 0].plot(Times_all, H_ranges_all[idg, :, 2], colors[idg], linestyle="--")
axes[0, 0].set_title("Unnormalized")
axes[0, 0].legend()
axes[0, 0].set_yscale("log")

# Normalized heat kernel
for idg in range(nGraphs2):
    axes[0, 1].plot(Times_all, H_ranges_normalized[idg, :, 0], colors[idg])
    axes[0, 1].plot(
        Times_all, H_ranges_normalized[idg, :, 1], colors[idg], label=f"Graph {idg}"
    )
    axes[0, 1].plot(
        Times_all, H_ranges_normalized[idg, :, 2], colors[idg], linestyle="--"
    )
axes[0, 1].set_title("Normalized")
axes[0, 1].legend()
axes[0, 1].set_yscale("log")


for idg in range(nGraphs2):
    axes[1, 0].plot(
        Times_all,
        H_ranges_all[idg, :, 1] - H_ranges_all[idg, :, 0],
        color=colors[idg],
        label=f"Graph {idg}",
    )
    axes[1, 0].set_title("Spread")
    axes[1, 0].legend()

    axes[1, 1].plot(
        Times_all,
        H_ranges_normalized[idg, :, 1] - H_ranges_normalized[idg, :, 0],
        color=colors[idg],
        label=f"Graph {idg}",
    )
    axes[1, 1].set_title("Spread")
    axes[1, 1].legend()

axes[1, 0].set_yscale("log")
axes[1, 1].set_yscale("log")

## GW distance

In [None]:
# Compute dGW between each step
dGW_all = np.zeros((nSteps_all, nGraphs2, nGraphs2))
Ts_all = np.zeros((nSteps_all, nGraphs2, nGraphs2), dtype=object)
# logs_ss = np.zeros((nGraphs2, nGraphs2), dtype=object)
Trees_all = []

for idt in range(nSteps_all):
    print("Step: %i/%i" % (idt + 1, nSteps_all))
    for i in range(nGraphs2):
        H1 = Hs_all[i][idt]

        for j in range(i, nGraphs2):
            if j == i and not compute_diagonal:
                continue

            print((i, j))
            H2 = Hs_all[j][idt]

            time_start = time()
            T, log = gromov_wasserstein(
                H1,
                H2,
                log=True,
                verbose=0,
                # max_iter=1e3,
                tol_abs=1e-15,
                tol_rel=1e-15,
            )
            time_end = time()
            print("Time: " + display_time(time_end - time_start))
            print()

            dGW_all[i, j] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_all[idt, i, j] = T
            # logs_ss[i, j] = log

            dGW_all[j, i] = 0.5 * np.sqrt(log["gw_dist"])
            # Ts_ss[idt, j, i] = T.T
            # logs_ss[j, i] = log

            # Block tree
            Trees_all.append(extract_multiscale_info(T, params2[0], params2[1]))

In [None]:
cols = np.min([nSteps_all, 5])
ps = [int(np.ceil(nSteps_all / cols)), cols]
scale = 4
fig, axes = plt.subplots(ps[0], ps[1], figsize=(scale * ps[1], scale * ps[0]))

# Plot all couplings
for idt in range(nSteps_all):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs2):
        for j in range(i + 1, nGraphs2):
            T = Ts_all[idt, i, j]
            cyclic = cyclic_by_levels(Trees_all[idt])

            if ps[0] > 1:
                [r, c] = np.unravel_index(idt, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]
            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )
            axes[*sp_idx].set_title("t=%i, cyclic=%s" % (Times_all[idt], str(cyclic)))
            # plt.colorbar(im)

### Format graphs for paper

In [None]:
Times_plot = [10, 50, 100, 150, 200]
num_plots = len(Times_plot)

with matplotlib.rc_context({"font.size": 25}):
    scale = 5
    fig, axes = plt.subplots(
        1, num_plots, figsize=(scale * num_plots, scale + 0.5), sharey=False
    )
    for idx in range(num_plots):
        # Find index in global times list
        idt = np.where(Times_all == Times_plot[idx])[0][0]

        # Load coupling and find at which levels it's cyclic
        T = Ts_all[idt, 0, 1]
        cyclic = cyclic_by_levels(Trees_all[idt])

        # Plot
        im = axes[idx].imshow(
            T,
            vmin=0,
            vmax=0.0001,
            # Plots matrices as squares
            aspect="auto",
            # Pixels are enlarged to improve visibility
            interpolation="auto",
            interpolation_stage="data",
        )
        axes[idx].set_title("t=%i\ncyclic=%s" % (Times_all[idt], str(cyclic)))

    # Save for publication
    plt.tight_layout()
    plt.savefig(Path(folder_figs, "Nested_2_cycles_GW_small.pdf"), bbox_inches="tight")

## Multiscale distance, 2 levels

In [None]:
# Select times
Times_nu = [50, 200]
I_nu = np.where(np.isin(Times_all, Times_nu))[0]

# Search over nu
nu_log_0 = -10
nNus = 20
nus = np.logspace(nu_log_0, 0, nNus + 1)

# Reshape
nus = nus[:-1, np.newaxis]

# Make each row sum 1
nus = np.concatenate((nus, 1 - nus), axis=1)

print(nus)
print("nNus:", nNus)

In [None]:
# dMS -- vary nu
dMS_nu = np.zeros((nNus, nGraphs2, nGraphs2))
Ts_nu = np.zeros((nNus, nGraphs2, nGraphs2), dtype=object)
Trees_nu = []  # Compute block trees

for idv in range(nNus):
    print("nu %i/%i" % (idv + 1, nNus))
    with np.printoptions(precision=3):
        print(nus[idv, :])

    for i in range(nGraphs2):
        for j in range(i, nGraphs2):
            if j == i and not compute_diagonal:
                continue

            print((i, j))

            # Load cost matrices and nu
            nu = nus[idv, :]
            H1 = [Hs_all[i][idt] for idt in I_nu]
            H2 = [Hs_all[j][idt] for idt in I_nu]

            # dMS
            time_start = time()
            T, log = gromov_wasserstein_ms(
                H1, H2, nu=nu, log=True, verbose=0, tol_abs=1e-30, tol_rel=1e-30
            )
            time_end = time()
            print("dMS_nu:" + display_time(time_end - time_start))

            dMS_nu[idv, i, j] = 0.5 * np.sqrt(log["gw_dist"])
            dMS_nu[idv, j, i] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_nu[idv, i, j] = T
            Ts_nu[idv, j, i] = T.T

            # Compute block tree
            Trees_nu.append(
                extract_multiscale_info(Ts_nu[idv, 0, 1], params2[0], params2[1])
            )

            print()

In [None]:
cols = np.min([nNus, 5])
ps = [int(np.ceil(nNus / cols)), cols]
fig, axes = plt.subplots(ps[0], ps[1], figsize=(3.5 * ps[1], 3.5 * ps[0]))

# Plot all couplings
for idv in range(nNus):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs2):
        for j in range(i + 1, nGraphs2):
            T = Ts_nu[idv, i, j]
            cyclic = cyclic_by_levels(Trees_nu[idv])

            # im = axes[i, j].imshow(T, vmin=0, vmax=0.0001, aspect="auto")
            if ps[0] > 1:
                [r, c] = np.unravel_index(idv, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]
            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )

            axes[*sp_idx].set_title(
                "nu[0]=%0.2E, cyclic: %s" % (nus[idv, 0], str(cyclic))
            )

### Format graphs for paper

In [None]:
nus_plot = [4, 12, 14, 16, 18]
num_plots = len(nus_plot)

with matplotlib.rc_context({"font.size": 25}):
    scale = 5
    fig, axes = plt.subplots(
        1, num_plots, figsize=(scale * num_plots, scale + 0.5), sharey=False
    )
    for idx in range(num_plots):
        # Find index in global nu list
        idv = nus_plot[idx]

        # Load coupling and find at which levels it's cyclic
        T = Ts_nu[idv, 0, 1]
        cyclic = cyclic_by_levels(Trees_nu[idv])

        # Plot
        im = axes[idx].imshow(
            T,
            vmin=0,
            vmax=0.0001,
            # Plots matrices as squares
            aspect="auto",
            # Pixels are enlarged to improve visibility
            interpolation="auto",
            interpolation_stage="data",
        )

        axes[idx].set_title("nu[0]=%0.2E\ncyclic=%s" % (nus[idv, 0], str(cyclic)))

    # Save for publication
    plt.tight_layout()
    plt.savefig(
        Path(folder_figs, "Nested_2_cycles_MS_2_small.pdf"), bbox_inches="tight"
    )

## Multiscale distance, 3 levels

In [None]:
# Select times
Times_nu = [50, 100, 200]
I_nu = np.where(np.isin(Times_all, Times_nu))[0]

# Search over nu
nu_log_0 = -5
nNus_0 = 5
nus_0 = np.logspace(nu_log_0, 0, nNus_0 + 1)
# Remove 10^0
nus_0 = np.delete(nus_0, -1)

print(nus_0)
print()

[nus_1, nus_2] = np.meshgrid(nus_0, nus_0)
nus_1 = np.reshape(nus_1, (-1, 1))
nus_2 = np.reshape(nus_2, (-1, 1))
nus_3 = 1 - (nus_1 + nus_2)

# Vary weights on the 3 elements
nus_v3 = np.concatenate(
    (
        np.concatenate((nus_1, nus_2, nus_3), 1),
        np.concatenate((nus_1, nus_3, nus_2), 1),
        np.concatenate((nus_3, nus_1, nus_2), 1),
    ),
    axis=0,
)

nNus = nus_v3.shape[0]
print("nNus:", nNus)

In [None]:
# dMS -- vary nu
dMS_nu_v3 = np.zeros((nNus, nGraphs2, nGraphs2))
Ts_nu_v3 = np.zeros((nNus, nGraphs2, nGraphs2), dtype=object)
Trees_nu_v3 = []

for idv in range(nNus):
    print("nu %i/%i" % (idv + 1, nNus))
    with np.printoptions(precision=3):
        print(nus_v3[idv, :])

    for i in range(nGraphs2):
        for j in range(i, nGraphs2):
            if j == i and not compute_diagonal:
                continue

            print((i, j))

            # Load cost matrices and nu
            nu = nus_v3[idv, :]
            H1 = [Hs_all[i][idt] for idt in I_nu]
            H2 = [Hs_all[j][idt] for idt in I_nu]

            # dMS
            time_start = time()
            T, log = gromov_wasserstein_ms(
                H1, H2, nu=nu, log=True, verbose=0, tol_abs=1e-30, tol_rel=1e-30
            )
            time_end = time()
            print("dMS_nu:" + display_time(time_end - time_start))

            dMS_nu_v3[idv, i, j] = 0.5 * np.sqrt(log["gw_dist"])
            dMS_nu_v3[idv, j, i] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_nu_v3[idv, i, j] = T
            Ts_nu_v3[idv, j, i] = T.T

            # Compute block tree
            Trees_nu_v3.append(extract_multiscale_info(T, params2[0], params2[1]))
            print()

In [None]:
cols = np.min([nNus, 5])
ps = [int(np.ceil(nNus / cols)), cols]
fig, axes = plt.subplots(ps[0], ps[1], figsize=(3.5 * ps[1], 3.5 * ps[0]))
plt.subplots_adjust(hspace=0.35)

# Plot all couplings
for idv in range(nNus):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs2):
        for j in range(i + 1, nGraphs2):
            T = Ts_nu_v3[idv, i, j]
            cyclic = cyclic_by_levels(Trees_nu_v3[idv])

            # im = axes[i, j].imshow(T, vmin=0, vmax=0.0001, aspect="auto")
            if ps[0] > 1:
                [r, c] = np.unravel_index(idv, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]

            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )

            # axes[*sp_idx].set_title('nus=[%0.2E, %0.2E, %0.2E]' % tuple(nus_v3[idv,:]) )
            axes[*sp_idx].set_title(
                "nu=[%0.2E, %0.2E]\ncyclic=%s" % (*nus_v3[idv, :-1], str(cyclic))
            )

### Format graphs for paper

In [None]:
nus_plot = [4, 12, 14, 16, 18]
num_plots = len(nus_plot)

with matplotlib.rc_context({"font.size": 22}):
    scale = 6
    fig, axes = plt.subplots(1, num_plots, figsize=(scale * num_plots + 5, scale))
    for idx in range(num_plots):
        # Find index in global nu list
        idv = nus_plot[idx]

        # Load coupling and find at which levels it's cyclic
        T = Ts_nu_v3[idv, 0, 1]
        cyclic = cyclic_by_levels(Trees_nu_v3[idv])

        # Plot
        im = axes[idx].imshow(
            T,
            vmin=0,
            vmax=0.0001,
            # Plots matrices as squares
            aspect="auto",
            # Pixels are enlarged to improve visibility
            interpolation="auto",
            interpolation_stage="data",
        )

        axes[idx].set_title("nu[0]=%0.2E\ncyclic=%s" % (nus[idv, 0], str(cyclic)))

    # Save for publication
    # plt.tight_layout()
    # plt.savefig(
    #     Path(folder_figs, 'Nested_2_cycles_MS_2_small.pdf'),
    #     bbox_inches='tight'
    #     )

# Heat kernels of 3-nested cycles

## Generate graphs

In [None]:
# list: [outer_cycle, middle_cycle, inner_cycle]
cycles = [4, 4, 4]
gp_sizes = [5, 20]
params3 = list(product([cycles], [[sz] for sz in gp_sizes]))
params3 = [tuple(chain.from_iterable(param)) for param in params3]
print(params3)

nGraphs3 = len(params3)

# Generator function
# Args: num_groups, group_size
f = cycle_of_generators_fun(nx.complete_graph)
# Args: num_groups_outer, num_groups_inner, group_size
f = cycle_of_generators_fun(f)
# Args: num_groups_outer, num_groups_middle, num_groups_inner, group_size
generator = cycle_of_generators_fun(f)


# Sources for heat kernel
sources = np.ones(nGraphs3, dtype=int)
sources[0] = 0

In [None]:
plot = False
if plot:
    fig, axes = plt.subplots(1, nGraphs3, figsize=(20, 5))

# Create several graphs that are almost cycles
sp = 0
Graphs3 = []
for idx, param in enumerate(params3):
    G = generator(*param)

    # Plot the graph
    if plot:
        nx.draw(G, ax=axes[idx], with_labels=True)

    Graphs3.append(G)

In [None]:
for G in Graphs3:
    print(G.number_of_nodes())

In [None]:
fig, axes = plt.subplots(1, nGraphs3, figsize=(20, 5))

# Show the adjacency matrices
for idx, G in enumerate(Graphs3):
    A = nx.adjacency_matrix(G).todense()
    A = A.astype(np.uint8)

    axes[idx].imshow(A)

# plt.imshow(nx.adjacency_matrix(Graphs[0]).todense())

## Generate heat kernels

In [None]:
normalize = True
scale_0 = False
compute_diagonal = False

step0_all = 10
step_all = 10
nSteps_all = 50
Times_all = step0_all + np.arange(0, nSteps_all * step_all, step_all)
print(Times_all)

In [None]:
Hs_all = []

for idg in range(nGraphs3):
    G = Graphs3[idg]

    # We have one matrix per time step
    H_G = []
    for idt in range(nSteps_all):
        time_start = time()
        t = Times_all[idt]

        H = heat_kernel(G, t)
        if normalize:
            # H /= np.max(H)
            H /= np.linalg.norm(H)

        if t == 0 and scale_0:
            H /= 100

        H_G.append(H)

        time_end = time()
        print(
            "(%i, %i)/(%i, %i): %s"
            % (
                idg + 1,
                idt + 1,
                nGraphs3,
                nSteps_all,
                display_time(time_end - time_start),
            )
        )

    # if idg == 1:
    # H_G = H_G[::-1]
    Hs_all.append(H_G)

## GW distance

In [None]:
# Compute dGW between at step
dGW_all = np.zeros((nSteps_all, nGraphs3, nGraphs3))
Ts_all = np.zeros((nSteps_all, nGraphs3, nGraphs3), dtype=object)
# logs_ss = np.zeros((nGraphs3, nGraphs3), dtype=object)
Trees_all = []

for idt in range(nSteps_all):
    print("Step: %i/%i" % (idt + 1, nSteps_all))
    for i in range(nGraphs3):
        H1 = Hs_all[i][idt]

        for j in range(i, nGraphs3):
            if j == i and not compute_diagonal:
                continue

            print((i, j))
            H2 = Hs_all[j][idt]

            time_start = time()
            T, log = gromov_wasserstein(
                H1, H2, log=True, verbose=0, tol_abs=1e-30, tol_rel=1e-30
            )
            time_end = time()
            print("Time: " + display_time(time_end - time_start))
            print()

            dGW_all[i, j] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_all[idt, i, j] = T
            # logs_ss[i, j] = log

            dGW_all[j, i] = 0.5 * np.sqrt(log["gw_dist"])
            # Ts_ss[idt, j, i] = T.T
            # logs_ss[j, i] = log

            # Block tree
            Trees_all.append(extract_multiscale_info(T, params3[0], params3[1]))

In [None]:
cols = np.min([nSteps_all, 5])
ps = [int(np.ceil(nSteps_all / cols)), cols]
fig, axes = plt.subplots(ps[0], ps[1], figsize=(3.5 * ps[1], 3.5 * ps[0]))

# Plot all couplings
for idt in range(nSteps_all):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs3):
        for j in range(i + 1, nGraphs3):
            T = Ts_all[idt, i, j]
            cyclic = cyclic_by_levels(Trees_all[idt])

            # im = axes[i, j].imshow(T, vmin=0, vmax=0.0001, aspect="auto")
            if ps[0] > 1:
                [r, c] = np.unravel_index(idt, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]

            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )
            axes[*sp_idx].set_title("t=%i, cyclic=%s" % (Times_all[idt], str(cyclic)))

    # # Delete unused plots
    # for i in range(nGraphs):
    #     for j in range(i, nGraphs):
    #         fig.delaxes(axes[j][i])

    # plt.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5)

### Format graphs for paper

In [None]:
Times_plot = [10, 30, 100, 200, 410]
num_plots = len(Times_plot)

with matplotlib.rc_context({"font.size": 25}):
    scale = 5
    fig, axes = plt.subplots(
        1, num_plots, figsize=(scale * num_plots, scale + 1), sharey=True
    )
    for idx in range(num_plots):
        # Find index in global times list
        idt = np.where(Times_all == Times_plot[idx])[0][0]

        # Load coupling and find at which levels it's cyclic
        T = Ts_all[idt, 0, 1]
        cyclic = cyclic_by_levels(Trees_all[idt])

        # Plot
        im = axes[idx].imshow(
            T,
            vmin=0,
            vmax=0.0001,
            # Plots matrices as squares
            aspect="auto",
            # Pixels are enlarged to improve visibility
            interpolation="auto",
            interpolation_stage="data",
        )
        axes[idx].set_title("t=%i\ncyclic=%s" % (Times_all[idt], str(cyclic)))

    # Save for publication
    plt.tight_layout()
    plt.savefig(Path(folder_figs, "Nested_3_cycles_GW.pdf"), bbox_inches="tight")

## Multiscale distance, 2 levels

In [None]:
# Sample probability measures spaced exponentially
nus_v2 = sampling_simplex_log(10, 2, n_samples=13)
nNus_v2 = nus_v2.shape[0]

# Select times
Times_nu_v2 = [30, 200]
I_nu_v2 = np.where(np.isin(Times_all, Times_nu_v2))[0]

# With:
# - cycles = [4, 4, 4]
# - gp_sizes = [5, 20]
# Tried: (didn't work)
# Times_nu_v2 = [30, 300]

# With:
# - cycles = [5, 5, 5]
# - gp_sizes = [5, 10]
# Tried:
# Times_nu_v2 = [30, 150]
# Times_nu_v2 = [30, 200]
# Times_nu_v2 = [30, 300]
# Times_nu_v2 = [30, 450]

# Times_nu_v2 = [50, 390]
# Times_nu_v2 = [50, 400]
# Times_nu_v2 = [50, 450]

# print(nus_v2)
print("nNus:", nNus_v2)

In [None]:
# dMS -- vary nu
dMS_nu_v2 = np.zeros((nNus_v2, nGraphs3, nGraphs3))
Ts_nu_v2 = np.zeros((nNus_v2, nGraphs3, nGraphs3), dtype=object)
Trees_nu_v2 = []

for idv in range(nNus_v2):
    print("nu %i/%i" % (idv + 1, nNus_v2))
    with np.printoptions(precision=3, suppress=False):
        print(nus_v2[idv, :])

    for i in range(nGraphs3):
        for j in range(i, nGraphs3):
            if j == i and not compute_diagonal:
                continue

            print((i, j))

            # Load cost matrices and nu
            nu = nus_v2[idv, :]
            H1 = [Hs_all[i][idt] for idt in I_nu_v2]
            H2 = [Hs_all[j][idt] for idt in I_nu_v2]

            # dMS
            time_start = time()
            T, log = gromov_wasserstein_ms(
                H1, H2, nu=nu, log=True, verbose=0, tol_abs=1e-30, tol_rel=1e-30
            )
            time_end = time()
            print("dMS_nu:" + display_time(time_end - time_start))

            dMS_nu_v2[idv, i, j] = 0.5 * np.sqrt(log["gw_dist"])
            dMS_nu_v2[idv, j, i] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_nu_v2[idv, i, j] = T
            Ts_nu_v2[idv, j, i] = T.T

            # Compute block tree
            Trees_nu_v2.append(extract_multiscale_info(T, params3[0], params3[1]))
            print()

In [None]:
cols = np.min([nNus_v2, 5])
ps = [int(np.ceil(nNus_v2 / cols)), cols]
fig, axes = plt.subplots(ps[0], ps[1], figsize=(3.5 * ps[1], 3.5 * ps[0]))
plt.subplots_adjust(hspace=0.35)

# Plot all couplings
for idv in range(nNus_v2):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs3):
        for j in range(i + 1, nGraphs3):
            T = Ts_nu_v2[idv, i, j]
            cyclic = cyclic_by_levels(Trees_nu_v2[idv])

            # im = axes[i, j].imshow(T, vmin=0, vmax=0.0001, aspect="auto")
            if ps[0] > 1:
                [r, c] = np.unravel_index(idv, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]
            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )

            # axes[*sp_idx].set_title('nus=[%0.2E, %0.2E, %0.2E]' % tuple(nus_v2[idv,:]) )
            axes[*sp_idx].set_title(
                "nu=[%0.2E, %0.2E]\ncyclic: %s (id=%i)"
                % (*nus_v2[idv, :], str(cyclic), idv)
            )

## Multiscale distance, 3 levels

In [None]:
# Sample probability measures spaced exponentially
nus_v3 = sampling_simplex_log(6, 3, n_samples=10)
# nus_v3 = sampling_simplex_log(5, 3)

# Subset only a specific region
subset = np.logical_and(1e-4 <= nus_v3[:, 0], nus_v3[:, 0] < 10 ** (-3.5))
nus_v3 = nus_v3[subset, :]

nNus_v3 = nus_v3.shape[0]

# Select times
Times_nu_v3 = [30, 100, 410]

# Old:
# Times_nu_v3 = [50, 100, 200]
# Times_nu_v3 = [30, 150, 300]

# Bad:
# Times_nu_v3 = [50, 200, 470]

I_nu_v3 = np.where(np.isin(Times_all, Times_nu_v3))[0]

print(nus_v3)
print("nNus:", nNus_v3)

In [None]:
plt.plot(np.log10(nus_v3), ".")
plt.legend(range(3))

In [None]:
# dMS -- vary nu
dMS_nu_v3 = np.zeros((nNus_v3, nGraphs3, nGraphs3))
Ts_nu_v3 = np.zeros((nNus_v3, nGraphs3, nGraphs3), dtype=object)
Trees_nu_v3 = []

for idv in range(nNus_v3):
    print("nu %i/%i" % (idv + 1, nNus_v3))
    with np.printoptions(precision=3, suppress=False):
        print(nus_v3[idv, :])

    for i in range(nGraphs3):
        for j in range(i, nGraphs3):
            if j == i and not compute_diagonal:
                continue

            print((i, j))

            # Load cost matrices and nu
            nu = nus_v3[idv, :]
            H1 = [Hs_all[i][idt] for idt in I_nu_v3]
            H2 = [Hs_all[j][idt] for idt in I_nu_v3]

            # dMS
            time_start = time()
            T, log = gromov_wasserstein_ms(
                H1,
                H2,
                nu=nu,
                log=True,
                verbose=0,
                max_iter=1e3,
                numItermaxEmd=1e5,
                tol_abs=1e-30,
                tol_rel=1e-30,
            )
            time_end = time()
            print("dMS_nu:" + display_time(time_end - time_start))

            # Remove noise from T
            # T = T - T*(T<1e-15)

            dMS_nu_v3[idv, i, j] = 0.5 * np.sqrt(log["gw_dist"])
            dMS_nu_v3[idv, j, i] = 0.5 * np.sqrt(log["gw_dist"])
            Ts_nu_v3[idv, i, j] = T
            Ts_nu_v3[idv, j, i] = T.T

            # Compute block tree
            Trees_nu_v3.append(extract_multiscale_info(T, params3[0], params3[1]))
            print()

In [None]:
cols = np.min([nNus_v3, 5])
ps = [int(np.ceil(nNus_v3 / cols)), cols]
fig, axes = plt.subplots(ps[0], ps[1], figsize=(3.5 * ps[1], 3.5 * ps[0]))
plt.subplots_adjust(hspace=0.35)

# Plot all couplings
for idv in range(nNus_v3):
    # fig, axes = plt.subplots(nGraphs, nGraphs, figsize=(12, 12))

    # Plot all couplings with the same colorbar
    for i in range(nGraphs3):
        for j in range(i + 1, nGraphs3):
            T = Ts_nu_v3[idv, i, j]
            T = T - T * (T < 1e-15)  # Remove noise
            cyclic = cyclic_by_levels(Trees_nu_v3[idv])

            if ps[0] > 1:
                [r, c] = np.unravel_index(idv, ps)
                sp_idx = [r, c]
            else:
                sp_idx = [idt]

            # Plot
            im = axes[*sp_idx].imshow(
                T,
                vmin=0,
                vmax=0.0001,
                # Plots matrices as squares
                aspect="auto",
                # Pixels are enlarged to improve visibility
                interpolation="auto",
                interpolation_stage="data",
            )

            axes[*sp_idx].set_title(
                "nu=[%0.2E, %0.2E]\ncyclic: %s (id=%i)"
                % (*nus_v3[idv, :-1], str(cyclic), idv)
            )

            print(
                "{:3d}/{} Cyclic: {} (num={})".format(
                    idv, nNus_v3, str(cyclic), np.sum(cyclic)
                )
            )
            if all(cyclic):
                print(f" -- All cyclic! --")

### Format graphs for paper

In [None]:
# Compile examples of nus to put on the paper
nus_paper = []

# First version of the experiment
nus_pre = sampling_simplex_log(5, 3)

# 58 -- good on smallest scale
# 8 -- good on intermediate scale
nus_paper.append(nus_pre[58, :])
nus_paper.append(nus_pre[8, :])

# Winning coupling
nus_paper.append(nus_v3[42, :])

nus_paper = np.array(nus_paper)
nNus_paper = nus_paper.shape[0]
print(nus_paper)

In [None]:
# dMS -- vary nu
Ts_paper = np.zeros(nNus_paper, dtype=object)
Trees_paper = []

for idv in range(nNus_paper):
    print("nu %i/%i" % (idv + 1, nNus_paper))
    with np.printoptions(precision=3, suppress=False):
        print(nus_paper[idv, :])

    # Load cost matrices and nu
    nu = nus_paper[idv, :]
    H1 = [Hs_all[0][idt] for idt in I_nu_v3]
    H2 = [Hs_all[1][idt] for idt in I_nu_v3]

    # dMS
    time_start = time()
    T, log = gromov_wasserstein_ms(
        H1,
        H2,
        nu=nu,
        log=True,
        verbose=0,
        max_iter=1e3,
        numItermaxEmd=1e5,
        tol_abs=1e-30,
        tol_rel=1e-30,
    )
    time_end = time()
    print("dMS_nu:" + display_time(time_end - time_start))

    # Remove noise from T
    # T = T - T*(T<1e-15)

    Ts_paper[idv] = T

    # Compute block tree
    Trees_paper.append(extract_multiscale_info(T, params3[0], params3[1]))
    print()

In [None]:
with matplotlib.rc_context({"font.size": 22}):
    scale = 6
    fig, axes = plt.subplots(
        1, nNus_paper, figsize=(scale * nNus_paper, scale + 0.5), sharey=False
    )
    for idx in range(nNus_paper):
        # Load coupling and find at which levels it's cyclic
        T = Ts_paper[idx]
        cyclic = cyclic_by_levels(Trees_paper[idx])

        # Plot
        im = axes[idx].imshow(
            T,
            vmin=0,
            vmax=0.0001,
            # Plots matrices as squares
            aspect="auto",
            # Pixels are enlarged to improve visibility
            interpolation="auto",
            interpolation_stage="data",
        )

        axes[idx].set_title(
            "nu[0:1]=[%0.2E, %0.2E]\ncyclic: %s" % (*nus_paper[idx, :-1], str(cyclic))
        )

    # Save for publication
    plt.tight_layout()
    # plt.savefig(Path(folder_figs, "Nested_3_cycles_MS_3.pdf"), bbox_inches="tight")