In [6]:
import numpy as np
np.random.seed(42)
np.set_printoptions(precision=3)
np.set_printoptions(suppress=True)

r = 2
alpha = 1e-6

def PSD(A, e=1e-10):
    try:
        np.linalg.cholesky(A)
        return A
    except np.linalg.LinAlgError:
        W, V = np.linalg.eigh(A)
        if W[-1] > e:
            i = np.argmax(W > e)
            return V[:,i:] @ np.diag(W[i:]) @ V[:,i:].T
        else:
            return np.zeros(A.shape)
        
def psd(A, e=1e-10):
    try:
        np.linalg.cholesky(A)
        return A
    except np.linalg.LinAlgError:
        W, V = np.linalg.eigh(A)
        if W[-1] > e:
            i = np.argmax(W > e)
            return (V[:,i:] * W[i:]) @ V[:,i:].T
        else:
            return np.zeros(A.shape)
        
def center(A):
    n = A.shape[1]
    A = A - np.sum(A, axis=1, keepdims=True) / n
    A = A - np.sum(A, axis=0, keepdims=True) / n
    
    return A
    
def PSD_delta(A, e=1e-10):
    try:
        np.linalg.cholesky(A)
        return np.zeros(1)
    except np.linalg.LinAlgError:
        W, V = np.linalg.eigh(A)
        if W[-1] > e:
            i = np.argmax(W > -e)
            return V[:,0:i] * np.sqrt(-W[0:i]).reshape(1, -1)
        else:
            return None
        
n = 1000
d = 10
active_row = .1
active_col = .2

active_row_bool = np.random.rand(n) < active_row
active_col_bool = np.random.rand(n) < active_col
active_all_bool = np.logical_or(active_row_bool, active_col_bool)
active_row_indx = np.argwhere(active_row_bool).flatten()
active_col_indx = np.argwhere(active_col_bool).flatten()
active_all_indx = np.argwhere(active_all_bool).flatten()

X = np.random.randn(n, d) / np.sqrt(2 * d)
G = X @ X.T
Y = np.where(np.random.rand(n) < .5, 0, 1)
T = np.where(np.equal(Y.reshape(-1, 1), Y.reshape(1, -1)), 1, -1)
N = np.sum(T == 1, axis=1) - 1
D = np.diag(G).reshape(-1, 1) + np.diag(G).reshape(1, -1) - 2 * G
E = np.ones(n) * np.sqrt(np.pi)
V = np.where(D - E.reshape(-1, 1) > 0, 1.0, 0.0) * T
V[~active_row_bool, :] = 0
V[:, ~active_col_bool] = 0
U = np.diag(np.sum(V + V.T, axis=0)) - (V + V.T)
M = G.T @ U @ G
M = PSD(M)

active_row_bool = np.random.rand(n) < active_row
active_col_bool = np.random.rand(n) < active_col
active_all_bool = np.logical_or(active_row_bool, active_col_bool)
active_row_indx = np.argwhere(active_row_bool).flatten()
active_col_indx = np.argwhere(active_col_bool).flatten()
active_all_indx = np.argwhere(active_all_bool).flatten()

X = np.random.randn(n, d) / np.sqrt(2 * d)
G = X @ X.T
Y = np.where(np.random.rand(n) < .5, 0, 1)
T = np.where(np.equal(Y.reshape(-1, 1), Y.reshape(1, -1)), 1, -1)
N = np.sum(T == 1, axis=1) - 1
D = np.diag(G).reshape(-1, 1) + np.diag(G).reshape(1, -1) - 2 * G
E = np.ones(n) * np.sqrt(np.pi)
V = np.where(D - E.reshape(-1, 1) > 0, 1.0, 0.0) * T
V[~active_row_bool, :] = 0
V[:, ~active_col_bool] = 0
W = X.T @ M @ X

active_row_bool = np.any(V, axis=1)
active_col_bool = np.any(V, axis=0)
active_all_bool = np.logical_or(active_row_bool, active_col_bool)
active_row_indx = np.argwhere(active_row_bool).flatten()
active_col_indx = np.argwhere(active_col_bool).flatten()
active_all_indx = np.argwhere(active_all_bool).flatten()

print(active_row_indx.size / n, active_col_indx.size / n, active_all_indx.size / n)



0.104 0.154 0.252


In [27]:
rng = np.random.Generator(np.random.PCG64(12345))

E = np.zeros(1).item()
print(np.isscalar(E))
E = rng.standard_normal(1)
print(np.isscalar(E))
E = np.ones(1)
print(np.isscalar(np.sum(E)))


True
False
True


In [2]:
# a_mode == 'WX'

A = W
B = X

R = r * .5 * np.dot(A.T.ravel(), A.ravel())

P = B @ A @ B.T
D = np.diag(P).reshape(-1, 1) + np.diag(P).reshape(1, -1) - 2 * P

Z = V + V.T
U = np.diag(np.sum(Z, axis=0)) - Z
U1 = U.take(active_all_indx, axis=0).take(active_all_indx, axis=1)

V2 = V.take(active_all_indx, axis=0).take(active_all_indx, axis=1)
Z = V2 + V2.T
U2 = np.diag(np.sum(Z, axis=0)) - Z
print(np.allclose(U1, U2))

gRgA = r * A
dRdA = r * A
gLgA = B.T @ U @ B
dLdA = B.T @ U @ B

B2 = B.take(active_all_indx, axis=0)
gLgA2 = B2.T @ U2 @ B2
dLdA2 = B2.T @ U2 @ B2
print(np.allclose(gLgA, gLgA2))

gFgA = gRgA + gLgA
dFdA = dRdA + dLdA

phiA = -np.dot(dFdA.T.ravel(), dFdA.ravel())
phiA2 = -np.dot(gFgA.ravel(), dFdA.ravel())
print(np.allclose(phiA, phiA2))

Anew = A - alpha * dFdA
Anew2 = (1 - alpha * r) * A - alpha * dLdA2
print(np.allclose(Anew, Anew2))

Rnew = r * .5 * np.dot(Anew.T.ravel(), Anew.ravel())
Pnew = B @ Anew @ B.T

Rnew2 = (1 - alpha * r) ** 2 * R - r * ((1 - alpha * r) * alpha * np.dot(A.T.ravel(), dLdA2.ravel())
                                        - .5 * alpha ** 2 * np.dot(dLdA2.T.ravel(), dLdA2.ravel()))
print(np.allclose(Rnew, Rnew2))

Pnew2 = (1 - alpha * r) * P - alpha * B @ dLdA2 @ B.T
print(np.allclose(Pnew, Pnew2))

Dnew = np.diag(Pnew).reshape(-1, 1) + np.diag(Pnew).reshape(1, -1) - 2 * Pnew
Dnew1 = Dnew.take(active_row_indx, axis=0).take(active_col_indx, axis=1)

Dnew2 = (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
         B.take(active_row_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
         Anew.reshape(d ** 2, 1) +
         (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
          B.take(active_col_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
          Anew.reshape(d ** 2, 1)).T -
         2 * B.take(active_row_indx, axis=0) @ Anew @ B.take(active_col_indx, axis=0).T
        )
print(np.allclose(Dnew1, Dnew2))

Dnew3 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
                  B.take(active_row_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
                  dLdA2.reshape(d ** 2, 1) +
                  (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
                   B.take(active_col_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
                   dLdA2.reshape(d ** 2, 1)).T -
                  2 * B.take(active_row_indx, axis=0) @ dLdA2 @ B.take(active_col_indx, axis=0).T)
        )
print(np.allclose(Dnew1, Dnew3))

active_all_dict = {k: v for v, k in enumerate(active_all_indx)}
select_row_indx = [active_all_dict[i] for i in active_row_indx]
select_col_indx = [active_all_dict[i] for i in active_col_indx]

Pnew4 = B2 @ Anew @ B2.T
Dnew4 = (np.diag(Pnew4).take(select_row_indx).reshape(-1, 1) + 
         np.diag(Pnew4).take(select_col_indx).reshape(1, -1) - 
         2 * Pnew4.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
print(np.allclose(Dnew1, Dnew4))

dP = B2 @ dLdA2 @ B2.T
Dnew5 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (np.diag(dP).take(select_row_indx).reshape(-1, 1) +
                  np.diag(dP).take(select_col_indx).reshape(1, -1) -
                  2 * dP.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
        )
print(np.allclose(Dnew1, Dnew5))

True
True
True
True
True
True
True
True
True
True


In [3]:
# a_mode == 'MX'

A = M
B = X

R = r * .5 * np.dot((B.T @ A @ B).T.ravel(), (B.T @ A @ B).ravel())

P = B @ (B.T @ A @ B) @ B.T
D = np.diag(P).reshape(-1, 1) + np.diag(P).reshape(1, -1) - 2 * P

Z = V + V.T
U = np.diag(np.sum(Z, axis=0)) - Z
U1 = U.take(active_all_indx, axis=0).take(active_all_indx, axis=1)

V2 = V.take(active_all_indx, axis=0).take(active_all_indx, axis=1)
Z = V2 + V2.T
U2 = np.diag(np.sum(Z, axis=0)) - Z
print(np.allclose(U1, U2))

gRgA = r * B @ (B.T @ A @ B) @ B.T
dRdA = r * A
gLgA = B @ (B.T @ U @ B) @ B.T
dLdA = U

B2 = B.take(active_all_indx, axis=0)
gLgA2 = B @ (B2.T @ U2 @ B2) @ B.T
dLdA2 = U2
print(np.allclose(gLgA, gLgA2))

gFgA = gRgA + gLgA
dFdA = dRdA + dLdA

phiA = -np.dot((B.T @ dFdA @ B).T.ravel(), (B.T @ dFdA @ B).ravel())
phiA2 = -np.dot(gFgA.ravel(), dFdA.ravel())
print(np.allclose(phiA, phiA2))

Anew = A - alpha * dFdA
Anew2 = (1 - alpha * r) * A
Anew2[np.ix_(active_all_indx, active_all_indx)] -= alpha * dLdA2
print(np.allclose(Anew, Anew2))

Rnew = r * .5 * np.dot((B.T @ Anew @ B).T.ravel(), (B.T @ Anew @ B).ravel())
Pnew = B @ (B.T @ Anew @ B) @ B.T

Rnew2 = (1 - alpha * r) ** 2 * R - r * ((1 - alpha * r) * alpha * np.dot((B.T @ A @ B).T.ravel(),
                                                                         (B2.T @ dLdA2 @ B2).ravel())
                                        - .5 * alpha ** 2 * np.dot((B2.T @ dLdA2 @ B2).T.ravel(),
                                                                   (B2.T @ dLdA2 @ B2).ravel()))
print(np.allclose(Rnew, Rnew2))

Pnew2 = (1 - alpha * r) * P - alpha * B @ (B2.T @ dLdA2 @ B2) @ B.T
print(np.allclose(Pnew, Pnew2))

Dnew = np.diag(Pnew).reshape(-1, 1) + np.diag(Pnew).reshape(1, -1) - 2 * Pnew
Dnew1 = Dnew.take(active_row_indx, axis=0).take(active_col_indx, axis=1)

Dnew2 = (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
         B.take(active_row_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
         (B.T @ Anew @ B).reshape(d ** 2, 1) +
         (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
          B.take(active_col_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
          (B.T @ Anew @ B).reshape(d ** 2, 1)).T -
         2 * B.take(active_row_indx, axis=0) @ (B.T @ Anew @ B) @ B.take(active_col_indx, axis=0).T
        )
print(np.allclose(Dnew1, Dnew2))

Dnew3 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
                  B.take(active_row_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
                  (B2.T @ dLdA2 @ B2).reshape(d ** 2, 1) +
                  (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(d), d), axis=1) *
                   B.take(active_col_indx, axis=0).take(np.tile(np.arange(d), d), axis=1) @
                   (B2.T @ dLdA2 @ B2).reshape(d ** 2, 1)).T -
                  2 * B.take(active_row_indx, axis=0) @ (B2.T @ dLdA2 @ B2) @ B.take(active_col_indx, axis=0).T)
        )
print(np.allclose(Dnew1, Dnew3))

active_all_dict = {k: v for v, k in enumerate(active_all_indx)}
select_row_indx = [active_all_dict[i] for i in active_row_indx]
select_col_indx = [active_all_dict[i] for i in active_col_indx]

Pnew4 = B2 @ (B.T @ Anew @ B) @ B2.T
Dnew4 = (np.diag(Pnew4).take(select_row_indx).reshape(-1, 1) + 
         np.diag(Pnew4).take(select_col_indx).reshape(1, -1) - 
         2 * Pnew4.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
print(np.allclose(Dnew1, Dnew4))

dP = B2 @ (B2.T @ dLdA2 @ B2) @ B2.T
Dnew5 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (np.diag(dP).take(select_row_indx).reshape(-1, 1) +
                  np.diag(dP).take(select_col_indx).reshape(1, -1) -
                  2 * dP.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
        )
print(np.allclose(Dnew1, Dnew5))

True
True
True
True
True
True
True
True
True
True


In [4]:
# a_mode == 'MXX'

A = M
B = G

R = r * .5 * np.dot((A @ B.T).T.ravel(), (A @ B.T).ravel())

P = B @ A @ B.T
D = np.diag(P).reshape(-1, 1) + np.diag(P).reshape(1, -1) - 2 * P

Z = V + V.T
U = np.diag(np.sum(Z, axis=0)) - Z
U1 = U.take(active_all_indx, axis=0).take(active_all_indx, axis=1)

V2 = V.take(active_all_indx, axis=0).take(active_all_indx, axis=1)
Z = V2 + V2.T
U2 = np.diag(np.sum(Z, axis=0)) - Z
print(np.allclose(U1, U2))

gRgA = r * B.T @ A @ B
dRdA = r * A
gLgA = B.T @ U @ B
dLdA = U

B2 = B.take(active_all_indx, axis=0)
gLgA2 = B2.T @ U2 @ B2
dLdA2 = U2
print(np.allclose(gLgA, gLgA2))

gFgA = gRgA + gLgA
dFdA = dRdA + dLdA

phiA = -np.dot((dFdA @ B).T.ravel(), (dFdA @ B).ravel())
phiA2 = -np.dot(gFgA.ravel(), dFdA.ravel())
print(np.allclose(phiA, phiA2))

Anew = A - alpha * dFdA
Anew2 = (1 - alpha * r) * A
Anew2[np.ix_(active_all_indx, active_all_indx)] -= alpha * dLdA2
print(np.allclose(Anew, Anew2))

Rnew = r * .5 * np.dot((Anew @ B.T).T.ravel(), (Anew @ B.T).ravel())
Pnew = B @ Anew @ B.T

Rnew2 = (1 - alpha * r) ** 2 * R - r * ((1 - alpha * r) * alpha * np.dot((A @ B2.T).T.ravel(),
                                                                         (dLdA2 @ B2).ravel())
                                        - .5 * alpha ** 2 * np.dot((dLdA2 @ B2.take(active_all_indx, axis=1)).T.ravel(),
                                                                   (dLdA2 @ B2.take(active_all_indx, axis=1)).ravel()))
print(np.allclose(Rnew, Rnew2))

Pnew2 = (1 - alpha * r) * P - alpha * B2.T @ dLdA2 @ B2
print(np.allclose(Pnew, Pnew2))

Dnew = np.diag(Pnew).reshape(-1, 1) + np.diag(Pnew).reshape(1, -1) - 2 * Pnew
Dnew1 = Dnew.take(active_row_indx, axis=0).take(active_col_indx, axis=1)

Dnew2 = (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(n), n), axis=1) *
         B.take(active_row_indx, axis=0).take(np.tile(np.arange(n), n), axis=1) @
         Anew.reshape(n ** 2, 1) +
         (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(n), n), axis=1) *
          B.take(active_col_indx, axis=0).take(np.tile(np.arange(n), n), axis=1) @
          Anew.reshape(n ** 2, 1)).T -
         2 * B.take(active_row_indx, axis=0) @ Anew @ B.take(active_col_indx, axis=0).T
        )
print(np.allclose(Dnew1, Dnew2))

m = active_all_indx.size

Dnew3 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (B2.T.take(active_row_indx, axis=0).take(np.repeat(np.arange(m), m), axis=1) *
                  B2.T.take(active_row_indx, axis=0).take(np.tile(np.arange(m), m), axis=1) @
                  dLdA2.reshape(m ** 2, 1) +
                  (B2.T.take(active_col_indx, axis=0).take(np.repeat(np.arange(m), m), axis=1) *
                   B2.T.take(active_col_indx, axis=0).take(np.tile(np.arange(m), m), axis=1) @
                   dLdA2.reshape(m ** 2, 1)).T -
                  2 * B2.T.take(active_row_indx, axis=0) @ dLdA2 @ B2.T.take(active_col_indx, axis=0).T)
        )
print(np.allclose(Dnew1, Dnew3))

active_all_dict = {k: v for v, k in enumerate(active_all_indx)}
select_row_indx = [active_all_dict[i] for i in active_row_indx]
select_col_indx = [active_all_dict[i] for i in active_col_indx]

Pnew4 = B2 @ Anew @ B2.T
Dnew4 = (np.diag(Pnew4).take(select_row_indx).reshape(-1, 1) + 
         np.diag(Pnew4).take(select_col_indx).reshape(1, -1) - 
         2 * Pnew4.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
print(np.allclose(Dnew1, Dnew4))

dP = B2.take(active_all_indx, axis=1) @ dLdA2 @ B2.take(active_all_indx, axis=1).T
Dnew5 = ((1 - alpha * r) * D.take(active_row_indx, axis=0).take(active_col_indx, axis=1) -
         alpha * (np.diag(dP).take(select_row_indx).reshape(-1, 1) +
                  np.diag(dP).take(select_col_indx).reshape(1, -1) -
                  2 * dP.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
        )
print(np.allclose(Dnew1, Dnew5))

True
True
True
True
True
True
True
True
True
True


In [5]:
# a_mode == 'MG'

A = M
B = G

R = r * .5 * np.dot((A @ B.T).T.ravel(), (A @ B.T).ravel())

P = B @ A @ B.T
D = np.diag(P).reshape(-1, 1) + np.diag(P).reshape(1, -1) - 2 * P

Z = V + V.T
U = np.diag(np.sum(Z, axis=0)) - Z
U1 = U.take(active_all_indx, axis=0).take(active_all_indx, axis=1)

V2 = V.take(active_all_indx, axis=0).take(active_all_indx, axis=1)
Z = V2 + V2.T
U2 = np.diag(np.sum(Z, axis=0)) - Z
print(np.allclose(U1, U2))

gRgA = r * B.T @ A @ B
dRdA = r * B.T @ A @ B
gLgA = B.T @ U @ B
dLdA = B.T @ U @ B

B2 = B.take(active_all_indx, axis=0)
gLgA2 = B2.T @ U2 @ B2
dLdA2 = B2.T @ U2 @ B2
print(np.allclose(gLgA, gLgA2))

gFgA = gRgA + gLgA
dFdA = dRdA + dLdA

phiA = -np.dot((dFdA).T.ravel(), (dFdA).ravel())
phiA2 = -np.dot(gFgA.ravel(), dFdA.ravel())
print(np.allclose(phiA, phiA2))

Anew = A - alpha * dFdA
Anew2 = A - alpha * (dRdA + dLdA2)
print(np.allclose(Anew, Anew2))

Rnew = r * .5 * np.dot((Anew @ B.T).T.ravel(), (Anew @ B.T).ravel())
Pnew = B @ Anew @ B.T

Rnew2 = Rnew
print(np.allclose(Rnew, Rnew2))

Pnew2 = Pnew
print(np.allclose(Pnew, Pnew2))

Dnew = np.diag(Pnew).reshape(-1, 1) + np.diag(Pnew).reshape(1, -1) - 2 * Pnew
Dnew1 = Dnew.take(active_row_indx, axis=0).take(active_col_indx, axis=1)

Dnew2 = (B.take(active_row_indx, axis=0).take(np.repeat(np.arange(n), n), axis=1) *
         B.take(active_row_indx, axis=0).take(np.tile(np.arange(n), n), axis=1) @
         Anew.reshape(n ** 2, 1) +
         (B.take(active_col_indx, axis=0).take(np.repeat(np.arange(n), n), axis=1) *
          B.take(active_col_indx, axis=0).take(np.tile(np.arange(n), n), axis=1) @
          Anew.reshape(n ** 2, 1)).T -
         2 * B.take(active_row_indx, axis=0) @ Anew @ B.take(active_col_indx, axis=0).T
        )
print(np.allclose(Dnew1, Dnew2))

Dnew3 = Dnew1
print(np.allclose(Dnew1, Dnew3))

active_all_dict = {k: v for v, k in enumerate(active_all_indx)}
select_row_indx = [active_all_dict[i] for i in active_row_indx]
select_col_indx = [active_all_dict[i] for i in active_col_indx]

Pnew4 = B2 @ Anew @ B2.T
Dnew4 = (np.diag(Pnew4).take(select_row_indx).reshape(-1, 1) + 
         np.diag(Pnew4).take(select_col_indx).reshape(1, -1) - 
         2 * Pnew4.take(select_row_indx, axis=0).take(select_col_indx, axis=1))
print(np.allclose(Dnew1, Dnew4))

Dnew5 = Dnew1
print(np.allclose(Dnew1, Dnew5))

True
True
True
True
True
True
True
True
True
True


In [None]:
d = 8 #4096

W = np.random.randn(d, d)
W = ((W) * 10).astype(int) / 10
W = W @ W.T
W /= np.linalg.norm(W)

w = np.diag(W)

G = np.random.randn(d, d)
G = ((G) * 10).astype(int) / 10
G = G @ G.T
G /= np.linalg.norm(G)

In [None]:
# linear transform - 2-norm - full matrix - loss

L1 = .5 * np.trace(W.T @ W)
L2 = .5 * np.sum(np.square(W))
L3 = .5 * np.square(np.linalg.norm(W))
L4 = .5 * np.dot(W.ravel(), W.ravel())
L5 = .5 * np.tensordot(W, W)
print(np.array([L1, L2, L3, L4, L5]).reshape((-1, 1)))

%timeit L1 = .5 * np.trace(W.T @ W)
%timeit L2 = .5 * np.sum(np.square(W))
%timeit L3 = .5 * np.square(np.linalg.norm(W)) # <- about the same
%timeit L4 = .5 * np.dot(W.ravel(), W.ravel()) # <- fastest
%timeit L5 = .5 * np.tensordot(W, W)           # <- about the same

In [None]:
# linear transform - 2-norm - full matrix - gradient

%timeit D = W

In [None]:
# linear transform - 2-norm - diagonal - loss

L1 = .5 * np.sum(np.square(w))
L2 = .5 * np.square(np.linalg.norm(w))
L3 = .5 * np.dot(w, w)
L4 = .5 * np.tensordot(w, w, 1)
print(np.array([L1, L2, L3, L4]).reshape((-1, 1)))

%timeit L1 = .5 * np.sum(np.square(w))
%timeit L2 = .5 * np.square(np.linalg.norm(w))
%timeit L3 = .5 * np.dot(w, w)                  # <- fastest
%timeit L4 = .5 * np.tensordot(w, w, 1)


In [None]:
# linear transform - 2-norm - diagonal - gradient

%timeit D = w

In [None]:
# linear transform - 1-norm - full matrix - loss

L1 = np.sum(np.abs(W))
L2 = np.linalg.norm(W.ravel(), 1)
print(np.array([L1, L2]).reshape((-1, 1)))

%timeit L1 = np.sum(np.abs(W))            # <- same speed, more concise
%timeit L2 = np.linalg.norm(W.ravel(), 1)

In [None]:
# linear transform - 1-norm - full matrix - gradient

%timeit D = np.sign(W)

In [None]:
# linear transform - 1-norm - diagonal - loss

L1 = np.sum(np.abs(w))
L2 = np.linalg.norm(w, 1)
print(np.array([L1, L2]).reshape((-1, 1)))

%timeit L1 = np.sum(np.abs(w))    # <- same speed, more concise
%timeit L2 = np.linalg.norm(w, 1) 

In [None]:
# linear transform - 1-norm - diagonal - gradient

%timeit D = np.sign(w)

In [None]:
# nonlinear transform - 2-norm - full matrix - loss

L1 = .5 * np.trace(W @ G @ W @ G)
A = W @ G; L2 = .5 * np.dot(A.ravel(), A.ravel('F'))
A = W @ G; L3 = .5 * np.dot(A.ravel(), A.T.ravel())
A = W @ G; L4 = .5 * np.tensordot(A, A.T)
L5 = .5 * np.dot((W @ G).ravel(), (W @ G).ravel('F'))
L6 = .5 * np.dot((W @ G).ravel(), (W @ G).T.ravel())
L7 = .5 * np.tensordot((W @ G), (W @ G).T)
L8 = .5 * np.dot((W @ G).ravel(), (G @ W).ravel())
L9 = .5 * np.tensordot((W @ G), (G @ W))
print(np.array([L1, L2, L3, L4, L5, L6, L7, L8, L9]).reshape((-1, 1)))

%timeit L1 = .5 * np.trace(W @ G @ W @ G)
%timeit A = W @ G; L2 = .5 * np.dot(A.ravel(), A.ravel('F')) 
%timeit A = W @ G; L3 = .5 * np.dot(A.ravel(), A.T.ravel())  # <- fastest
%timeit A = W @ G; L4 = .5 * np.tensordot(A, A.T)            
%timeit L5 = .5 * np.dot((W @ G).ravel(), (W @ G).ravel('F'))
%timeit L6 = .5 * np.dot((W @ G).ravel(), (W @ G).T.ravel())
%timeit L7 = .5 * np.tensordot((W @ G), (W @ G).T)
%timeit L8 = .5 * np.dot((W @ G).ravel(), (G @ W).ravel())
%timeit L9 = .5 * np.tensordot((W @ G), (G @ W))             # <- slower, no assignment

In [None]:
# nonlinear transform - 2-norm - full matrix - gradient

%timeit D = G @ W @ G

In [None]:
# nonlinear transform - 2-norm - diagonal - loss

L1 = .5 * np.trace(np.diag(w) @ G @ np.diag(w) @ G)
A = w * G; L2 = .5 * np.dot(A.ravel(), A.ravel('F'))
A = w * G; L3 = .5 * np.dot(A.ravel(), A.T.ravel())
A = w * G; L4 = .5 * np.tensordot(A, A.T)
L5 = .5 * np.tensordot(w * G, w[:, None] * G)
L6 = .5 * np.dot((w * G).ravel(), (w * G).T.ravel())
L7 = .5 * np.sum(np.outer(w, w) * np.square(G))
L8 = .5 * np.sum(w * np.square(G) * w[:, None])
L9 = .5 * np.tensordot(np.outer(w, w), np.square(G))
L10 = .5 * np.dot(np.outer(w, w).ravel(), np.square(G).ravel())
print(np.array([L1, L2, L3, L4, L5, L6, L7, L8, L9, L10]).reshape((-1, 1)))

%timeit L1 = .5 * np.trace(np.diag(w) @ G @ np.diag(w) @ G)
%timeit A = w * G; L2 = .5 * np.dot(A.ravel(), A.ravel('F'))
%timeit A = w * G; L3 = .5 * np.dot(A.ravel(), A.T.ravel())
%timeit A = w * G; L4 = .5 * np.tensordot(A, A.T)
%timeit L5 = .5 * np.tensordot(w * G, w[:, None] * G)
%timeit L6 = .5 * np.dot((w * G).ravel(), (w * G).T.ravel())
%timeit L7 = .5 * np.sum(np.outer(w, w) * np.square(G))
%timeit L8 = .5 * np.sum(w * np.square(G) * w[:, None])
%timeit L9 = .5 * np.tensordot(np.outer(w, w), np.square(G))            # <- about the same
%timeit L10 = .5 * np.dot(np.outer(w, w).ravel(), np.square(G).ravel()) # <- fastest

In [None]:
# nonlinear transform - 2-norm - diagonal - gradient

D1 = np.diag(G @ np.diag(w) @ G)
D2 = np.square(G) @ w
print(np.stack((D1, D2)))

%timeit D1 = np.diag(G @ np.diag(w) @ G)
%timeit D2 = np.square(G) @ w            # <- fastest

In [None]:
# nonlinear transform - 1-norm - full matrix - loss

L1 = np.sum(np.abs(G @ W @ G))
L2 = np.linalg.norm((G @ W @ G).ravel(), 1)
print(np.array([L1, L2]).reshape((-1, 1)))

%timeit L1 = np.sum(np.abs(G @ W @ G))              # <- more concise
%timeit L2 = np.linalg.norm((G @ W @ G).ravel(), 1) # <- faster

In [None]:
# nonlinear transform - 1-norm - full matrix - gradient

%timeit D = G @ np.sign(G @ W @ G) @ G

In [None]:
# nonlinear transform - 1-norm - diagonal - loss

L1 = np.sum(np.abs(G @ np.diag(w) @ G))
L2 = np.sum(np.abs((w * G) @ G))
L3 = np.linalg.norm(((w * G) @ G).ravel(), 1)
print(np.array([L1, L2, L3]).reshape((-1, 1)))

%timeit L1 = np.sum(np.abs(G @ np.diag(w) @ G))
%timeit L2 = np.sum(np.abs((w * G) @ G))              # <- more concise
%timeit L3 = np.linalg.norm(((w * G) @ G).ravel(), 1) # <- faster

In [None]:
# nonlinear transform - 1-norm - diagonal - gradient

D1 = np.diag(G @ np.sign(G @ np.diag(w) @ G) @ G)
D2 = np.sum((G @ np.sign((w * G) @ G)) * G, axis=1)
D3 = np.sum(G * (np.sign((w * G) @ G) @ G), axis=0)
print(np.stack((D1, D2, D3)))

%timeit D1 = np.diag(G @ np.sign(G @ np.diag(w) @ G) @ G)
%timeit D2 = np.sum((G @ np.sign((w * G) @ G)) * G, axis=1)   # <- same speed
%timeit D3 = np.sum(G * (np.sign((w * G) @ G) @ G), axis=0)   # <- same speed



In [None]:
# nonlinear transform - 1-norm - non-convex - loss

L1 = np.sum(np.abs(W @ G @ W.T))
L2 = np.linalg.norm((W @ G @ W.T).ravel(), 1)
print(np.array([L1, L2]).reshape((-1, 1)))

%timeit L1 = np.sum(np.abs(W @ G @ W.T))              # <- more concise
%timeit L2 = np.linalg.norm((W @ G @ W.T).ravel(), 1) # <- faster

In [None]:
# nonlinear transform - 1-norm - non-convex - gradient

%timeit D1 = np.sign(W @ G @ W.T) @ W @ G

In [None]:
%%timeit
# linear transform - 2-norm - full matrix

L = .5 * np.dot(W.ravel(), W.ravel())
D = W

In [None]:
%%timeit
# linear transform - 2-norm - diagonal

L = .5 * np.dot(w, w)
D = w

In [None]:
%%timeit
# linear transform - 1-norm - full matrix

L = np.linalg.norm(W.ravel(), 1)
D = np.sign(W)

In [None]:
%%timeit
# linear transform - 1-norm - diagonal

L = np.linalg.norm(w, 1)
D = np.sign(w)

In [None]:
%%timeit
# nonlinear transform - 2-norm - full matrix

A = W @ G
L = .5 * np.dot(A.ravel(), A.T.ravel())
D = G @ A

In [None]:
%%timeit
# nonlinear transform - 2-norm - diagonal

A = np.square(G)
L = .5 * np.dot(np.outer(w, w).ravel(), A.ravel())
D = A @ w

In [None]:
%%timeit
# nonlinear transform - 1-norm - full matrix

A = G @ W @ G
L = np.linalg.norm(A.ravel(), 1)
D = G @ np.sign(A) @ G

In [None]:
%%timeit
# nonlinear transform - 1-norm - diagonal

A = (w * G) @ G
L = np.linalg.norm(A.ravel(), 1)
D = np.sum(G @ np.sign(A) * G, axis=1)

In [None]:
%%timeit
# nonlinear transform - 1-norm - non-convex

A = W @ G @ W.T
L = np.linalg.norm(A.ravel(), 1)
D = np.sign(A) @ W @ G

In [9]:
A = np.arange(4).reshape(2,2)
B = np.outer(A.ravel(), A.ravel()).reshape(2,2,2,2)
A[1,1] * A[1,0], B[1,1,1,0]

A = np.arange(8).reshape(4,2)
B = np.arange(4).reshape(2,2)
B = B + B.T
C = A @ B @ A.T
E = np.zeros((16,4))
E[0,:] = np.outer(A[0,:],A[0,:]).ravel()
E[1,:] = np.outer(A[0,:],A[1,:]).ravel()
E[2,:] = np.outer(A[0,:],A[2,:]).ravel()
E[3,:] = np.outer(A[0,:],A[3,:]).ravel()
E[4,:] = np.outer(A[1,:],A[0,:]).ravel()
E[5,:] = np.outer(A[1,:],A[1,:]).ravel()
E[6,:] = np.outer(A[1,:],A[2,:]).ravel()
E[7,:] = np.outer(A[1,:],A[3,:]).ravel()
E[8,:] = np.outer(A[2,:],A[0,:]).ravel()
E[9,:] = np.outer(A[2,:],A[1,:]).ravel()
E[10,:] = np.outer(A[2,:],A[2,:]).ravel()
E[11,:] = np.outer(A[2,:],A[3,:]).ravel()
E[12,:] = np.outer(A[3,:],A[0,:]).ravel()
E[13,:] = np.outer(A[3,:],A[1,:]).ravel()
E[14,:] = np.outer(A[3,:],A[2,:]).ravel()
E[15,:] = np.outer(A[3,:],A[3,:]).ravel()
D = (np.dot(A.reshape(4,2,1), A.reshape(4,1,2)).transpose(0,2,1,3).reshape(16,4) @ B.reshape(4,1)).reshape(4,4)
F = (E @ B.reshape(4,1)).reshape(4,4)
C, np.int_(D), np.int_(F)
C, A[0:2,np.repeat(np.arange(2),2)] * A[0:2,np.tile(np.arange(2),2)] @ B.reshape(4,1)
#C, (np.dot(A[[0, 2], :].reshape(2,2,1), A[[0, 1, 3], :].reshape(3,1,2)).transpose(0,2,1,3).reshape(6,4) @ B.reshape(4,1)).reshape(2,3)



(array([[  6,  24,  42,  60],
        [ 24,  90, 156, 222],
        [ 42, 156, 270, 384],
        [ 60, 222, 384, 546]]),
 array([[ 6],
        [90]]))

In [8]:
n = 7
d = 4
r = np.array([0, 2])
c = np.array([0, 2, 3])

A = np.int_(np.arange(n * d).reshape(n, d) ** (8 / 7))
B = np.arange(d ** 2).reshape(d, d)
B = B + B.T + 1
print(A)
print(B)

P = A @ B @ A.T
D = P.diagonal()[:, None]
R = D + D.T - 2 * P
R1 = R[r][:, c]
print(R)
print(R1)

R2 = (A[r][:, np.repeat(np.arange(d), d)] * A[r][:, np.tile(np.arange(d), d)] @ B.reshape(d ** 2, 1) +
      (A[c][:, np.repeat(np.arange(d), d)] * A[c][:, np.tile(np.arange(d), d)] @ B.reshape(d ** 2, 1)).T -
      2 * A[r] @ B @ A[c].T
     )
print(R2)

R3 = (A[r][:, np.repeat(np.arange(d), d)] * A[r][:, np.tile(np.arange(d), d)] @ B.reshape(d ** 2, 1) +
      (A[c][:, np.repeat(np.arange(d), d)] * A[c][:, np.tile(np.arange(d), d)] @ B.reshape(d ** 2, 1)).T -
      2 * (np.dot(A[r].reshape(r.size, d , 1),
                  A[c].reshape(c.size, 1 , d)).transpose(0,2,1,3).reshape(r.size * c.size, d ** 2)
           @ B.reshape(d ** 2, 1)).reshape(r.size, c.size)
     )
print(R3)

[[ 0  1  2  3]
 [ 4  6  7  9]
 [10 12 13 15]
 [17 18 20 22]
 [23 25 27 28]
 [30 32 34 35]
 [37 39 41 43]]
[[ 1  6 11 16]
 [ 6 11 16 21]
 [11 16 21 26]
 [16 21 26 31]]
[[     0   7000  32296  83141 153939 254375 387156]
 [  7000      0   9216  41871  95249 176925 289976]
 [ 32296   9216      0  11799  45209 105381 195800]
 [ 83141  41871  11799      0  10816  46656 111469]
 [153939  95249  45209  10816      0  12544  52839]
 [254375 176925 105381  46656  12544      0  13891]
 [387156 289976 195800 111469  52839  13891      0]]
[[    0 32296 83141]
 [32296     0 11799]]
[[    0 32296 83141]
 [32296     0 11799]]
[[    0 32296 83141]
 [32296     0 11799]]


In [7]:
n = 7
d = 4
r = np.array([0, 2])
c = np.array([0, 2, 3])

A = np.int_(np.arange(n * d).reshape(n, d) ** (8 / 7))
B = np.arange(n ** 2).reshape(n, n)
print(A)
print(B)

R1 = np.zeros((d, d), dtype=int)
for i in range(n):
    for j in range(n):
        C = A[i] - A[j]
        R1 += B[i, j] * np.outer(C, C)
print(R1)

C = A.reshape(n, 1, d) - A.reshape(1, n, d)
R2 = np.tensordot(B, C.reshape(n, n, d, 1) @ C.reshape(n, n, 1, d))
print(R2)

C = B + B.T
R3 = A.T @ (np.diag(np.sum(C, axis=0)) - C) @ A
print(R3)

[[ 0  1  2  3]
 [ 4  6  7  9]
 [10 12 13 15]
 [17 18 20 22]
 [23 25 27 28]
 [30 32 34 35]
 [37 39 41 43]]
[[ 0  1  2  3  4  5  6]
 [ 7  8  9 10 11 12 13]
 [14 15 16 17 18 19 20]
 [21 22 23 24 25 26 27]
 [28 29 30 31 32 33 34]
 [35 36 37 38 39 40 41]
 [42 43 44 45 46 47 48]]
[[385728 391544 403688 402880]
 [391544 397768 409960 409104]
 [403688 409960 422680 421704]
 [402880 409104 421704 421304]]
[[385728 391544 403688 402880]
 [391544 397768 409960 409104]
 [403688 409960 422680 421704]
 [402880 409104 421704 421304]]
[[385728 391544 403688 402880]
 [391544 397768 409960 409104]
 [403688 409960 422680 421704]
 [402880 409104 421704 421304]]


In [6]:
I = np.array([True, False, False, True, False])
A = np.random.randint(0,10,(5,5)) - 4.5
A = A + A.T
B = np.random.randint(0,10,(5,5)) - 4.5
B = B + B.T
V = np.random.randint(0,10,(5,5)) - 4.5
V[~I, :] = 0
V[:, ~I] = 0
U = np.diag(np.sum(V + V.T, axis=0)) - (V + V.T)

a = .1

R1 = .5 * np.trace(A @ B @ A @ B)
R2 = .5 * np.dot((A @ B).T.ravel(), (A @ B).ravel())

print(R1, R2)

M = A + a * A + U

F1 = .5 * np.trace(M @ B @ M @ B)
F2 = .5 * np.dot((M @ B).T.ravel(), (M @ B).ravel())

print(F1, F2)

C1 = (1 + a) ** 2 * R1 + (1 + a) * np.trace(A @ B @ U @ B) + .5 * np.trace(U @ B @ U @ B)
C2 = (1 + a) ** 2 * R2 + (1 + a) * np.dot((A @ B).T.ravel(), (U @ B).ravel()) + .5 * np.dot((U @ B).T.ravel(), (U @ B).ravel())

print(C1, C2)

D1 = ((1 + a) ** 2 * R2 +
     (1 + a) * np.dot((A @ B[:, I]).T.ravel(), (U[I, :][:, I] @ B[I, :]).ravel()) +
     .5 * np.dot((U[I, :][:, I] @ B[I, :][:, I]).T.ravel(), (U[I, :][:, I] @ B[I, :][:, I]).ravel()))

BS = B[I]
CS = A @ BS.T
US = U[I, :][:, I]
UB = US @ BS

D2 = ((1 + a) ** 2 * R2 +
      (1 + a) * np.dot(CS.T.ravel(), UB.ravel()) +
      .5 * np.dot(UB[:, I].T.ravel(), UB[:, I].ravel()))

J = np.nonzero(I)[0]
BS = B.take(J, axis=0)
CS = A @ BS.T
US = U.take(J, axis=0).take(J, axis=1)
UB = US @ BS

D3 = ((1 + a) ** 2 * R2 +
      (1 + a) * np.dot(CS.T.ravel(), UB.ravel()) +
      .5 * np.dot(UB.take(J, axis=1).T.ravel(), UB.take(J, axis=1).ravel()))

print(D1, D2, D3)


40994.5 40994.5
49603.344999999994 49603.345
49603.34500000001 49603.34500000001
49603.34500000001 49603.34500000001 49603.34500000001
