# Fused-Gromov Wasserstein Barycenter example

In [None]:
import torch
from conan_fgw.src.model.fgw.barycenter import fgw_barycenters, fgw_barycenters_BAPG
import time
import ot
from pathlib import Path

# file path
path = Path(__file__)
debug_dict = torch.load((path / 'data' / 'cfm_log.pt').as_posix())
N = debug_dict["N"]
Ys = debug_dict["Ys"]
Cs = debug_dict["Cs"]
ps = debug_dict["ps"]
lambdas = debug_dict["lambdas"]
Ys = torch.stack(Ys)
Cs = torch.stack(Cs)
ps = torch.stack(ps)
p = torch.ones(N) / N
p = p.to(Ys.device)

Now, we compare the performance and computational time vs. Conjugated gradient method

In [None]:
def mse(A, B):
    return torch.mean((A - B) ** 2).item()

start = time.time()
F_bary_ref, C_bary_ref, log = ot.gromov.fgw_barycenters(N=N, Ys=Ys, Cs=Cs, ps=ps, lambdas=lambdas, p=p, warmstartT=True, symmetric=True,
                                alpha=0.5, fixed_structure=False, fixed_features=False, loss_fun='kl_loss', max_iter=50, tol=1e-5,
                                verbose=False, log=True, init_C=None, init_X=None, random_state=None)
print("FGW CG Time elapsed: ", time.time() - start)

start = time.time()
F_bary1, C_bary1, log = fgw_barycenters(N=N, Ys=Ys, Cs=Cs, ps=ps, lambdas=lambdas, p=p, warmstartT=True, symmetric=False, method='sinkhorn_log',
                                alpha=0.5, solver='PGD', fixed_structure=False, fixed_features=False, epsilon=0.05, loss_fun='kl_loss', max_iter=50, tol=1e-5,
                                numItermax=50, stopThr=5e-3, verbose=False, log=True, init_C=None, init_X=None, random_state=None)
print("FGW Sinkhorn Time elapsed: ", time.time() - start)
print("FGW Sinkhorn Feature matrix difference: ", mse(F_bary1, F_bary_ref))
print("FGW Sinkhorn Structure matrix difference: ", mse(C_bary1, C_bary_ref))

start = time.time()
F_bary2, C_bary2, log = fgw_barycenters_BAPG(N=N, Ys=Ys, Cs=Cs, ps=ps, lambdas=lambdas, p=p, warmstartT=True,
                                        alpha=0.5, fixed_structure=False, fixed_features=False, epsilon=0.025, loss_fun='kl_loss', max_iter=50, toly=1e-3, tolc=1e-6, rho=22,
                                        verbose=False, log=True, init_C=None, init_X=None, random_state=None)
print("FGW BAPG Time elapsed: ", time.time() - start)
print("FGW BAPG Feature matrix difference: ", mse(F_bary2, F_bary_ref))
print("FGW BAPG Structure matrix difference: ", mse(C_bary2, C_bary_ref))