# Check the walltime of lower bounds from GW and OGW

In [1]:
import pickle
import os.path as osp
from collections import defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from torch_geometric.datasets import TUDataset
from tqdm import tqdm

from ogw.gromov_prox import projection_matrix, quad_solver
from ogw.gw_lb import flb, slb, tlb
from ogw.ogw_dist import ogw_lb_v2
from ogw.utils import random_perturb
from scipy.linalg import eigvalsh, svdvals

from time import time
np.set_printoptions(3)
np.random.seed(1)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
graphs = defaultdict(list)
sizes = [10 * i for i in range(1, 10)] + [100 * i for i in range(1, 11)]
# sizes = [50]
samples = 20
for n in sizes:
    for _ in range(samples):
        G = nx.erdos_renyi_graph(int(n), 0.6)
        C = nx.floyd_warshall_numpy(G)
        graphs[n].append(C)


In [None]:
tc_flb = []
for s in sizes:
    tic = time()
    for i in range(samples):
        for j in range(samples):
            d = flb(graphs[s][i], graphs[s][j])
    toc = time()
    tc_flb.append(toc - tic)

print(np.mean(tc_flb), np.std(tc_flb))


#
tc_slb = []
for s in sizes:
    tic = time()
    for i in range(samples):
        for j in range(samples):
            d = slb(graphs[s][i], graphs[s][j])
    toc = time()
    tc_slb.append(toc - tic)

print(np.mean(tc_slb), np.std(tc_slb))


#
tc_tlb = []
for s in sizes:
    tic = time()
    for i in range(samples):
        for j in range(samples):
            d = tlb(graphs[s][i], graphs[s][j])
    toc = time()
    tc_tlb.append(toc - tic)

print(np.mean(tc_tlb), np.std(tc_tlb))


#
tc_ogw = []
for s in sizes:
    V = projection_matrix(s)
    tic = time()
    for i in range(samples):
        for j in range(samples):
            d = ogw_lb_v2(graphs[s][i], graphs[s][j], V=V)
    toc = time()
    tc_ogw.append(toc - tic)
    # tc_ogw.append(t)


print(np.mean(tc_ogw), np.std(tc_ogw))


In [None]:
fig = plt.figure(figsize=(5, 5))
plt.loglog(sizes, tc_tlb,"--", label=r"$\mathsf{GW}_{tlb}$")
plt.loglog(sizes, tc_slb,"--", label=r"$\mathsf{GW}_{slb}$")
plt.loglog(sizes, tc_ogw,"-", label=r"$\mathsf{OGW}_{lb}$")
plt.loglog(sizes, tc_flb,"--", label=r"$\mathsf{GW}_{flb}$")
plt.xlabel("number of nodes")
plt.ylabel("Wall time (s)")
plt.legend(ncol=2)
plt.tight_layout()