# Check triangle inequality

In [1]:
import os.path as osp
from collections import defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from ot.gromov import gromov_wasserstein
from torch_geometric.datasets import TUDataset
from tqdm import tqdm

from ogw.gromov_prox import projection_matrix, quad_solver
from ogw.gw_lb import flb, slb, tlb
from ogw.ogw_dist import eval_ogw_lb, ogw_lb, ogw_ub
from ogw.utils import random_perturb
from scipy.linalg import eigvalsh, svdvals

from time import time
np.set_printoptions(3)
np.random.seed(0)


  from .autonotebook import tqdm as notebook_tqdm


## Synthetic dataset
* Generate 15 Erdos-Renyi random graph with n=20, p=0.6
* No disconnected graphs

In [2]:
Cs = []
n = 15
for _ in range(n):
    G = nx.erdos_renyi_graph(20, 0.6)
    C = nx.floyd_warshall_numpy(G)
    Cs.append(C)


In [3]:
dist = np.zeros((n, n))
for i in range(n):
    for j in range(i+1, n):
        dist[i, j] = ogw_lb(Cs[i], Cs[j])
dist += dist.T



In [4]:
dist_sqrt = np.power(dist, 0.5)
counter = 0
for i in range(n):
    for j in range(i+1, n):
        for k in range(j+1, n):
            if dist_sqrt[i, j] > dist_sqrt[i, k] + dist_sqrt[k, j]:
                counter += 1

total = n * (n - 1) * (n - 2) / 6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality under ogw_lb")



0.00 % of tuples (i, j, k) violate the triangle inequality under ogw_lb


In [5]:
dist = np.zeros((n, n))
p = np.ones(20) / 20
for i in range(n):
    for j in range(i + 1, n):
        _, gw_log = gromov_wasserstein(Cs[i], Cs[j], p, p, "square_loss", log=True)
        dist[i, j] = gw_log['gw_dist']
dist += dist.T


In [6]:
dist_sqrt = np.power(dist, 0.5)
counter = 0
for i in range(n):
    for j in range(i + 1, n):
        for k in range(j + 1, n):
            if dist_sqrt[i, j] > dist_sqrt[i, k] + dist_sqrt[k, j]:
                counter += 1

total = n * (n - 1) * (n - 2) / 6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality under gw")


0.00 % of tuples (i, j, k) violate the triangle inequality under gw


## Real dataset

In [7]:
from ogw.utils import load_pyg_data
import logging
import os
from joblib.parallel import Parallel, delayed
import pickle

ROOT = osp.join(osp.expanduser("~"), 'tmp', 'data', "TUDataset")
dsname = "MUTAG"

# prepare dataset
Gs, ys = load_pyg_data(dsname)
Cs = [nx.floyd_warshall_numpy(g) for g in Gs]
Ns = [C.shape[0] for C in Cs]
ps = [np.ones(n) / n for n in Ns]

SAVED_PATH = osp.join(ROOT, dsname, "saved")
if not osp.isdir(SAVED_PATH):
    logging.info("creating folder")
    os.makedirs(SAVED_PATH)

N = len(Gs)


In [8]:
def calc_D_OGW_lb(i, j, D):
    D[i, j] = ogw_lb(Cs[i], Cs[j])

fn_mm = osp.join(ROOT, dsname, "D_OGW_lb")
D_OGW_lb = np.memmap(fn_mm, mode="w+", shape=(N, N), dtype=float)

logging.info(f"calcualte OGW_lb")
Parallel(n_jobs=-1, backend="multiprocessing")(
    delayed(calc_D_OGW_lb)(i, j, D_OGW_lb) for i in range(N) for j in range(i + 1, N))
D_OGW_lb += D_OGW_lb.T

pickle.dump(D_OGW_lb, open(osp.join(SAVED_PATH, "D_OGW_lb.pkl"), "wb"))


In [9]:
# Check non-negativity
print(f"min value from pairwise dissimilarities {D_OGW_lb.min():.2e}")

min value from pairwise dissimilarities -1.42e-14


In [10]:
# fix some numerical issue. e.g., -1 e-15
D_OGW_lb = np.clip(D_OGW_lb, 0, np.inf)

# compare in square root, not in ||.||^2
D_OGW_lb_sqrt = np.power(D_OGW_lb, 0.5)

counter = 0
for i in range(N):
    for j in range(i+1, N):
        for k in range(j+1, N):
            if D_OGW_lb_sqrt[i, j] - D_OGW_lb_sqrt[i, k] - D_OGW_lb_sqrt[k, j] > 1e-13:
                print("violate triangle inequality", i, j, k)
                counter += 1

total = N*(N-1) * (N-2)/6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality")

0.00 % of tuples (i, j, k) violate the triangle inequality


In [11]:
D_OGW_lb

array([[0.   , 1.046, 1.046, ..., 0.917, 0.314, 0.1  ],
       [1.046, 0.   , 0.   , ..., 0.086, 1.388, 0.511],
       [1.046, 0.   , 0.   , ..., 0.086, 1.388, 0.511],
       ...,
       [0.917, 0.086, 0.086, ..., 0.   , 1.53 , 0.447],
       [0.314, 1.388, 1.388, ..., 1.53 , 0.   , 0.401],
       [0.1  , 0.511, 0.511, ..., 0.447, 0.401, 0.   ]])

In [12]:
def calc_D_GW(i, j, D):
    T, gw_log = gromov_wasserstein(Cs[i], Cs[j], ps[i], ps[j], loss_fun="square_loss", log=True)
    D[i, j] = gw_log['gw_dist']


# GW
fn_mm = osp.join(ROOT, dsname, "D_GW")
D_GW = np.memmap(fn_mm, mode="w+", shape=(N, N), dtype=float)

logging.info(f"calcualte GW")
Parallel(n_jobs=-1, backend="multiprocessing")(
    delayed(calc_D_GW)(i, j, D_GW) for i in range(N) for j in range(i + 1, N))
D_GW += D_GW.T

pickle.dump(D_GW, open(osp.join(SAVED_PATH, "D_GW.pkl"), "wb"))


  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)
  relative_delta_fval = abs_delta_fval / abs(f_val)


In [13]:
print(f"min value from pairwise dissimilarities {D_GW.min():.2e}")

min value from pairwise dissimilarities -3.11e-15


In [14]:
D_GW = np.clip(D_GW, 0, np.inf)
D_GW_sqrt = np.power(D_GW, 0.5)

counter = 0
for i in range(N):
    for j in range(i + 1, N):
        for k in range(j + 1, N):
            if D_GW_sqrt[i, j] - D_GW_sqrt[i, k] - D_GW_sqrt[k, j] > 1e-13:
                # print("violate triangle inequality", i, j, k)
                counter += 1


total = N * (N - 1) * (N - 2) / 6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality")


0.43 % of tuples (i, j, k) violate the triangle inequality


## Check with GW_flb

In [15]:
def calc_D_GW_lb(i, j, D):
    D[i, j] = flb(Cs[i], Cs[j])
    
# GW_lb
fn_mm = osp.join(ROOT, dsname, "D_GW_lb")
D_GW_lb = np.memmap(fn_mm, mode="w+", shape=(N, N), dtype=float)

logging.info(f"calcualte GW_lb")
Parallel(n_jobs=-1, backend="multiprocessing")(
    delayed(calc_D_GW_lb)(i, j, D_GW_lb) for i in range(N) for j in range(i + 1, N))
D_GW_lb += D_GW_lb.T

pickle.dump(D_GW_lb, open(osp.join(SAVED_PATH, "D_GW_lb.pkl"), "wb"))

In [16]:
D_GW_lb = np.clip(D_GW_lb, 0, np.inf)
D_GW_lb_sqrt = np.power(D_GW_lb, 0.5)

counter = 0
for i in range(N):
    for j in range(i + 1, N):
        for k in range(j + 1, N):
            if D_GW_lb_sqrt[i, j] - D_GW_lb_sqrt[i, k] - D_GW_lb_sqrt[k, j] > 1e-13:
                # print("violate triangle inequality", i, j, k)
                counter += 1


total = N * (N - 1) * (N - 2) / 6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality")


0.00 % of tuples (i, j, k) violate the triangle inequality


In [17]:
def calc_D_OGW_ub(i, j, D):
    D[i, j] = ogw_ub(Cs[i], Cs[j])
    
# OGW_ub
fn_mm = osp.join(ROOT, dsname, "D_OGW_ub")
D_OGW_ub = np.memmap(fn_mm, mode="w+", shape=(N, N), dtype=float)

logging.info(f"calcualte OGW_ub")
Parallel(n_jobs=-1, backend="multiprocessing")(
    delayed(calc_D_OGW_ub)(i, j, D_OGW_ub) for i in range(N) for j in range(i + 1, N))
D_OGW_ub += D_OGW_ub.T

pickle.dump(D_OGW_ub, open(osp.join(SAVED_PATH, "D_OGW_ub.pkl"), "wb"))


In [18]:
# fix some numerical issue. e.g., -1 e-15
D_OGW_ub = np.clip(D_OGW_ub, 0, np.inf)

# compare in square root, not in ||.||^2
D_OGW_ub_sqrt = np.power(D_OGW_ub, 0.5)

counter = 0
for i in range(N):
    for j in range(i + 1, N):
        for k in range(j + 1, N):
            if D_OGW_ub_sqrt[i, j] - D_OGW_ub_sqrt[i, k] - D_OGW_ub_sqrt[k, j] > 1e-13:
                counter += 1

total = N * (N - 1) * (N - 2) / 6
print(f"{counter/total * 100:.2f} % of tuples (i, j, k) violate the triangle inequality")


0.08 % of tuples (i, j, k) violate the triangle inequality
