In [5]:
from time import time
import numpy as np
from numba import jit, njit
from numpy.testing import assert_array_equal
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext line_profiler

In [25]:
# create k matrices of shape m x m
k, m = 128, 200

rng = np.random.RandomState(42) 

diagonals = rng.uniform(size=(k, m))
B = rng.randn(m, m)  # mixing matrix
M = np.array([B.dot(d[:, None] * B.T) for d in diagonals])  # dataset


In [26]:
def mean_rotation(C):
    C_mean = np.mean(C, axis=0)
    vals, vecs = np.linalg.eigh(C_mean)
    B = vecs.T / np.sqrt(vals[:, None])
    C = B[None,:,:] @ M @ B.T[None,:,:]
    return B, C

@jit(nopython=True)
def rotmat(C,i,j):
    '''
    compute update matrix according to phams method see:
    D. T. Pham, “Joint Approximate Diagonalization of Positive Definite Hermitian Matrices,”
    SIAM Journal on Matrix Analysis and Applications, vol. 22, no. 4, pp. 1136–1152, Jan. 2001.
    '''
    C_ii = C[:, i, i] 
    C_jj = C[:, j, j]
    C_ij = C[:, i, j]

    # find g_ij (2.04)
    g_ij = np.mean(C_ij / C_ii)
    g_ji = np.mean(C_ij / C_jj)

    # find w_ij (2.07) with w_ii, w_jj = 1, 1
    w_ij = np.mean(C_jj / C_ii)
    w_ji = np.mean(C_ii / C_jj)

    # solve 2.10, that is find h such that W @ h = g
    w_tilde_ji = np.sqrt(w_ji / w_ij)
    w_prod = np.sqrt(w_ij * w_ji)
    tmp1 = (w_tilde_ji * g_ij + g_ji) / (w_prod + 1)
    tmp2 = (w_tilde_ji * g_ij - g_ji) / max(w_prod - 1, 1e-9) 
    h12 = tmp1 + tmp2 # (2.10)
    h21 = np.conj((tmp1 - tmp2) / w_tilde_ji)

    # decrease in current step 
    decrease = k * (g_ij * np.conj(h12) + g_ji * h21) / 2.0

    # construct T by 2.08
    tmp = 1 + 1.j * 0.5 * np.imag(h12 * h21)
    tmp = np.real(tmp + np.sqrt(tmp ** 2 - h12 * h21))
    T = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) # stacks all scalar values
    return T, decrease

def phams(Gamma, threshold=1e-50, mean_initialize=False):
    '''
    find approximate joint diagonalization of set of square matrices Gamma,
    returns joint basis B and corresponding set of approximate diagonals C
    '''
    C = np.copy(Gamma)
    k, m, _ = C.shape
    B = np.eye(m)

    # precompute B
    if mean_initialize:
        B, C = mean_rotation(C)
 
    active = 1
    while active == 1:
        cum_decrease = 0
        for i in range(0, m):
            for j in range(0, i):
                # computation of rotations           
                T, decrease = rotmat(C,i,j)
                cum_decrease += decrease

                # update of C and B matrices
                pair = np.array((i,j))
                C[:,:,pair] = C[:,:,pair] @ T.T[None,:,:]
                C[:,pair,:] = T[None,:,:] @ C[:,pair,:]
                # C[:,:,pair] = np.einsum('ij,klj->kli',T,C[:,:,pair]) einsum alternative
                # C[:,pair,:] = np.einsum('ij,kjl->kil',T,C[:,pair,:]) einsum alternative
                B[pair,:] = T @ B[pair,:]

        active = np.abs(decrease) > threshold

    return B, C

In [27]:
%lprun -f phams phams(M)

In [None]:
Timer unit: 1e-06 s

Total time: 271.475 s
File: <ipython-input-26-9a492039a17e>
Function: phams at line 44

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    44                                           def phams(Gamma, threshold=1e-50, mean_initialize=False):
    45                                               '''
    46                                               find approximate joint diagonalization of set of square matrices Gamma,
    47                                               returns joint basis B and corresponding set of approximate diagonals C
    48                                               '''
    49         1      37241.0  37241.0      0.0      C = np.copy(Gamma)
    50         1          6.0      6.0      0.0      k, m, _ = C.shape
    51         1         63.0     63.0      0.0      B = np.eye(m)
    52                                           
    53                                               # precompute B
    54         1          1.0      1.0      0.0      if mean_initialize:
    55                                                   B, C = mean_rotation(C)
    56                                            
    57         1          1.0      1.0      0.0      active = 1
    58        12        116.0      9.7      0.0      while active == 1:
    59        11         11.0      1.0      0.0          cum_decrease = 0
    60      2211       1682.0      0.8      0.0          for i in range(0, m):
    61    221100     261688.0      1.2      0.1              for j in range(0, i):
    62                                                           # computation of rotations           
    63    218900    3807326.0     17.4      1.4                  T, decrease = rotmat(C,i,j)
    64    218900     251108.0      1.1      0.1                  cum_decrease += decrease
    65                                           
    66                                                           # update of C and B matrices
    67    218900    1798990.0      8.2      0.7                  pair = np.array((i,j))
    68    218900  213735813.0    976.4     78.7                  C[:,:,pair] = C[:,:,pair] @ T.T[None,:,:]
    69    218900   48408551.0    221.1     17.8                  C[:,pair,:] = T[None,:,:] @ C[:,pair,:]
    70                                                           # C[:,:,pair] = np.einsum('ij,klj->kli',T,C[:,:,pair]) einsum alternative
    71                                                           # C[:,pair,:] = np.einsum('ij,kjl->kil',T,C[:,pair,:]) einsum alternative
    72    218900    3172459.0     14.5      1.2                  B[pair,:] = T @ B[pair,:]
    73                                           
    74        11        115.0     10.5      0.0          active = np.abs(decrease) > threshold
    75                                           
    76         1          1.0      1.0      0.0      return B, C

In [None]:
# check if B and Bhat are identical up to permutation and scaling
BA = np.abs(Bhat.dot(B))  # undo negative scaling 
BA /= np.max(BA, axis=1, keepdims=True) # normalize to 1
BA[np.abs(BA) < 1e-12] = 0. # numerical tolerance
print(BA)

plt.imshow(BA @ BA.T)
plt.show()
assert_array_equal(BA[np.lexsort(BA)], np.eye(m))