In [None]:
! pip install pot



In [None]:
import numpy as np
import ot
from scipy.optimize.linesearch import scalar_search_armijo
from ot.lp import emd
import numpy as np
import scipy as sp
from scipy.spatial.distance import cdist
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import floyd_warshall
from collections import deque
import copy

  from scipy.optimize.linesearch import scalar_search_armijo


In [None]:
################################################################################
# FINAL GFI CODE
################################################################################


class GraphIntegrator():
  # * f_func: function of signature R -> R; either given as lambda or a a pair
  #     of two lists ([a_0, a_1,...,a_{t-1}], [b_0, b_1,..., b_{r-1}]). In the
  #     latter case, the lists encode a rational function:
  #
  #        f(x) = (a_0 + a_1 * x + ... + a_{t-1} * x^{t-1}) /
  #               (b_0 + b_1 * x + ... + b_{r-1} * x^{r-1})
  #
  #
  # * is_lambda: boolean indicating whether f_func above is given as a lambda
  #     or a pair of two lists of coefficients (as described above)
  # * graph_adj_lists: the adjacency lists encoding weighted undirected graph:
  #     graph_adj_lists[i][j] is a pair of the form (k, w), where k is the id
  #     of the jth neighbor of i (we start counting from 0) and w is the weight
  #     of an edge connecting i with k. We assume that graph nodes have
  #     identifiers: 0, 1, 2, ..., N-1, where N is the number of the nodes of
  #     the graph.
  def __init__(self, f_func, is_lambda, graph_adj_lists):
    self.f_func = f_func
    self.is_lambda = is_lambda
    self.graph_adj_lists = graph_adj_lists
    self.N = len(graph_adj_lists)
  # * X_tensor: tensor of the shape: N x b_1 x b_2 x ... b_s, where: N is the
  #     number of nodes of the graph and b_1, b_2, ... b_s are sizes of batch
  #     dimensions (arbitrary number of them).
  # * Output: Tensor Y = einsum("mn,n...->m...", M, X_tensor), where M is the
  #     N x N matrix satisfying: M[i][j] ~ f_func(dist(i,j)) and dist is the
  #     shortest path distance between i and j in the graph.
  def integrate(self, X_tensor):
    pass

# Auxiliary functions for the brute-force integrator.
def compute_shortest_path_distances(graph_adj_lists):
  N = len(graph_adj_lists)
  edges = np.zeros((N, N))
  for i in range(N):
    for j, w in graph_adj_lists[i]:
      edges[i,j] = w
      edges[j,i] = w
  csr_adjacency = csr_matrix(edges)
  return floyd_warshall(csgraph=csr_adjacency, directed=False)

def poly(x, coeff_list):
  accum = 0
  x_power = 1
  for i in range(len(coeff_list)):
    accum += x_power * coeff_list[i]
    x_power *= x
  return accum

class BruteForceGraphIntegrator(GraphIntegrator):
  def __init__(self, f_func, is_lambda, graph_adj_lists):
    super().__init__(f_func, is_lambda, graph_adj_lists)
    self.M = compute_shortest_path_distances(self.graph_adj_lists)
    for i in range(self.N):
      for j in range(self.N):
        if not self.is_lambda:
          numerator = poly(self.M[i][j], self.f_func[0])
          denominator = poly(self.M[i][j], self.f_func[1])
          self.M[i][j] = numerator / denominator
        else:
          self.M[i][j] = self.f_func(self.M[i][j])
  def integrate(self, X_tensor):
    return np.einsum("mn,n...->m...", self.M, X_tensor)
  def get_m_matrix(self):
    return self.M

# Low-level auxiliary functions for main auxiliary functions:
#
# 1. integrate_on_tree,
# 2. preprocess_tree,
# 3. partition_tree,
# 4. compute_struct_for_merge.
#
class CompTree():
  def __init__(self, left_child, right_child, left_id_sets, right_id_sets,
               left_distances, right_distances, left_ids, right_ids, bfgi):
    self.left_child = left_child
    self.right_child = right_child
    ### Fields containing content.
    self.left_id_sets = left_id_sets
    self.right_id_sets = right_id_sets
    self.left_distances = left_distances
    self.right_distances = right_distances
    self.left_ids = left_ids
    self.right_ids = right_ids
    self.bfgi = bfgi

def find_vertices(tree, root):
  found_vertices = []
  parent = np.zeros(len(tree))
  parent[root] = -1
  queue = deque([root])
  while queue:
    m = queue.pop()
    for neighbour, weight in tree[m]:
      if neighbour != parent[m]:
        found_vertices.append(neighbour)
        queue.append(neighbour)
        parent[neighbour] = m
  return found_vertices

def bfs(graph, start):
  visited = np.zeros(len(graph))
  distances = np.zeros(len(graph))
  queue = deque([])
  def bfs_aux(graph, node, visited, distances, queue):
    visited[node] = 1
    queue.append(node)
    while queue:
      m = queue.pop()
      for neighbour, weight in graph[m]:
        if visited[neighbour] == 0:
          visited[neighbour] = 1
          distances[neighbour] = distances[m] + weight
          queue.append(neighbour)
    return distances
  return bfs_aux(graph, start, visited, distances, queue)

def dfs_subtree_sizes(tree, root):
  stack = []
  sizes = np.zeros(len(tree))
  discovered = np.zeros(len(tree))
  parent = np.zeros(len(tree), dtype=int)
  parent[root] = root
  stack.append(root)
  while len(stack) > 0:
    vertex = stack[-1]
    if not discovered[vertex]:
      sizes[vertex] += 1
      discovered[vertex] = 1
      count = 0
      for neighbor_weight in tree[vertex]:
        neighbor = neighbor_weight[0]
        if neighbor != parent[vertex]:
          count += 1
          stack.append(neighbor)
          parent[neighbor] = vertex
      if not count:
        sizes[parent[vertex]] += sizes[vertex]
        x = stack.pop()
    else:
      if vertex is not root:
        sizes[parent[vertex]] += sizes[vertex]
      x = stack.pop()
  return sizes

class Level():
  def __init__(self, tf_shape):
    self.nodes = []
    if tf_shape is not None:
      self.tf_value = np.zeros(tf_shape)

def compute_cross_contribs(left_distances, left_tf_vals, right_distances,
                           right_tf_vals, f_func, is_lambda):
  if is_lambda:
    left_struct_matrix = np.zeros((len(left_distances), len(right_distances)))
    for i in range(len(left_distances)):
      for j in range(len(right_distances)):
        left_struct_matrix[i][j] = f_func(left_distances[i] + right_distances[j])
    right_struct_matrix = np.transpose(left_struct_matrix)
    cross_vals_for_left = np.einsum("kl,l...->k...", left_struct_matrix,
                                    np.array(right_tf_vals))
    cross_vals_for_right = np.einsum("lk,k...->l...", right_struct_matrix,
                                     np.array(left_tf_vals))
  else:
    a, b = f_func
    if (len(b) == 1 and b[0] == 1.0):
      l_res_shape = tuple([len(left_distances)] +
                          list(np.array(right_tf_vals).shape[1:]))
      r_res_shape = tuple([len(right_distances)] +
                          list(np.array(left_tf_vals).shape[1:]))
      cross_vals_for_left = np.zeros(l_res_shape)
      cross_vals_for_right = np.zeros(r_res_shape)
      for k in range(len(a)):
        for b in range(k + 1):
          l_array = np.array([np.power(l, b) for l in left_distances])
          r_array = np.array([np.power(r, k - b) for r in right_distances])
          renorm = a[k] * math.comb(k, b)
          cross_vals_for_left += renorm * np.einsum("n,m,m...->n...", l_array,
                                                    r_array,
                                                    np.array(right_tf_vals))
          cross_vals_for_right += renorm * np.einsum("m,n,n...->m...", r_array,
                                                     l_array,
                                                     np.array(left_tf_vals))

  return cross_vals_for_left, cross_vals_for_right

# Main auxiliary functions.
def partition_tree(original_tree):
  tree = copy.deepcopy(original_tree)
  root = 0
  pivot_point = 0
  parent = np.zeros(len(tree), dtype=int)
  parent[root] = -1
  sizes = dfs_subtree_sizes(tree, root)
  queue = deque([root])
  while queue:
    m = queue.pop()
    if sizes[m] > 0.5 * len(tree):
      pivot_point = m
    for neighbour, _ in tree[m]:
      if neighbour != parent[m]:
        queue.append(neighbour)
        parent[neighbour] = m
  sizes[parent[pivot_point]] = len(tree) - sizes[pivot_point]
  acc = 0
  index = -1
  for neighbor, _ in tree[pivot_point]:
    if acc > 0.25 * len(tree):
      break
    else:
      acc += sizes[neighbor]
      index += 1
  left_neighbors = copy.deepcopy(tree[pivot_point][:(index + 1)])
  right_neighbors = copy.deepcopy(tree[pivot_point][(index + 1):])
  tree[pivot_point] = left_neighbors
  left_vertex_set = find_vertices(tree, pivot_point)
  if_left_vertex = np.zeros(len(tree), dtype=int)
  for elem in left_vertex_set:
    if_left_vertex[elem] = 1
  left_tree = [left_neighbors]
  left_ids = [pivot_point]
  right_tree = [right_neighbors]
  right_ids = [pivot_point]
  for i in range(len(if_left_vertex)):
    if i == pivot_point:
      continue
    if if_left_vertex[i]:
      left_tree.append(copy.deepcopy(tree[i]))
      left_ids.append(i)
    else:
      right_tree.append(copy.deepcopy(tree[i]))
      right_ids.append(i)
  inv_left_ids = np.zeros(len(tree))
  inv_right_ids = np.zeros(len(tree))
  for i in range(len(left_ids)):
    inv_left_ids[left_ids[i]] = i
  for i in range(len(right_ids)):
    inv_right_ids[right_ids[i]] = i
  for i in range(len(left_tree)):
    for j in range(len(left_tree[i])):
      left_tree[i][j][0] = int(inv_left_ids[left_tree[i][j][0]])
  for i in range(len(right_tree)):
    for j in range(len(right_tree[i])):
      right_tree[i][j][0] = int(inv_right_ids[right_tree[i][j][0]])
  return [left_tree, left_ids, right_tree, right_ids]


def integrate_cross_terms(left_id_sets, right_id_sets, left_distances,
                          right_distances, f_func, is_lambda, X_tensor,
                          Y_tensor):
  left_tf_vals = []
  right_tf_vals = []
  for i in range(len(left_id_sets)):
    left_tf_vals.append(np.sum(X_tensor[left_id_sets[i],:], axis=0,
                               keepdims=False))
  for i in range(len(right_id_sets)):
    right_tf_vals.append(np.sum(X_tensor[right_id_sets[i],:], axis=0,
                                keepdims=False))
  res = compute_cross_contribs(left_distances, left_tf_vals, right_distances,
                               right_tf_vals, f_func, is_lambda)
  cross_vals_for_left = res[0]
  cross_vals_for_right = res[1]
  for i in range(len(cross_vals_for_left)):
    A = cross_vals_for_left[i]
    N = len(left_id_sets[i])
    fin_shape = tuple([N] + [1] * len(A.shape))
    Y_tensor[left_id_sets[i],:] += np.tile(A, fin_shape)
  for i in range(len(cross_vals_for_right)):
    A = cross_vals_for_right[i]
    N = len(right_id_sets[i])
    fin_shape = tuple([N] + [1] * len(A.shape))
    Y_tensor[right_id_sets[i],:] += np.tile(A, fin_shape)
  return Y_tensor

def compute_struct_for_merge(left_tree, left_ids, right_tree, right_ids):
  left_distances = bfs(left_tree, 0)
  right_distances = bfs(right_tree, 0)
  left_dict = dict()
  right_dict = dict()
  for i in range(len(left_distances)):
    if left_distances[i] > 0.0:
      if left_distances[i] not in left_dict:
        left_dict[left_distances[i]] = Level(None)
      (left_dict[left_distances[i]].nodes).append(left_ids[i])
  for i in range(len(right_distances)):
    if right_distances[i] > 0.0:
      if right_distances[i] not in right_dict:
        right_dict[right_distances[i]] = Level(None)
      (right_dict[right_distances[i]].nodes).append(right_ids[i])
  left_dict_keys = list(left_dict.keys())
  left_dict_nodes = [x.nodes for x in list(left_dict.values())]
  right_dict_keys = list(right_dict.keys())
  right_dict_nodes = [x.nodes for x in list(right_dict.values())]
  return left_dict_keys, left_dict_nodes, right_dict_keys, right_dict_nodes

def preprocess_tree(tree, f_func, is_lambda, threshold=6):
  if len(tree) < threshold:
    bfgi = BruteForceGraphIntegrator(f_func, is_lambda, tree)
    return CompTree(None, None, None, None, None, None, None, None, bfgi)
  else:
    left_tree, left_ids, right_tree, right_ids = partition_tree(tree)
    left_child = preprocess_tree(left_tree, f_func, is_lambda, threshold)
    right_child = preprocess_tree(right_tree, f_func, is_lambda, threshold)
    l_ds, l_ns, r_ds, r_ns = compute_struct_for_merge(left_tree, left_ids,
                                                      right_tree, right_ids)
    return CompTree(left_child, right_child, l_ns, r_ns, l_ds, r_ds, left_ids,
                    right_ids, None)

def integrate_on_tree(comp_tree, X_tensor, f_func, is_lambda):
  if comp_tree.bfgi is not None:
    return comp_tree.bfgi.integrate(X_tensor)
  else:
    left_result = integrate_on_tree(comp_tree.left_child,
                                    X_tensor[comp_tree.left_ids,:], f_func,
                                    is_lambda)
    right_result = integrate_on_tree(comp_tree.right_child,
                                     X_tensor[comp_tree.right_ids,:], f_func,
                                     is_lambda)
    Y_tensor = np.zeros_like(X_tensor)
    Y_tensor[comp_tree.left_ids,:] += left_result
    Y_tensor[comp_tree.right_ids,:] += right_result
    integrate_cross_terms(comp_tree.left_id_sets, comp_tree.right_id_sets,
                          comp_tree.left_distances, comp_tree.right_distances,
                          f_func, is_lambda, X_tensor, Y_tensor)
    return Y_tensor

# Abstract class for the tree maker.

class TreeConstructor():
  def __init__(self):
    pass
  def construct_tree(graph_adj_lists):
    pass

# Minimum spanning tree functions.

class DisjointSet:
    parent = {}
    size = {}
    def makeSet(self, n):
      for i in range(n):
        self.parent[i] = i
        self.size[i] = 1
    def find(self, k):
      if self.parent[k] == k:
        return k
      return self.find(self.parent[k])
    def union(self, a, b):
      x = self.find(a)
      y = self.find(b)
      if self.size[x] > self.size[y]:
        self.parent[y] = x
        self.size[x] += self.size[y]
      else:
        self.parent[x] = y
        self.size[y] += self.size[x]

def kruskal_algo(graph_adj_lists):
  mst = []
  tree = []
  N = len(graph_adj_lists)
  for i in range(N):
    tree.append([])
  ds = DisjointSet()
  ds.makeSet(N)
  index = 0
  edges = []
  for i in range(len(graph_adj_lists)):
    for j in range(len(graph_adj_lists[i])):
      if graph_adj_lists[i][j][0] < i:
        edges.append([i, graph_adj_lists[i][j][0], graph_adj_lists[i][j][1]])
  edges.sort(key=lambda x: x[2])
  while len(mst) != len(graph_adj_lists) - 1:
    src, dest, weight = edges[index]
    index = index + 1
    x = ds.find(src)
    y = ds.find(dest)
    if x != y:
      tree[src].append([dest, weight])
      tree[dest].append([src, weight])
      mst.append((src, dest, weight))
      ds.union(x, y)
  cost = sum([x[2] for x in mst])
  return tree

class MinimumSpanningTreeConstructor(TreeConstructor):
  def __init__(self):
    pass
  def construct_tree(self, graph_adj_lists):
    return kruskal_algo(graph_adj_lists)

class TreeBasedGraphIntegrator(GraphIntegrator):
  def __init__(self, f_func, is_lambda, graph_adj_lists, tree_constructor,
               threshold=6):
    super().__init__(f_func, is_lambda, graph_adj_lists)
    self.tree = tree_constructor.construct_tree(graph_adj_lists)
    self.comp_tree = preprocess_tree(self.tree, f_func, is_lambda, threshold)
  def integrate(self, X_tensor, threshold=6):
    return integrate_on_tree(self.comp_tree, X_tensor, self.f_func,
                             self.is_lambda)

In [None]:
class StopError(Exception):
    pass


class NonConvergenceError(Exception):
    pass


def solve_1d_linesearch_quad_funct(a, b, c):
    # solve min f(x)=a*x**2+b*x+c sur 0,1
    f0 = c
    df0 = b
    f1 = a + f0 + df0

    if a > 0:  # convex
        minimum = min(1, max(0, -b / (2 * a)))
        # print('entrelesdeux')
        return minimum
    else:  # non convexe donc sur les coins
        if f0 > f1:
            # print('sur1 f(1)={}'.format(f(1)))
            return 1
        else:
            # print('sur0 f(0)={}'.format(f(0)))
            return 0


def line_search_armijo(
    f,
    xk,
    pk,
    gfk,
    old_fval,
    args=(),
    c1=1e-4,
    alpha0=0.99,
    alpha_min=None,
    alpha_max=None,
):
    """
    Armijo linesearch function that works with matrices
    find an approximate minimum of f(xk+alpha*pk) that satifies the
    armijo conditions.
    Parameters
    ----------
    f : function
        loss function
    xk : np.ndarray
        initial position
    pk : np.ndarray
        descent direction
    gfk : np.ndarray
        gradient of f at xk
    old_fval : float
        loss value at xk
    args : tuple, optional
        arguments given to f
    c1 : float, optional
        c1 const in armijo rule (>0)
    alpha0 : float, optional
        initial step (>0)
    alpha_min : float, optional
        minimum value for alpha
    alpha_max : float, optional
        maximum value for alpha
    Returns
    -------
    alpha : float
        step that satisfy armijo conditions
    fc : int
        nb of function call
    fa : float
        loss value at step alpha
    """
    xk = np.atleast_1d(xk)
    fc = [0]

    def phi(alpha1):
        fc[0] += 1
        return f(xk + alpha1 * pk, *args)

    if old_fval is None:
        phi0 = phi(0.0)
    else:
        phi0 = old_fval

    derphi0 = np.sum(pk * gfk)  # Quickfix for matrices
    alpha, phi1 = scalar_search_armijo(phi, phi0, derphi0, c1=c1, alpha0=alpha0)

    if alpha is None:
        return 0.0, fc[0], phi0
    else:
        if alpha_min is not None or alpha_max is not None:
            alpha = np.clip(alpha, alpha_min, alpha_max)
        return float(alpha), fc[0], phi1


def do_linesearch(
    cost,
    G,
    deltaG,
    Mi,
    f_val,
    armijo=True,
    C1=None,
    C2=None,
    reg=None,
    Gc=None,
    constC=None,
    M=None,
    alpha_min=None,
    alpha_max=None,
    method_type=None,
    source_integrator=None,
    target_integrator=None,
):

    """
    Solve the linesearch in the FW iterations
    Parameters
    ----------
    cost : method
        The FGW cost
    G : ndarray, shape(ns,nt)
        The transport map at a given iteration of the FW
    deltaG : ndarray (ns,nt)
        Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
    Mi : ndarray (ns,nt)
        Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
    f_val :  float
        Value of the cost at G
    armijo : bool, optionnal
            If True the steps of the line-search is found via an armijo research. Else closed form is used.
            If there is convergence issues use False.
    C1 : ndarray (ns,ns), optionnal
        Structure matrix in the source domain. Only used when armijo=False
    C2 : ndarray (nt,nt), optionnal
        Structure matrix in the target domain. Only used when armijo=False
    reg : float, optionnal
          Regularization parameter. Corresponds to the alpha parameter of FGW. Only used when armijo=False
    Gc : ndarray (ns,nt)
        Optimal map found by linearization in the FW algorithm. Only used when armijo=False
    constC : ndarray (ns,nt)
             Constant for the gromov cost. See [3]. Only used when armijo=False
    M : ndarray (ns,nt), optional
        Cost matrix between the features. Only used when armijo=False,
    Optional:
    method_type : str None defaults to brute force
    source_integrator : Callable function that does fast matrix multplication for source graph
    target_integrator : Callable function that does fast matrix multplication for target graph
    Returns
    -------
    alpha : float
            The optimal step size of the FW
    fc : useless here
    f_val :  float
             The value of the cost for the next iteration
    References
    ----------
    .. [3] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
          and Courty Nicolas
        "Optimal Transport for structured data with application on graphs"
        International Conference on Machine Learning (ICML). 2019.
    """
    if armijo:
        alpha, fc, f_val = line_search_armijo(
            cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
        )
    else:  # need sym matrices
        if method_type is None:
            dot = np.dot(np.dot(C1, deltaG), C2)
            a = (
                -2 * reg * np.sum(dot * deltaG)
            )  # -2*alpha*<C1 dt C2,dt> si qqlun est pas bon c'est lui
            b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (
                np.sum(dot * G) + np.sum(np.dot(np.dot(C1, G), C2) * deltaG)
            )
            c = cost(G)  # f(xt)

        else:
            if source_integrator is not None and target_integrator is not None:
                partial_dcost = source_integrator.integrate(deltaG)
                dot = (
                    target_integrator.integrate(partial_dcost.T)
                ).T  # use symmetry here
                del partial_dcost
                a = (
                    -2 * reg * np.sum(dot * deltaG)
                )  # -2*alpha*<C1 dt C2,dt> si qqlun est pas bon c'est lui
                partial_cost = source_integrator.integrate(G)
                b1 = (target_integrator.integrate(partial_cost.T)).T
                del partial_cost
                b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (
                    np.sum(dot * G) + np.sum(b1 * deltaG)
                )
                del b1
                c = cost(G)
            elif target_integrator is None:
                partial_dcost = source_integrator.integrate(deltaG)
                dot = np.dot(partial_dcost, C2)
                del partial_dcost
                a = (
                    -2 * reg * np.sum(dot * deltaG)
                )  # -2*alpha*<C1 dt C2,dt> si qqlun est pas bon c'est lui
                partial_cost = source_integrator.integrate(G)
                b1 = np.dot(partial_cost, C2)
                del partial_cost
                b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (
                  np.sum(dot * G) + np.sum(b1 * deltaG)
                )
                del b1
                c = cost(G)
            elif source_integrator is None :
                partial_dcost = np.dot(C1, deltaG)
                dot = (
                    target_integrator.integrate(partial_dcost.T)
                ).T  # use symmetry here
                del partial_dcost
                a = (
                    -2 * reg * np.sum(dot * deltaG)
                )
                partial_cost = np.dot(C1,G)
                b1 = (target_integrator.integrate(partial_cost.T)).T
                del partial_cost
                b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (
                    np.sum(dot * G) + np.sum(b1 * deltaG)
                )
                del b1
                c = cost(G)

        alpha = solve_1d_linesearch_quad_funct(a, b, c)
        if alpha_min is not None or alpha_max is not None:
            alpha = np.clip(alpha, alpha_min, alpha_max)
        fc = None
        f_val = cost(G + alpha * deltaG)

    return alpha, fc, f_val


def cg(
    a,
    b,
    M,
    reg,
    f,
    df,
    G0=None,
    numItermax=500,
    numItermaxEmd=100000,
    stopThr=1e-09,
    stopThr2=1e-9,
    verbose=False,
    log=False,
    armijo=True,
    C1=None,
    C2=None,
    constC=None,
    alpha_min=0.0,
    alpha_max=1.0,
    method_type=None,
    source_integrator=None,
    target_integrator=None,
):
    """
    Solve the general regularized OT problem with conditional gradient
        The function solves the following optimization problem:
    .. math::
        \gamma = arg\min_\gamma <\gamma,M>_F + reg*f(\gamma)
        s.t. \gamma 1 = a
             \gamma^T 1= b
             \gamma\geq 0
    where :
    - M is the (ns,nt) metric cost matrix
    - :math:`f` is the regularization term ( and df is its gradient)
    - a and b are source and target weights (sum to 1)
    The algorithm used for solving the problem is conditional gradient as discussed in  [1]_
    Parameters
    ----------
    a : np.ndarray (ns,)
        samples weights in the source domain
    b : np.ndarray (nt,)
        samples in the target domain
    M : np.ndarray (ns,nt)
        loss matrix
    reg : float
        Regularization term >0
    G0 :  np.ndarray (ns,nt), optional
        initial guess (default is indep joint density)
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshol on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    Optional:
    method_type : str None defaults to brute force
    source_integrator : Callable function that does fast matrix multplication for source graph
    target_integrator : Callable function that does fast matrix multplication for target graph
    Returns
    -------
    gamma : (ns x nt) ndarray
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    References
    ----------
    .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
    See Also
    --------
    ot.lp.emd : Unregularized optimal ransport
    ot.bregman.sinkhorn : Entropic regularized optimal transport
    """

    loop = 1

    if log:
        log = {"loss": [], "delta_fval": []}

    if G0 is None:
        G = np.outer(a, b)
    else:
        G = G0

    def cost(G):
        return np.sum(M * G) + reg * f(G)

    f_val = cost(G)  # f(xt)

    if log:
        log["loss"].append(f_val)

    it = 0

    if verbose:
        print(
            "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss") + "\n" + "-" * 32
        )
        print("{:5d}|{:8e}|{:8e}".format(it, f_val, 0))

    while loop:

        it += 1
        old_fval = f_val
        # G=xt
        # problem linearization
        Mi = M + reg * df(G)  # Gradient(xt)
        # set M positive
        Mi += Mi.min()

        # solve linear program
        Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True)  # st

        deltaG = Gc - G  # dt

        # argmin_alpha f(xt+alpha dt)
        alpha, fc, f_val = do_linesearch(
            cost=cost,
            G=G,
            deltaG=deltaG,
            Mi=Mi,
            f_val=f_val,
            armijo=armijo,
            constC=constC,
            C1=C1,
            C2=C2,
            reg=reg,
            Gc=Gc,
            M=M,
            alpha_min=alpha_min,
            alpha_max=alpha_max,
            method_type=method_type,
            source_integrator=source_integrator,
            target_integrator=target_integrator,
        )

        if alpha is None or np.isnan(alpha):
            raise NonConvergenceError("Alpha is not found")
        else:
            G = G + alpha * deltaG  # xt+1=xt +alpha dt

        # test convergence
        if it >= numItermax:
            loop = 0

        delta_fval = f_val - old_fval
        abs_delta_fval = abs(f_val - old_fval)

        relative_delta_fval = abs_delta_fval / abs(f_val)
        if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
            loop = 0

        if log:
            log["loss"].append(f_val)
            log["delta_fval"].append(delta_fval)

        if verbose:
            if it % 20 == 0:
                print(
                    "{:5s}|{:12s}|{:8s}".format("It.", "Loss", "Delta loss")
                    + "\n"
                    + "-" * 32
                )
            print("{:5d}|{:8e}|{:8e}|{:5e}".format(it, f_val, delta_fval, alpha))

    if log:
        log.update(logemd)
        return G, log
    else:
        return G

In [None]:
def fast_multiply_matrix_square(integrator, field):
    """
    Fast mutiplication with Hadamard square of a matrix and a vector
    Args : integrator : fast graph field integrator to compute einsum with a vector
    """
    assert field.shape[1] == 1
    partial_field = integrator.integrate(np.diag(field.squeeze())).T
    return np.diag(integrator.integrate(partial_field)).reshape(-1, 1)


def init_matrix(
    C1,
    C2,
    p,
    q,
    loss_fun="square_loss",
    method_type=None,
    source_integrator=None,
    target_integrator=None,
):
    """Return loss matrices and tensors for Gromov-Wasserstein fast computation
    Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss
    function as the loss function of Gromow-Wasserstein discrepancy.
    The matrices are computed as described in Proposition 1 in [1]
    Where :
        * C1 : Metric cost matrix in the source space
        * C2 : Metric cost matrix in the target space
        * T : A coupling between those two spaces
    The square-loss function L(a,b)=(1/2)*|a-b|^2 is read as :
        L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with :
            * f1(a)=(a^2)
            * f2(b)=(b^2)
            * h1(a)=a
            * h2(b)=2b
    Parameters
    ----------
    C1 : ndarray, shape (ns, ns)
         Metric cost matrix in the source space
    C2 : ndarray, shape (nt, nt)
         Metric costfr matrix in the target space
    T :  ndarray, shape (ns, nt)
         Coupling between source and target spaces
    p : ndarray, shape (ns,)
    method_type : (str) Choose one of [None, "diffusion", "separator"]
    source_integrator : Callable , fast graph field integrator for source points
    target_integrator : Callable , fast graph field integrator for target points
    Returns
    -------
    constC : ndarray, shape (ns, nt)
           Constant C matrix in Eq. (6)
    hC1 : ndarray, shape (ns, ns)
           h1(C1) matrix in Eq. (6)
    hC2 : ndarray, shape (nt, nt)
           h2(C) matrix in Eq. (6)
    References
    ----------
    .. [1] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
    "Gromov-Wasserstein averaging of kernel and distance matrices."
    International Conference on Machine Learning (ICML). 2016.
    """

    if loss_fun == "square_loss":

        def f1(a):
            return a**2

        def f2(b):
            return b**2

        def h1(a):
            return a

        def h2(b):
            return 2 * b

    elif loss_fun == "kl_loss":

        def f1(a):
            return a * np.log(a + 1e-15) - a

        def f2(b):
            return b

        def h1(a):
            return a

        def h2(b):
            return np.log(b + 1e-15)

    if method_type is None:
        constC1 = np.dot(
            np.dot(f1(C1), p.reshape(-1, 1)), np.ones(len(q)).reshape(1, -1)
        )
        constC2 = np.dot(
            np.ones(len(p)).reshape(-1, 1), np.dot(q.reshape(1, -1), f2(C2).T)
        )

    else :
        if loss_fun == "square_loss":
            if source_integrator is not None and target_integrator is not None:
                constC1 = np.dot(
                    fast_multiply_matrix_square(source_integrator, p.reshape(-1, 1)),
                    np.ones(len(q)).reshape(1, -1),
                )
                constC2 = np.dot(
                    np.ones(len(p)).reshape(-1, 1),
                    fast_multiply_matrix_square(target_integrator, q.reshape(-1, 1)).T,
                )
            elif target_integrator is None and source_integrator is not None:
                constC1 = np.dot(
                    fast_multiply_matrix_square(source_integrator, p.reshape(-1, 1)),
                    np.ones(len(q)).reshape(1, -1),
                )
                constC2 = np.dot(
                np.ones(len(p)).reshape(-1, 1), np.dot(q.reshape(1, -1), f2(C2).T)
            )
            elif source_integrator is None and target_integrator is not None:
                constC1 = np.dot(
                np.dot(f1(C1), p.reshape(-1, 1)), np.ones(len(q)).reshape(1, -1)
            )
                constC2 = np.dot(
                    np.ones(len(p)).reshape(-1, 1),
                    fast_multiply_matrix_square(target_integrator, q.reshape(-1, 1)).T,
                )

        elif loss_fun == "kl_loss":
            constC1 = np.dot(
                np.dot(f1(C1), p.reshape(-1, 1)), np.ones(len(q)).reshape(1, -1)
            )  # no idea how to make it faster
            constC2_partial = (
                target_integrator.integrate(q.reshape(1, -1).T)
            ).T
            constC2 = np.dot(np.ones(len(p)).reshape(-1, 1), constC2_partial)

        else:
            raise ValueError("Unsupported combination of loss and methods")

    constC = constC1 + constC2

    if method_type is None:
        hC1 = h1(C1)
        hC2 = h2(C2)
    else:
        if loss_fun == "square_loss":
            if C1 is None and C2 is None :
                hC1, hC2 = None, None
            elif C1 is None and C2 is not None :
                hC1, hC2 = None, h2(C2)
            elif C2 is None and C1 is not None :
                hC1, hC2 = h1(C1), None
        else:
            hC1, hC2 = None, h2(C2)

    return constC, hC1, hC2


def tensor_product(
    constC,
    hC1,
    hC2,
    T,
    method_type=None,
    loss_fun=None,
    source_integrator=None,
    target_integrator=None,
):

    """Return the tensor for Gromov-Wasserstein fast computation
    The tensor is computed as described in Proposition 1 Eq. (6) in [1].
    Parameters
    ----------
    constC : ndarray, shape (ns, nt)
           Constant C matrix in Eq. (6)
    hC1 : ndarray, shape (ns, ns)
           h1(C1) matrix in Eq. (6)
    hC2 : ndarray, shape (nt, nt)
           h2(C) matrix in Eq. (6)
    T : ndarray shape (ns,nt) coupling matrix between source and target
    Optional :
    method_type : str None defaults to brute force
    source_integrator : Callable function that does fast matrix multplication for source graph
    target_integrator : Callable function that does fast matrix multplication for target graph
    Returns
    -------
    tens : ndarray, shape (ns, nt)
           \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
    References
    ----------
    .. [1] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
    "Gromov-Wasserstein averaging of kernel and distance matrices."
    International Conference on Machine Learning (ICML). 2016.
    """

    if method_type is None:
        A = -np.dot(np.dot(hC1, T), hC2.T)
    else:
        if loss_fun == "square_loss":
            if source_integrator is not None and target_integrator is not None:
                partial_prod = source_integrator.integrate(T)
                A = -2 * (target_integrator.integrate(partial_prod.T)).T
                del partial_prod
            elif source_integrator is not None and target_integrator is None:
                partial_prod = source_integrator.integrate(T)
                A = -np.dot(partial_prod, hC2.T)
                del partial_prod
            elif target_integrator is not None and source_integrator is None :
                partial_prod = np.dot(hC1, T)
                A = -2 * (target_integrator.integrate(partial_prod.T)).T
                del partial_prod
        elif loss_fun == "kl_loss":
            partial_prod = source_integrator.integrate(T)
            A = -np.dot(partial_prod, hC2.T)
            del partial_prod
        else:
            raise NotImplementedError(
                "Other types of losses are not currently supported."
            )
    tens = constC + A

    return tens


def gwloss(
    constC,
    hC1,
    hC2,
    T,
    method_type=None,
    loss_fun=None,
    source_integrator=None,
    target_integrator=None,
):

    """Return the Loss for Gromov-Wasserstein
    The loss is computed as described in Proposition 1 Eq. (6) in [1].
    Parameters
    ----------
    constC : ndarray, shape (ns, nt)
           Constant C matrix in Eq. (6)
    hC1 : ndarray, shape (ns, ns)
           h1(C1) matrix in Eq. (6)
    hC2 : ndarray, shape (nt, nt)
           h2(C) matrix in Eq. (6)
    T : ndarray, shape (ns, nt)
           Current value of transport matrix T
    Optional :
    method_type : str None defaults to brute force
    source_integrator : Callable function that does fast matrix multplication for source graph
    target_integrator : Callable function that does fast matrix multplication for target graph
    Returns
    -------
    loss : float
           Gromov Wasserstein loss
    References
    ----------
    .. [1] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
    "Gromov-Wasserstein averaging of kernel and distance matrices."
    International Conference on Machine Learning (ICML). 2016.
    """

    tens = tensor_product(
        constC,
        hC1,
        hC2,
        T,
        method_type=method_type,
        loss_fun=loss_fun,
        source_integrator=source_integrator,
        target_integrator=target_integrator,
    )
    return np.sum(tens * T)


def gwggrad(
    constC,
    hC1,
    hC2,
    T,
    method_type=None,
    loss_fun=None,
    source_integrator=None,
    target_integrator=None,
):

    """Return the gradient for Gromov-Wasserstein
    The gradient is computed as described in Proposition 2 in [1].
    Parameters
    ----------
    constC : ndarray, shape (ns, nt)
           Constant C matrix in Eq. (6)
    hC1 : ndarray, shape (ns, ns)
           h1(C1) matrix in Eq. (6)
    hC2 : ndarray, shape (nt, nt)
           h2(C) matrix in Eq. (6)
    T : ndarray, shape (ns, nt)
           Current value of transport matrix T
    Optional :
    method_type : str None defaults to brute force
    source_integrator : Callable function that does fast matrix multplication for source graph
    target_integrator : Callable function that does fast matrix multplication for target graph
    Returns
    -------
    grad : ndarray, shape (ns, nt)
           Gromov Wasserstein gradient
    References
    ----------
    .. [1] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
    "Gromov-Wasserstein averaging of kernel and distance matrices."
    International Conference on Machine Learning (ICML). 2016.
    """

    return 2 * tensor_product(
        constC,
        hC1,
        hC2,
        T,
        method_type=method_type,
        loss_fun=loss_fun,
        source_integrator=source_integrator,
        target_integrator=target_integrator,
    )


def gw_lp(
    C1=None,
    C2=None,
    p=None,
    q=None,
    loss_fun="square_loss",
    alpha=1,
    armijo=True,
    G0=None,
    log=True,
    method_type=None,
    dim: int = None,
    source_adjacency_lists=None,
    threshold=None,
    target_adjacency_lists=None,
    verbose=False,
    source_integrator=None,
    target_integrator=None,
    max_iter=1000,
    stopThr=1e-9,
    func=None,
):

    """
    Returns the gromov-wasserstein transport between (C1,p) and (C2,q)
    The function solves the following optimization problem:
    .. math::
        \GW_Dist = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
    Where :
        C1 : Metric cost matrix in the source space
        C2 : Metric cost matrix in the target space
        p  : distribution in the source space
        q  : distribution in the target space
        L  : loss function to account for the misfit between the similarity matrices
        H  : entropy
    Parameters
    ----------
    C1 : ndarray, shape (ns, ns)
         Metric cost matrix in the source space
    C2 : ndarray, shape (nt, nt)
         Metric costfr matrix in the target space
    p :  ndarray, shape (ns,)
         distribution in the source space
    q :  ndarray, shape (nt,)
         distribution in the target space
    loss_fun :  string
        loss function used for the solver
    max_iter : int, optional
        Max number of iterations
    tol : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    armijo : bool, optional
        If True the step of the line-search is found via an armijo research. Else closed form is used.
        If there is convergence issues use False.
     G0: ndarray, shape (ns,nt), optional
        If None the initial transport plan of the solver is pq^T.
        Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver.
    The rest of the parameters are optional and only used for fast matrix vector multiplication.
        method_type : Choose from [None, "diffusion", "separator"]
        source_positions : (n_s, dim) location of points in d-dim Euclidean space.
        target_positions : (n_t, dim) location of points in d-dim Euclidean space.
        source_epsilon : parameter that controls the epsilon neighbor of source points
        target_epsilon : parameter that controls the epsilon neighbor of target points
        source_lambda_par : diffusion parameter for source graph.
        target_lambda_par : diffusion parameter for target graph.
        num_rand_features : Number of random features
        dim : Input dimensionality of the data
    verbose : bool, optional
        If true returns logs/errors in each iteration
    Returns
    -------
    T : ndarray, shape (ns, nt)
        coupling between the two spaces that minimizes :
            \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
    log : dict
        convergence information and loss
    References
    ----------
    .. [1] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
        "Gromov-Wasserstein averaging of kernel and distance matrices."
        International Conference on Machine Learning (ICML). 2016.
    .. [2] Mémoli, Facundo. Gromov–Wasserstein distances and the
        metric approach to object matching. Foundations of computational
        mathematics 11.4 (2011): 417-487.
    """

    if method_type is not None:
        if source_integrator is None and target_integrator is None:
            s_integrator = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=source_adjacency_lists,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold)
            t_integrator = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=target_adjacency_lists,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold) #hardcoded
        else:
            s_integrator = source_integrator
            t_integrator = target_integrator

        if loss_fun == "square_loss":
            constC, hC1, hC2 = init_matrix(
                C1,
                C2,
                p,
                q,
                loss_fun,
                method_type=method_type,
                source_integrator=s_integrator,
                target_integrator=t_integrator,
            )
        elif loss_fun == "kl_loss":
            constC, hC1, hC2 = init_matrix(
                C1, # shortest path matrices in our case
                C2,
                p,
                q,
                loss_fun,
                method_type=method_type,
                source_integrator=None,
                target_integrator=t_integrator,
            )
        else:
            raise ValueError("incorrect loss function used")

    else:
        s_integrator = None
        t_integrator = None
        constC, hC1, hC2 = init_matrix(
            C1,
            C2,
            p,
            q,
            loss_fun,
            method_type=method_type,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
        )


    if method_type is None:
        M = np.zeros((C1.shape[0], C2.shape[0]))
    else:
        M = np.zeros((p.shape[0], q.shape[0]))

    if G0 is None:
        G0 = p[:, None] * q[None, :]
    else:  # check marginals
        np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08)
        np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08)

    def f(G):
        return gwloss(
            constC,
            hC1,
            hC2,
            G,
            method_type=method_type,
            loss_fun=loss_fun,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
        )

    def df(G):
        return gwggrad(
            constC,
            hC1,
            hC2,
            G,
            method_type=method_type,
            loss_fun=loss_fun,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
        )

    if log:
        res, log0 = cg(
            a=p,
            b=q,
            M=M,
            reg=alpha,
            f=f,
            df=df,
            G0=G0,
            armijo=armijo,
            C1=C1,
            C2=C2,
            constC=constC,
            log=log,
            alpha_min=0,
            alpha_max=1,
            method_type=method_type,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
            verbose=verbose,
            numItermax=max_iter,
            stopThr=stopThr,
        )
        log0["gw_dist"] = gwloss(
            constC,
            hC1,
            hC2,
            res,
            method_type=method_type,
            loss_fun=loss_fun,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
        )
        return res, log0
    else:
        res = cg(
            a=p,
            b=q,
            M=M,
            reg=alpha,
            f=f,
            df=df,
            G0=G0,
            armijo=armijo,
            C1=C1,
            C2=C2,
            constC=constC,
            log=log,
            alpha_min=0,
            alpha_max=1,
            method_type=method_type,
            source_integrator=s_integrator,
            target_integrator=t_integrator,
            verbose=verbose,
            numItermax=max_iter,
            stopThr=stopThr,
        )
        return res

In [None]:
import matplotlib.pylab as pl
# from ot.gromov import semirelaxed_gromov_wasserstein, semirelaxed_fused_gromov_wasserstein, gromov_wasserstein, fused_gromov_wasserstein
import networkx
from networkx.generators.community import stochastic_block_model as sbm

In [None]:
def construct_edge_lists(graph, weights=False):
  adj_lists = []
  for k, v in graph.adjacency():
    lists = []
    # print(k)
    for t, w in v.items() :
      if len(w) == 0:
        lists.append([t, 1])
      else :
    # possible bug depending on the structure of w
        lists.append([t,w])
    # print(lists)
    adj_lists.append(lists)

  return adj_lists

In [None]:
# construct_edge_lists(G2)

# RUN THE ALGORITHM ON TREES

In [None]:
import time
from networkx.generators.trees import random_tree

for i in [500, 1000, 2000, 2500]:
  G2 = random_tree(n=i, seed=0)
  G3 = random_tree(n=1, seed=42)
  start = time.time()
  C2 = networkx.floyd_warshall_numpy(G2)
  C3 = networkx.floyd_warshall_numpy(G3)
  h2 = np.ones(C2.shape[0]) / C2.shape[0]
  h3 = np.ones(C3.shape[0]) / C3.shape[0]
  _, log = gw_lp(C2, C3, h2, h3, log=True)
  end = time.time()
  gw = log['gw_dist']
  print(f"Time taken for baseline gw on {i} nodes is {end-start} and the distance is {gw}")
  del G2, G3, C2, C3, log, h2, h3

In [None]:
# bigger graphs
for i in [3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]:
  G2 = random_tree(n=i, seed=0)
  G3 = random_tree(n=1, seed=42)
  start = time.time()
  C2 = networkx.floyd_warshall_numpy(G2)
  C3 = networkx.floyd_warshall_numpy(G3)
  h2 = np.ones(C2.shape[0]) / C2.shape[0]
  h3 = np.ones(C3.shape[0]) / C3.shape[0]
  _, log = gw_lp(C2, C3, h2, h3, log=True)
  end = time.time()
  gw = log['gw_dist']
  print(f"Time taken for baseline gw on {i} nodes is {end-start} and the distance is {gw}")
  del G2, G3, C2, C3, log, h2, h3

Time taken for baseline gw on 3000 nodes is 194.95167183876038 and the distance is 6394.569884222222
Time taken for baseline gw on 4000 nodes is 401.9335091114044 and the distance is 6877.147011999999
Time taken for baseline gw on 5000 nodes is 769.647983789444 and the distance is 10339.689725039996
Time taken for baseline gw on 6000 nodes is 1328.2349543571472 and the distance is 13561.547033722221
Time taken for baseline gw on 7000 nodes is 2102.0649557113647 and the distance is 17092.65742036735


In [None]:
for j,i in enumerate([500, 1000, 2000]):
  threshold = [50, 100, 150]
  func = lambda x : x
  G2 = random_tree(n=i, seed=0)
  G3 = random_tree(n=1, seed=42)
  s_graph = construct_edge_lists(G2)
  t_graph = construct_edge_lists(G3)
  h2 = np.ones(i) / i
  h3 = np.ones(i) / i
  start_time = time.time()
  s_tbgi = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=s_graph,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold[j])

  t_tbgi = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=t_graph,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold[j])

  _, log2 = gw_lp(
        C1=None,
        C2=None,
        p=h2,
        q=h3,
        loss_fun="square_loss",
        alpha=.5,
        armijo=False,
        G0=None,
        log=True,
        method_type="diffusion",
        source_integrator=s_tbgi,
        target_integrator=t_tbgi,
        max_iter=200000,
    )
  end = time.time()
  gw = log2['gw_dist']
  print(f"Time taken for baseline gw on {i} nodes is {end-start_time} and the distance is {gw}")
  del G2, G3, s_graph, t_graph, s_tbgi, t_tbgi, log2, h2, h3

Time taken for baseline gw on 500 nodes is 0.6170721054077148 and the distance is 1585.8247039999999
Time taken for baseline gw on 1000 nodes is 2.304067611694336 and the distance is 2088.9115020000013
Time taken for baseline gw on 2000 nodes is 6.896076917648315 and the distance is 3826.6408794999966


In [None]:
# our method for bigger graphs

for j,i in enumerate([3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]):
  threshold = [500, 600, 700, 800, 900, 900, 900, 1000]
  func = lambda x : x
  G2 = random_tree(n=i, seed=0)
  G3 = random_tree(n=1, seed=42)
  s_graph = construct_edge_lists(G2)
  t_graph = construct_edge_lists(G3)
  h2 = np.ones(i) / i
  h3 = np.ones(i) / i
  start_time = time.time()
  s_tbgi = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=s_graph,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold[j])

  t_tbgi = TreeBasedGraphIntegrator(f_func=func, is_lambda=True, graph_adj_lists=t_graph,
                                    tree_constructor=MinimumSpanningTreeConstructor(),
                                    threshold=threshold[j])

  _, log2 = gw_lp(
        C1=None,
        C2=None,
        p=h2,
        q=h3,
        loss_fun="square_loss",
        alpha=.5,
        armijo=False,
        G0=None,
        log=True,
        method_type="diffusion",
        source_integrator=s_tbgi,
        target_integrator=t_tbgi,
        max_iter=200000,
    )
  end = time.time()
  gw = log2['gw_dist']
  print(f"Time taken for baseline gw on {i} nodes is {end-start} and the distance is {gw}")
  del G2, G3, s_graph, t_graph, s_tbgi, t_tbgi, log2, h2, h3