In [37]:
import numpy as np
from og_mha_backend import ProjectMax1, normalizeColumns, AupdateNonDiag, AupdateDiag, ProjectNonNegative
from MHA_backend import project_W, normalize_columns, update_A, project_non_negative

In [362]:
a = np.random.rand(10000, 100)
X = np.random.rand(5000, 10000)

In [363]:
a1 = a.copy()
a2 = a.copy()

In [364]:
%%time
b1 = ProjectMax1(a1)

CPU times: user 52.1 ms, sys: 14.1 ms, total: 66.1 ms
Wall time: 56.5 ms


In [365]:
%%time
b2 = project_W(a2)

CPU times: user 5.43 ms, sys: 0 ns, total: 5.43 ms
Wall time: 4.87 ms


In [366]:
np.linalg.norm(a1 - a2)  

568.4442685685634

In [367]:
np.linalg.norm(b1 - b2)  
# => result: 0.0

0.0

In [7]:
a1 = a.copy()
a2 = a.copy()

In [8]:
%%time
c1 = normalizeColumns(a1)

CPU times: user 25.9 ms, sys: 0 ns, total: 25.9 ms
Wall time: 24.2 ms


In [9]:
%%time
c2 = normalize_columns(a2)

CPU times: user 4.49 ms, sys: 17 µs, total: 4.51 ms
Wall time: 3.74 ms


In [10]:
np.linalg.norm(c1 - c2)
# => result: 1.4570046231790046e-14

1.1866845793461653e-14

In [11]:
a1 = a.copy()
a2 = a.copy()

In [12]:
%%time
cov = np.cov(X, rowvar=False)
d1 = AupdateNonDiag(a1, cov)

CPU times: user 39.2 s, sys: 11.8 s, total: 51 s
Wall time: 3.76 s


In [13]:
%%time
d2 = update_A(a2, X)

CPU times: user 1.88 s, sys: 2.05 s, total: 3.92 s
Wall time: 279 ms


In [14]:
np.linalg.norm(d1 - d2)
# => result: 1.4832387638782345e-15

1.4832387638782345e-15

In [19]:
a1 = a.copy()
a2 = a.copy()

In [20]:
%%time
cov = np.cov(X, rowvar=False)
e1 = AupdateDiag(a1, cov)

CPU times: user 39 s, sys: 11.9 s, total: 50.9 s
Wall time: 3.75 s


In [21]:
%%time
e2 = update_A(a2, X, diag=True)

CPU times: user 2.37 s, sys: 4.49 s, total: 6.85 s
Wall time: 488 ms


In [22]:
np.linalg.norm(e1 - e2)
# => result: 1.4832387638782345e-15

1.9229626863835638e-16

In [24]:
import numpy as np
def armijo_obj_slow(W, X, A):
    # expects the following dims:
    # W: (p, k)
    # X: (N, n, p)
    # A: (N, k, k)
    tmp = (X - X.mean(1)[:, None, :]).dot(W) / np.sqrt(X.shape[1])
    return np.einsum('Nnk,Nnl,Nlj->Nkj', tmp, tmp, A, optimize='greedy').sum()
def armijo_obj(W, X, A):
    obj = 0
#     for i in range(X.shape[0]):
#         tmp = (X[i] - X[i].mean(0)).dot(W) / np.sqrt(X[i].shape[0])
#         obj += tmp.T.dot(tmp).dot(A[i]).sum()
    for x, a in zip(X, A):
        tmp = (x - x.mean(0)).dot(W) / np.sqrt(x.shape[0] - 1)
        obj += np.trace(tmp.T.dot(tmp).dot(a))
    return obj
def armijo_obj_single(W, x, a):
    tmp = (x - x.mean(0)).dot(W) / np.sqrt(x.shape[0] - 1)
    return np.trace(tmp.T.dot(tmp).dot(a))
def armijo_old_single(W, c, a):
    return np.diag(W.T.dot(c).dot(W).dot(a)).sum()

In [33]:
N = 10
x = np.random.rand(N, 500, 1000) # (N, n, p)
a = np.random.rand(N, 100, 100) # (N, k, k)
w = np.random.rand(1000, 100) # (p, k)
c = [np.cov(x[i], rowvar=False) for i in range(N)]
grad = np.random.rand(1000, 100) # (p, k)

In [26]:
%%time
obj1 = 0
for i in range(N):
    obj1 += armijo_old_single(w, c[i], a[i])

CPU times: user 502 ms, sys: 347 ms, total: 848 ms
Wall time: 62.9 ms


In [27]:
%%time
obj2 = armijo_obj(w, x, a)

CPU times: user 163 ms, sys: 520 ms, total: 683 ms
Wall time: 52.2 ms


In [30]:
%%time
obj3 = 0
for i in range(N):
    obj3 += armijo_obj_single(w, x[i], a[i])

CPU times: user 236 ms, sys: 322 ms, total: 557 ms
Wall time: 41.2 ms


In [32]:
print(obj3 - obj1)
print(obj2 - obj1)
# result: 0.0

0.0
0.0


In [276]:
def armijo_old(W, Wgrad, Gtilde, Shat, alpha=0.5, c=0.001, maxIter=1000):
    nSub = len(Shat)
    stopBackTracking = False
    Wgrad = normalizeColumns(Wgrad)
    iterCount = 1

    while stopBackTracking == False:
        Wnew = W - alpha * Wgrad
#         Wnew = ProjectNonNegative(Wnew)
#         Wnew = normalizeColumns(Wnew)  

        currObj = 0
        newObj = 0
        for i in range(nSub):  # TODO: what is this obj??
            currObj += np.diag(W.T.dot(Shat[i]).dot(W).dot(Gtilde[i])).sum()
            newObj += np.diag(Wnew.T.dot(Shat[i]).dot(Wnew).dot(Gtilde[i])).sum()

        m = c * np.diag(Wgrad.T.dot(Wnew - W)).sum()

#         print("obj", currObj)
#         print("new_obj", newObj)
#         print("m", m)
        
        if newObj <= currObj + m:
            stopBackTracking = (
                True  # TODO: why the center diag and the 0.001 and the Wnew-W
            )
        else:
            alpha /= 2
            iterCount += 1
            if iterCount > maxIter:
                stopBackTracking = True
    print(iterCount)
    return Wnew

def armijo_new(W, grad, A, X, alpha=.5, c=.001, tau=.5, max_iter=1000):
    grad = normalize_columns(grad)
    i = 1
    while True:
        W_new = W - alpha * grad
#         W_new = normalize_columns(project_non_negative(W_new))

        obj = armijo_obj(W, X, A)
        obj_new = armijo_obj(W_new, X, A)
#         m = c * alpha * (np.trace(grad.T.dot(W_new-W)) + 0.001)
        m = - c * alpha * np.linalg.norm(grad)**2

#         print("obj", obj)
#         print("new_obj", obj_new)
#         print("m", m)

        if obj_new <= obj + m:
            break
        alpha *= tau
        i += 1
        if i > max_iter:
            break
    print(i)
    return W_new

In [111]:
%%time
ar1 = armijo_old(w, grad, a, c)

obj 1018345.1648557438
new_obj 963518.2838584745
m -4.990005000000001
1
CPU times: user 912 ms, sys: 878 ms, total: 1.79 s
Wall time: 130 ms


In [112]:
%%time
ar2 = armijo_new(w, grad, a, x)

obj 1018345.1648557438
new_obj 963518.2838584747
m -4.990004999999999
1
CPU times: user 415 ms, sys: 960 ms, total: 1.38 s
Wall time: 94.9 ms


In [109]:
np.linalg.norm(ar1 - ar2)

5.260248426533897e-15

In [267]:
def init_W(X, k):
    p = X[0].shape[1]
    mean_cov = np.zeros((p, p))  # mean covariance across all subjects
    for i in range(len(X)):
        mean_cov += (1.0 / len(X)) * np.cov(X[i], rowvar=False)
    eig_vals, eig_vecs = np.linalg.eig(mean_cov)
    idx = eig_vals.argsort()[::-1][:k]
    W = eig_vecs[:, idx]
#     for i in range(W.shape[1]):
#         if np.sum(W[:, i]) < 0:
#             W[:, i] *= -1
#     TODO: is the sign flip necessary?
    W = W * (2 * (W.sum(0) >= 0) - 1)
    return W

def init_old(Shat, k):
    p = Shat[0].shape[0]
    nSub = len(Shat)
    ShatMean = np.zeros((p, p))  # mean covariance across all subjects
    for i in range(nSub):
        ShatMean += (1.0 / nSub) * Shat[i]
    # initialize W to eigenvalues of ShatMean
    evdShat = np.linalg.eig(ShatMean)
    W = evdShat[1][:, evdShat[0].argsort()[::-1][:k]]
    for i in range(W.shape[1]):
        if np.sum(W[:, i]) < 0:
            W[:, i] *= -1
    return W
def efficient_gradJ(W, X, AA):
    """a fast and memory efficient implementation of cov(X).dot(W).dot(AA)"""
    X_tilde = (X - X.mean(0)) / np.sqrt(X.shape[0] - 1)
    tmp = X_tilde.dot(W)
    tmp = tmp.dot(AA)
    return X_tilde.T.dot(tmp)
def efficient_WTcovW(W, X):
    """a fast and memory efficient implementation of W.T.dot(cov(X)).dot(W)"""
    tmp = (X - X.mean(0)).dot(W) / np.sqrt(X.shape[0] - 1)
    return tmp.T.dot(tmp)
def update_G(W, X, diag=False):
    tmp = efficient_WTcovW(W, X)
    if diag:
        tmp = np.diage(np.diag(tmp))
    return tmp - np.eye(W.shape[1])

In [229]:
wi1 = init_W(x, 5)

In [230]:
wi2 = init_old(c, 5)

In [231]:
np.linalg.norm(wi1 - wi2)

0.0

In [351]:
def optimize_new(X, k, diag=False, rho=1, tol=0.01, alpha=0.5, c=0.01, max_iter=1000):
    N = len(X)
    # define initial parameters:
    Lambda = np.zeros((k, k))
    W = init_W(X, k)
    W_old = np.copy(W)
    W = project_non_negative(W)
    
    A = [update_A(W, X[i], diag=diag) for i in range(N)]
    
#     for n_iters in range(max_iter):
#         # -------- update W matrix --------
#         # first compute A
    AA = [0.5 * a.dot(a) - a for a in A]

#         # compute gradient of SM objective with respect to W
    grad = np.zeros(W.shape)
    for i in range(N):
        # TODO: why divide by N?
        grad += efficient_gradJ(W, X[i], AA[i]) / N
    grad += rho * (W.dot(W.T.dot(W) - np.eye(k) + Lambda / rho))
    # compute armijo update:
    W = armijo_new(W, grad, AA, X, alpha=alpha, c=c)
    W = normalize_columns(project_non_negative(W))  # to ensure non-negativity

    Lambda = Lambda + rho * (W.T.dot(W) - np.eye(k))
    A = [update_A(normalize_columns(project_W(W)), X[i], diag=diag) for i in range(N)]

#         if np.linalg.norm(W - W_old) < tol:
#             break
#         else:
#             W_old = np.copy(W)
            
    W = normalize_columns(project_W(W))
    G = [update_G(W, X[i], diag=diag) for i in range(N)]

#     return {"W": W, "G": G, "n_iters": n_iters}
    return AA, W, G, A, Lambda

def optimize_old(Shat, k=2, diagG=False, lagParam=1, tol=0.01, alphaArmijo=0.5, maxIter=1000):
    # define initial parameters:
    p = Shat[0].shape[0]
    nSub = len(Shat)
    ShatMean = np.zeros((p, p))  # mean covariance across all subjects
    for i in range(nSub):
        ShatMean += (1.0 / nSub) * Shat[i]
    LagMult = np.zeros((k, k))

    # initialize W to eigenvalues of ShatMean
    evdShat = np.linalg.eig(ShatMean)
    W = evdShat[1][:, evdShat[0].argsort()[::-1][:k]]
    for i in range(W.shape[1]):
        if np.sum(W[:, i]) < 0:
            W[:, i] *= -1

#     define convergence checks
    Wold = np.copy(W)
    W = ProjectNonNegative(W)

    # define A matrices (related to latent var covariances)
    if diagG:
        A = [AupdateDiag(W, Shat[i]) for i in range(nSub)]
    else:
        A = [AupdateNonDiag(W, Shat[i]) for i in range(nSub)]

    cArmijo = 0.01

#     for iter_ in range(maxIter):
#         # -------- update W matrix --------
#         # first compute Atilde
    AtildeAll = [0.5 * Amat.dot(Amat) - Amat for Amat in A]
#         # compute gradient of SM objective with respect to W
    Wgrad = np.zeros(W.shape)
    for i in range(nSub):
        Wgrad += Shat[i].dot(W).dot(AtildeAll[i]) / float(
            nSub
        )  # TODO: why divide by N?
    Wgrad += lagParam * (W.dot(W.T).dot(W) - W) + W.dot(LagMult)
    # compute armijo update:
    W = armijo_old(W=W, Wgrad=Wgrad, Gtilde=AtildeAll, Shat=Shat,
        alpha=alphaArmijo,
        c=cArmijo,)
    W = ProjectNonNegative(W)  # to ensure non-negativity
    W = normalizeColumns(W)

    # -------- update A matrices --------
#     if diagG:
#         A = [
#             AupdateDiag(normalizeColumns(ProjectMax1(W)), Shat[i])
#             for i in range(nSub)
#         ]
#     else:
#         A = [
#             AupdateNonDiag(normalizeColumns(ProjectMax1(W)), Shat[i])
#             for i in range(nSub)
#         ]

    # -------- update Lagrange multipler --------
    LagMult = LagMult + lagParam * (W.T.dot(W) - np.eye(k))

        # -------- check for convergence --------
#         if np.sum(np.abs(W - Wold)) < tol:
#             break
#         else:
#             Wold = np.copy(W)

    # compute final matrices
    W = normalizeColumns(ProjectMax1(W))
    # compute G (latent variable covariances)
    if diagG:
        G = [np.diag(np.diag(W.T.dot(Shat[i]).dot(W))) - np.eye(k) for i in range(nSub)]
    else:
        G = [W.T.dot(Shat[i]).dot(W) - np.eye(k) for i in range(nSub)]

#     return {"W": W, "G": G, "iter": iter_}
    return AtildeAll, W, G, A, LagMult


In [352]:
AA1, ww1, gg1, A1, L1 = optimize_new(X=x, k=5, diag=False, 
                         rho=1, tol=0.01, alpha=0.5, c=0.01, max_iter=1)

1


In [353]:
AA2, ww2, gg2, A2, L2 = optimize_old(Shat=c, k=5, diagG=False, 
                         lagParam=1, tol=0.01, alphaArmijo=0.5, maxIter=1)

1


In [354]:
np.linalg.norm(np.array(A1) - np.array(A2))

48.076075771672535

In [355]:
np.linalg.norm(np.array(AA1) - np.array(AA2))

7.75319223537933e-13

In [356]:
np.linalg.norm(ww1 - ww2)

1.3665861539295074e-15

In [357]:
np.linalg.norm(np.array(gg1) - np.array(gg2))

5.25821958036673e-16

In [358]:
np.linalg.norm(L1 - L2)

1.2281646194399523e-15