In [15]:
import gcn.graph as graph
import numpy as np
import scipy.sparse


In [16]:
def chebyshev(x, L):
    return graph.chebyshev(L, x, hyper['K'])

In [17]:
def create_sq_mesh(M, N):
    # adjacency matrix
    A = np.zeros((M * N, M * N))
    for j in range(N):
        for i in range(M):
            # node id
            k = (j - 1) * M + i
            # edge north
            if i > 1:
                A[k, k - 1] = 1
            # edge south
            if i < M:
                A[k, k + 1] = 1
            # edge west
            if j > 1:
                A[k, k - M] = 1
            # edge east
            if j < N:
                A[k, k + M] = 1
    return A

In [78]:
n = 4
T = 10
x = np.random.rand(n, T)
w = np.fft.fft(x,axis=1)
np.fft.fft(x,axis=0)

array([[ 1.93672752+0.j        ,  2.31442165+0.j        ,
         2.35448531+0.j        ,  2.89988299+0.j        ,
         1.52610896+0.j        ,  1.81831128+0.j        ,
         2.70117192+0.j        ,  2.42189312+0.j        ,
         2.50448281+0.j        ,  3.51942746+0.j        ],
       [-0.05660359+0.17043682j,  0.04685478-0.28573158j,
        -0.12883701+0.26598821j, -0.49822713-0.05582733j,
        -0.17389621+0.23576897j,  0.48263433-0.2732881j ,
         0.57826148+0.03567676j,  0.00665949+0.36340146j,
        -0.18894578+0.60706819j, -0.15116315-0.00421662j],
       [-0.96253345+0.j        ,  1.41447018+0.j        ,
         0.33661879+0.j        , -0.19931130+0.j        ,
         0.60908459+0.j        ,  1.05668180+0.j        ,
        -0.60526816+0.j        , -0.02643999+0.j        ,
         0.64317910+0.j        ,  0.09876681+0.j        ],
       [-0.05660359-0.17043682j,  0.04685478+0.28573158j,
        -0.12883701-0.26598821j, -0.49822713+0.05582733j,
        -0.

In [79]:
A = scipy.sparse.csr_matrix(create_sq_mesh(int(np.sqrt(n)), int(np.sqrt(n))))

In [80]:
L = graph.laplacian(A)

In [81]:
w[:,0].shape

(4,)

In [82]:
def chebyshev(L, X, K):
    """Return T_k X where T_k are the Chebyshev polynomials of order up to K.
    Complexity is O(KMN)."""
    M, N = X.shape
    # assert L.dtype == X.dtype

    # L = rescale_L(L, lmax)
    # Xt = T @ X: MxM @ MxN.
    Xt = np.empty((K, M, N), dtype='complex')
    # Xt_0 = T_0 X = I X = X.
    Xt[0, ...] = X
    # Xt_1 = T_1 X = L X.
    if K > 1:
        Xt[1, ...] = L.dot(X)
    # Xt_k = 2 L Xt_k-1 - Xt_k-2.
    for k in range(2, K):
        Xt[k, ...] = 2 * L.dot(Xt[k-1, ...]) - Xt[k-2, ...]
    return Xt

In [83]:
def chebyshev2(L, X, K):
    """Return T_k X where T_k are the Chebyshev polynomials of order up to K.
    Complexity is O(KMN)."""
    M, N = X.shape
    # assert L.dtype == X.dtype

    # L = rescale_L(L, lmax)
    # Xt = T @ X: MxM @ MxN.
#     Xt = np.empty((K, M, N), L.dtype)
    Xt = np.empty((K, M, N), dtype='complex')    
    # Xt_0 = T_0 X = I X = X.
    Xt[0, ...] = X
    # Xt_1 = T_1 X = L X.
    if K > 1:
        Xt[1, ...] = L.dot(np.real(X)) + 1j*L.dot(np.imag(X))
    # Xt_k = 2 L Xt_k-1 - Xt_k-2.
    for k in range(2, K):
        Xt[k, ...] = 2 * (L.dot(np.real(Xt[k-1, ...])) + 1j*L.dot(np.imag(Xt[k-1, ...])) ) - Xt[k-2, ...]
    return Xt

In [84]:
def chebyshev_einsum(L, X, K):
    """Return T_k X where T_k are the Chebyshev polynomials of order up to K.
    Complexity is O(KMN)."""
    M, N = X.shape
    # assert L.dtype == X.dtype

    # L = rescale_L(L, lmax)
    # Xt = T @ X: MxM @ MxN.
    Xt = np.empty((K, M, N), L.dtype)
    # Xt_0 = T_0 X = I X = X.
    Xt[0, ...] = X
    # Xt_1 = T_1 X = L X.
    if K > 1:
        Xt[1, ...] = np.einsum('ik,kj->ij',L.todense(), X)
    # Xt_k = 2 L Xt_k-1 - Xt_k-2.
    for k in range(2, K):
#         Xt[k, ...] = 2 * np.vdot(L, Xt[k-1, ...]) - Xt[k-2, ...]
        Xt[k, ...] = 2 * np.einsum('ik,kj->ij',L.todense(), Xt[k-1, ...]) - Xt[k-2, ...]
    return Xt

In [85]:
L.shape

(4, 4)

In [86]:
np.einsum('ik,kj->ij',L.todense(),np.conjugate(w)).shape

(4, 10)

In [88]:
(np.real(w) + 1j*np.imag(w))

array([[ 6.54890895 +0.00000000e+00j, -0.23612840 +3.88328758e-01j,
         0.23191857 -5.22936011e-01j, -1.12746769 +7.48771350e-02j,
        -0.56809267 -2.02170767e-01j, -0.99690136 +2.08166817e-16j,
        -0.56809267 +2.02170767e-01j, -1.12746769 -7.48771350e-02j,
         0.23191857 +5.22936011e-01j, -0.23612840 -3.88328758e-01j],
       [ 4.87827777 +0.00000000e+00j,  0.34302306 +1.44070441e-01j,
         0.20055870 -1.02198250e-01j,  0.78447492 +1.47082987e+00j,
        -0.22307695 -2.76936096e-01j, -0.69226890 +1.88737914e-15j,
        -0.22307695 +2.76936096e-01j,  0.78447492 -1.47082987e+00j,
         0.20055870 +1.02198250e-01j,  0.34302306 -1.44070441e-01j],
       [ 6.63217174 +0.00000000e+00j,  0.66093267 -3.60395293e-01j,
        -0.94161309 +7.12830663e-01j, -0.60482368 -1.90707231e-01j,
        -0.50126923 -1.81014092e-01j, -1.14012194 +5.55111512e-16j,
        -0.50126923 +1.81014092e-01j, -0.60482368 +1.90707231e-01j,
        -0.94161309 -7.12830663e-01j,  0.66093

In [89]:
xh = chebyshev(L, w, K)

In [90]:
K = 3
np.linalg.norm(chebyshev(L, w, K) - chebyshev2(L, w, K))

0.0

In [91]:
xh.shape

(3, 4, 10)

In [102]:
W = np.random.random((K,n))

In [104]:
np.einsum('kn,knT->nT',W, xh).shape

(4, 10)