In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random

import networkx as nx
import numpy as np
import scipy.optimize as optim
from fgw.gromov_prox import *
from fgw.bary_utils import commute_time


In [3]:
def laplacian(A):
    """ Get graph Laplacian matrix

    Args:
        A (np.ndarray): Adjacency matrix

    Returns:
        (np.ndarray): Laplacian matrix
    """ 
    return np.diag(A.sum(1)) - A


In [4]:
N = 5
N1 = 6
# N2 = 3

G1 = nx.cycle_graph(N1)
# G2 = nx.binomial_graph(N2, random.random())
A1 = nx.adjacency_matrix(G1).toarray()
# A2 = nx.adjacency_matrix(G2).toarray()
D1 = commute_time(A1)
# D2 = commute_time(A2)

G = nx.cycle_graph(N)
A = nx.adjacency_matrix(G).toarray()
C = commute_time(A)

eet = np.ones((N, N))


In [5]:
def z(A):
    N = A.shape[0]
    # eet = np.ones((N, N))
    L_inv = np.linalg.inv(laplacian(A) + eet / N) - eet / N
    return L_inv

np.testing.assert_array_almost_equal(np.linalg.pinv(laplacian(A)), z(A))


In [6]:
""" Optimize over P_i, from fgwtil_ub """
# _, P1 = fused_gromov_upper_bound_rec(np.zeros([N, N1]), C, D1, 1, 'OE', True)
# _, P2 = fused_gromov_upper_bound_rec(np.zeros([N, N2]), C, D2, 1, 'OE', True)

""" Optimize over P_i, from GW """
P1 = gromov_wasserstein(C, D1, np.ones(N) / N, np.ones(N1) / N1, loss_fun="square_loss")
# P2 = gromov_wasserstein(C, D2, np.ones(N) / N, np.ones(N2) / N2, loss_fun="square_loss")


In [7]:
def rp(X, decimals=3):
    print(np.round(X.reshape(N, N), decimals=decimals))


## $\partial f / \partial C$

In [8]:
def f_C(C):
    """
    .. math::

    f(C) = \norm{C}^2 + \sum_i \lambda_i \norm{D_i}^2 - \tr(C P_i D_i P_i.T )
    """
    C = np.reshape(C, (N, N))
    C_norm = np.linalg.norm(C)
    D1_norm = np.linalg.norm(D1)
    # return np.linalg.norm(C)**2 + 0.5 * np.linalg.norm(D1)**2 + 0.5 * np.linalg.norm(D2)**2 \
    #     - np.trace(C @ P1 @ D1 @ P1.T) - np.trace(C @ P2 @ D2 @ P2.T)
    return C_norm**2 / N**2 + D1_norm**2 / N1**2 - 2 * np.trace(C @ P1 @ D1 @ P1.T)


def grad_f_C(C):
    """ derivative wrt C
    .. math::

    \partial f / \partial C = 2 * C - \sum_i \tr(P_i D_i.T P_i.T)
    """
    return 2 * C / N**2 - 2 * P1 @ D1.T @ P1.T


grad_approx = optim.approx_fprime(C.flatten(), f_C, np.sqrt(np.finfo(float).eps))
grad_true = grad_f_C(C)

# np.testing.assert_array_almost_equal(grad_approx.reshape((N, N)), grad_true)
print("approx grad")
rp(grad_approx, 5)
print("true grad")
rp(grad_true, 5)


approx grad
[[-0.56889 -0.56889 -0.20444  0.15111 -0.27556]
 [-0.56889 -0.35556 -0.20444 -0.32889 -0.00889]
 [-0.20444 -0.20444 -0.22222 -0.47111 -0.36444]
 [ 0.15111 -0.32889 -0.47111 -0.22223 -0.59556]
 [-0.27556 -0.00889 -0.36444 -0.59556 -0.22222]]
true grad
[[-0.56889 -0.56889 -0.20444  0.15111 -0.27556]
 [-0.56889 -0.35556 -0.20444 -0.32889 -0.00889]
 [-0.20444 -0.20444 -0.22222 -0.47111 -0.36444]
 [ 0.15111 -0.32889 -0.47111 -0.22222 -0.59556]
 [-0.27556 -0.00889 -0.36444 -0.59556 -0.22222]]


## $\partial \ vol(G) / \partial Z$
$Z = (L + 1 1^\top / n)^{-1}$

In [9]:
def L_pinv(A):
    """ pseudo inverse
    .. math::

    L_pinv = (L + 11.T / N)^{-1} - 11.T / N
    """
    N = A.shape[0]
    L = laplacian(A)
    # eet = np.ones([N, N])
    Lpinv = np.linalg.inv(L + eet / N) - eet / N
    return Lpinv

def Z(L):
    """ pseudo inverse
    .. math::

    Z = (L + 11.T / n) 
    """
    # eet = np.ones([N, N])
    return np.linalg.inv(L + eet / N)


def vol(A):
    """ vol(G)
    vol(G) = 1 A 1.T
    """
    return A.sum()


def grad_vol_A():
    """
    \partial vol(G) / \partial G = 1 1.T
    """
    return np.ones([N, N])


def volZ(Z):
    """ calculate vol of graph from Z
    volZ = \one^\top (Z^{-1} - \one\one^\top/N) \one
    """
    Z = np.reshape(Z, [N, N])
    # eet = np.ones([N, N])
    # recover the L 
    L = np.linalg.inv(Z) - eet / N
    # rp(L)
    # print(np.trace(L))
    # np.fill_diagonal(L, 0)
    # print(-L.sum())
    # assert -L.sum() == np.trace(L)
    # return -L.sum()
    return np.trace(L)

@DeprecationWarning
def volZ_v2(Z):
    Z = Z.reshape((N, N))
    # eet = np.ones((N, N))
    L = np.linalg.inv(Z) - eet / N
    np.fill_diagonal(L, 0)
    return -L.sum()


def grad_volZ_Z(Z):
    """ 
    \grad vol(Z) = -Z^{-1} Z^{-1} 

    Ref: Cookbook p.26
    """
    Z = np.reshape(Z, [N, N])
    Z_inv = np.linalg.inv(Z)
    # eet = np.ones([N, N])
    # np.fill_diagonal(eet, 0)
    # return - Z_inv @ Z_inv * eet + eet
    return -Z_inv @ Z_inv


L = laplacian(A)
Zm = Z(L)
# rp(Zm, 3)

# check the correct pseudo inverse
np.testing.assert_array_almost_equal(Z(L) - np.ones((N, N))/N, np.linalg.pinv(L))

# check vol(G) = sum(A) OR trace(Z^-1 - 11.T / N)
assert volZ(Z(L)) ==  vol(A)
# print(volZ(Z(L)), vol(A))

grad_approx = optim.approx_fprime(Zm.flatten(), volZ, np.sqrt(np.finfo(float).eps)).reshape([N, N])
grad_true = grad_volZ_Z(Zm)

print("approx grad")
rp(grad_approx)
print("true grad")
rp(grad_true)


approx grad
[[-6.2  3.8 -1.2 -1.2  3.8]
 [ 3.8 -6.2  3.8 -1.2 -1.2]
 [-1.2  3.8 -6.2  3.8 -1.2]
 [-1.2 -1.2  3.8 -6.2  3.8]
 [ 3.8 -1.2 -1.2  3.8 -6.2]]
true grad
[[-6.2  3.8 -1.2 -1.2  3.8]
 [ 3.8 -6.2  3.8 -1.2 -1.2]
 [-1.2  3.8 -6.2  3.8 -1.2]
 [-1.2 -1.2  3.8 -6.2  3.8]
 [ 3.8 -1.2 -1.2  3.8 -6.2]]


## $\partial C / \partial Z$

In [10]:
def calc_T(Z):
    """ T: see Eq. (67) """
    n = Z.shape[0]
    T = np.zeros([n, n])
    for i in range(n):
        for j in range(n):
            T[i, j] = (Z[i, i] + Z[j, j] - Z[i, j] - Z[j, i]) if i != j else (Z[i, i] + Z[j, j]) / 2
    return T


# def calc_Z(T):
#     n = T.shape[0]
#     Z = np.zeros([n, n])
#     for i in range(n):
#         for j in range(n):
#             Z[i, j] = 0.5 * (T[i, i] + T[j, j] - T[i, j]) if i != j else (T[i, i] + T[j, j]) / 2
#     return Z


# def vol_T(T):
#     return volZ(calc_Z(T))


# def vol_X(X):
#     return vol_Z()

# def T_star(T):
#     n = T.shape[0]
#     Z = np.zeros([n, n])
#     for i in range(n):
#         for j in range(n):
#             Z[i, j] = 0.5*(T[i, i] + T[j, j] - T[i, j]) if i!=j else (T[i, i]+T[j, j])/2
#     return Z


def T_star(X):
    Z = np.zeros([N, N])
    for i in range(N):
        for j in range(N):
            if i != j:
                Z[i, j] = - X[i, j] - X[j, i]
            else:
                Z[i, j] = - X[i, i] + 2 * sum([X[i, k] for k in range(N)])
    return Z


M = np.random.rand(N**2).reshape([N, N]) * 10
M = M.T / 2 + M / 2
print(np.sum(M * calc_T(Zm)))
print(np.sum(T_star(M) * Zm))


def CT(Z):
    """ commute time
    .. math::

    C = vol(Z) * T(Z) \prod (11.T - I)
    """
    Z = np.reshape(Z, [N, N])
    T = calc_T(Z)
    # e = np.ones([N, 1])
    eet = np.ones((N, N))
    ct = volZ(Z) * (T * (eet - np.identity(N)))
    np.fill_diagonal(ct, 0)
    return ct


def f_Z(Z):
    """ calculate the barycenter loss given Z """
    return f_C(CT(Z))


grad_approx = optim.approx_fprime(Zm.flatten(), volZ, np.sqrt(np.finfo(float).eps)).reshape([N, N])

# def grad_Z(Z, X):
#     return vol_Z(Z)*T_star(X) + grad_vol_Z(Z)*CT(Z)/vol_Z(Z)


def grad_XZ(Z):
    # REVIEW: chang
    XZ = np.zeros([N, N, N, N])
    for i in range(N):
        for j in range(N):
            if i != j:
                XZ[i, j, i, j] = -1
                XZ[j, i, i, j] = -1
                XZ[i, i, i, j] = 1
                XZ[j, j, i, j] = 1
    for i in range(N):
        XZ[:, : i, i] *= 0
    # wried observation: only the upper triangular of the last two dims are correct.
    # symmetrize over last two dims.
    for i in range(N):
        for j in range(i + 1, N):
            XZ[:, :, j, i] = XZ[:, :, i, j]
    return XZ


def vec(X):
    return X.reshape([-1, 1])


def grad_CZ(Z):
    gXZ = grad_XZ(Z)
    for i in range(N):
        gXZ[:, :, i, i] *= 0
    grad_CZ = volZ(Z) * gXZ.reshape([N**2, N**2]) + vec(grad_volZ_Z(Z)) @ vec(CT(Z)).T / volZ(Z)
    return grad_CZ.reshape([N, N, N, N])


def grad_Z(Z):
    """  """
    return np.einsum('pq, ijpq -> ij', grad_f_C(CT(Z)), grad_CZ(Z))


# grad_true = (vec(grad_C(CT(Zm))).T @ grad_CZ(Zm)).reshape([N, N])
grad_approx = optim.approx_fprime(Zm.flatten(), lambda x: f_Z(x), np.sqrt(np.finfo(float).eps) * 100).reshape([N, N])
grad_true = grad_Z(Zm)

rp(grad_approx / 2 + grad_approx.T / 2)
rp(grad_true)


127.67169704197076
127.67169704197076
[[ 14.273  -8.375  10.327   3.216 -14.242]
 [ -8.375  10.007 -15.664  12.816   6.416]
 [ 10.327 -15.664   7.34  -10.331  13.527]
 [  3.216  12.816 -10.331   7.34   -7.842]
 [-14.242   6.416  13.527  -7.842   7.34 ]]
[[ 14.273  -8.375  10.327   3.216 -14.242]
 [ -8.375  10.007 -15.664  12.816   6.416]
 [ 10.327 -15.664   7.34  -10.331  13.527]
 [  3.216  12.816 -10.331   7.34   -7.842]
 [-14.242   6.416  13.527  -7.842   7.34 ]]


## $\partial Z / \partial L$

In [11]:
L = laplacian(A)


def C(L):
    e = np.ones([N, 1])
    eet = np.ones((N, N))
    return np.linalg.inv(L + eet) - eet


def f_L(L):
    L = np.reshape(L, (N, N))
    # e = np.ones([N, 1])
    # eet = np.ones((N, N))
    return f_Z(Z(L))


def grad_L(X, L):
    return -Z(L) @ X @ Z(L)


grad_approx = optim.approx_fprime(L.flatten(), f_L, np.sqrt(np.finfo(float).eps)).reshape([N, N])
grad_true = grad_L(grad_Z(Z(L)), L)

rp(grad_approx / 2 + grad_approx.T / 2)
rp(grad_true)


[[-0.917 -0.917 -3.698 -1.749  2.084]
 [-0.917  0.476  2.027 -3.328 -3.456]
 [-3.698  2.027  1.877 -0.107 -5.298]
 [-1.749 -3.328 -0.107  0.05  -0.064]
 [ 2.084 -3.456 -5.298 -0.064  1.536]]
[[-0.917 -0.917 -3.698 -1.749  2.084]
 [-0.917  0.476  2.027 -3.328 -3.456]
 [-3.698  2.027  1.877 -0.107 -5.298]
 [-1.749 -3.328 -0.107  0.05  -0.064]
 [ 2.084 -3.456 -5.298 -0.064  1.536]]


## $\partial L / \partial X$

In [12]:
def f_A(A):
    A = np.reshape(A, (N, N))
    return f_L(laplacian(A))


def A_star(X):
    N = X.shape[0]
    _ = np.zeros([N, N])
    for i in range(N):
        for j in range(N):
            if i != j:
                _[i, j] = -X[i, j] + 0.5 * X[i, i] + 0.5 * X[j, j]
    return _


grad_A_approx = optim.approx_fprime(A.flatten(), f_A, np.sqrt(np.finfo(float).eps)).reshape([N, N])
grad_A_true = A_star(grad_L(grad_Z(Z(laplacian(A))), laplacian(A)))

rp(grad_A_approx / 2 + grad_A_approx.T / 2)
rp(grad_A_true)
rp(grad_A_approx)

[[ 0.     0.697  4.178  1.316 -1.774]
 [ 0.697  0.    -0.85   3.591  4.462]
 [ 4.178 -0.85   0.     1.07   7.004]
 [ 1.316  3.591  1.07   0.     0.857]
 [-1.774  4.462  7.004  0.857  0.   ]]
[[ 0.     0.697  4.178  1.316 -1.774]
 [ 0.697  0.    -0.85   3.591  4.462]
 [ 4.178 -0.85   0.     1.07   7.004]
 [ 1.316  3.591  1.07   0.     0.857]
 [-1.774  4.462  7.004  0.857  0.   ]]
[[ 0.    -0.     2.78   0.832 -3.001]
 [ 1.394  0.    -1.55   3.804  3.932]
 [ 5.575 -0.149  0.     1.984  7.175]
 [ 1.799  3.378  0.156  0.     0.114]
 [-0.548  4.992  6.834  1.6    0.   ]]


In [13]:
def grad_f_A(A):
    L = laplacian(A)
    Z_mat = Z(L)
    return A_star(grad_L(grad_Z(Z_mat), L))


grad_approx = optim.approx_fprime(A.flatten(), f_A, np.sqrt(np.finfo(float).eps)).reshape([N, N])
grad_totest = grad_f_A(A)

rp(grad_approx / 2 + grad_approx.T / 2)
rp(grad_totest)


[[ 0.     0.697  4.178  1.316 -1.774]
 [ 0.697  0.    -0.85   3.591  4.462]
 [ 4.178 -0.85   0.     1.07   7.004]
 [ 1.316  3.591  1.07   0.     0.857]
 [-1.774  4.462  7.004  0.857  0.   ]]
[[ 0.     0.697  4.178  1.316 -1.774]
 [ 0.697  0.    -0.85   3.591  4.462]
 [ 4.178 -0.85   0.     1.07   7.004]
 [ 1.316  3.591  1.07   0.     0.857]
 [-1.774  4.462  7.004  0.857  0.   ]]


In [14]:
ct = CT(Z(laplacian(A)))

grad_approx = optim.approx_fprime(ct.flatten(), f_C, np.sqrt(np.finfo(float).eps)).reshape((N, N))

In [15]:
f_A(A), f_C(ct)

(37.484444444444506, 37.484444444444506)