In [1]:
import torch
import numpy as np
import micca_model
import simulate_model_gaussian
torch.set_printoptions(linewidth = 200)

In [2]:
def sqrtm(X):
    U, l, VT = torch.linalg.svd(X)
    return((U * torch.sqrt(l)) @ VT)

In [8]:
n = 1000
p = [3, 3, 3]
psum = np.concatenate([[0], np.cumsum(p, 0)]) 
d = 2
k = None

In [9]:
model = simulate_model_gaussian.generate_model(p, k, d)
data = model.simulate(n)
Y_all = torch.cat(data.Y, dim = 1)
Y_all = Y_all - torch.mean(Y_all, 0)
Sigma_tilde = (Y_all.T @ Y_all) / (n - 1)
W = torch.cat(model.W, dim = 1)
Phi = torch.block_diag(*model.Phi)

In [27]:
Y = [Y_m - torch.mean(Y_m, 0) for Y_m in data.Y]
S11 = (Y[0].T @ Y[0])/(n-1)
S22 = (Y[1].T @ Y[1])/(n-1)
S33 = (Y[2].T @ Y[2])/(n-1)
S12 = (Y[0].T @ Y[1])/(n-1)
S13 = (Y[0].T @ Y[2])/(n-1)
S23 = (Y[1].T @ Y[2])/(n-1)
S11_inv2 = sqrtm(torch.linalg.pinv(S11, rcond = 1e-06))
S22_inv2 = sqrtm(torch.linalg.pinv(S22, rcond = 1e-06))
S33_inv2 = sqrtm(torch.linalg.pinv(S33, rcond = 1e-06))
C12_tilde = S11_inv2 @ S12 @ S22_inv2
C13_tilde = S11_inv2 @ S13 @ S33_inv2

In [12]:
# Fit 2-D
std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([d, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(5000):
    # W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): 
        print(delta_W, delta_Phi, micca_model.loglik(W_curr.T @ W_curr + Phi_curr, Sigma_tilde, n))
W_hat = W_curr
Phi_hat = Phi_curr

tensor(4.6638) tensor(239.7270) tensor(17477.3555)
tensor(4.9947e-07) tensor(9.5019e-06) tensor(16560.2324)
tensor(8.6580e-08) tensor(1.6055e-06) tensor(16560.0742)
tensor(1.8795e-08) tensor(3.4518e-07) tensor(16560.0449)
tensor(5.2162e-09) tensor(8.9491e-08) tensor(16560.0352)
tensor(1.8878e-09) tensor(2.9314e-08) tensor(16560.0312)
tensor(9.1882e-10) tensor(1.2139e-08) tensor(16560.0293)
tensor(5.6278e-10) tensor(6.1705e-09) tensor(16560.0273)
tensor(4.0663e-10) tensor(3.6414e-09) tensor(16560.0273)
tensor(3.1441e-10) tensor(2.6928e-09) tensor(16560.0273)
tensor(2.5285e-10) tensor(1.8450e-09) tensor(16560.0254)
tensor(2.0248e-10) tensor(1.4001e-09) tensor(16560.0254)
tensor(1.5084e-10) tensor(1.1128e-09) tensor(16560.0254)
tensor(1.2883e-10) tensor(9.7860e-10) tensor(16560.0273)
tensor(1.0562e-10) tensor(6.8894e-10) tensor(16560.0254)
tensor(8.9485e-11) tensor(6.4739e-10) tensor(16560.0234)
tensor(6.8341e-11) tensor(5.8229e-10) tensor(16560.0234)
tensor(5.6238e-11) tensor(3.6079e-10)

In [13]:
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]

W1_tilde = W1 @ S11_inv2
W2_tilde = W2 @ S22_inv2
W3_tilde = W3 @ S33_inv2

# The original conjecture in terms of SVD of W_tilde
B1, S1, A1T = torch.linalg.svd(W1_tilde, full_matrices=False)
B2, S2, A2T = torch.linalg.svd(W2_tilde, full_matrices=False)
B3, S3, A3T = torch.linalg.svd(W3_tilde, full_matrices=False)
F1 = S11_inv2 @ A1T.T
F2 = S22_inv2 @ A2T.T
F3 = S33_inv2 @ A3T.T

model_cov_1 = torch.cat([torch.cat([S11, W1.T @ W2, W1.T @ W3], axis = 1),
                         torch.cat([W2.T @ W1, S22, W2.T @ W3], axis = 1),
                         torch.cat([W3.T @ W1, W3.T @ W2, S33], axis = 1)], axis = 0)
model_cov_2 = torch.cat([torch.cat([torch.eye(p[0]), W1_tilde.T @ W2_tilde, W1_tilde.T @ W3_tilde], axis = 1),
                         torch.cat([W2_tilde.T @ W1_tilde, torch.eye(p[1]), W2_tilde.T @ W3_tilde], axis = 1),
                         torch.cat([W3_tilde.T @ W1_tilde, W3_tilde.T @ W2_tilde, torch.eye(p[2])], axis = 1)], axis = 0)
gen_var = torch.cat([torch.cat([F1.T @ S11 @ F1, F1.T @ S12 @ F2, F1.T @ S13 @ F3], axis = 1),
                     torch.cat([(F1.T @ S12 @ F2).T, F2.T @ S22 @ F2, F2.T @ S23 @ F3], axis = 1),
                     torch.cat([(F1.T @ S13 @ F3).T, (F2.T @ S23 @ F3).T, F3.T @ S33 @ F3], axis = 1)], axis = 0)
gen_var_2 = torch.cat([torch.cat([torch.eye(p[0]), A1T.T @ F1.T @ S12 @ F2 @ A2T, A1T.T @ F1.T @ S13 @ F3 @ A3T], axis = 1),
                       torch.cat([(A1T.T @ F1.T @ S12 @ F2 @ A2T).T, torch.eye(p[1]), A2T.T @ F2.T @ S23 @ F3 @ A3T], axis = 1),
                       torch.cat([(A1T.T @ F1.T @ S13 @ F3 @ A3T).T, (A2T.T @ F2.T @ S23 @ F3 @ A3T).T, torch.eye(p[1])], axis = 1)], axis = 0)

D_half = torch.block_diag(sqrtm(S11), sqrtm(S22), sqrtm(S33))

Sigma_alt = D_half @ gen_var_2 @ D_half
print(micca_model.loglik(model_cov_1, Sigma_tilde, n))
print(micca_model.loglik(Sigma_alt, Sigma_tilde, n))
# print(micca_model.loglik(W_ml.T @ W_ml + Phi_ml, Sigma_tilde, n))

# a1 = (W1_tilde.T/torch.sqrt(torch.sum(W1_tilde**2, 1)))
# a2 = (W2_tilde.T/torch.sqrt(torch.sum(W2_tilde**2, 1)))
# a3 = (W3_tilde.T/torch.sqrt(torch.sum(W3_tilde**2, 1)))
# f1 = S11_inv2 @ a1
# f2 = S22_inv2 @ a2
# f3 = S33_inv2 @ a3

# Phi_test = torch.cat([torch.cat([f1.T @ S11 @ f1, f1.T @ S12 @ f2, f1.T @ S13 @ f3], axis = 1),
#                  torch.cat([(f1.T @ S12 @ f2).T, f2.T @ S22 @ f2, f2.T @ S23 @ f3], axis = 1),
#                  torch.cat([(f1.T @ S13 @ f3).T, (f2.T @ S23 @ f3).T, f3.T @ S33 @ f3], axis = 1)], axis = 0)
# print(torch.logdet(Phi_test))

# f1_1 = F1[:, 0:1]
# f2_1 = F2[:, 0:1]
# f3_1 = F3[:, 0:1]
# f1_2 = F1[:, 1:2]
# f2_2 = F2[:, 1:2]
# f3_2 = F3[:, 0:1]

# Phi_1 = torch.tensor([[f1_1.T @ S11 @ f1_1, f1_1.T @ S12 @ f2_1, f1_1.T @ S13 @ f3_1],
#                       [f1_1.T @ S12 @ f2_1, f2_1.T @ S22 @ f2_1, f2_1.T @ S23 @ f3_1],
#                       [f1_1.T @ S13 @ f3_1, f2_1.T @ S23 @ f3_1, f3_1.T @ S33 @ f3_1]])
# Phi_2 = torch.tensor([[f1_2.T @ S11 @ f1_2, f1_2.T @ S12 @ f2_2, f1_2.T @ S13 @ f3_2],
#                       [f1_2.T @ S12 @ f2_2, f2_2.T @ S22 @ f2_2, f2_2.T @ S23 @ f3_2],
#                       [f1_2.T @ S13 @ f3_2, f2_2.T @ S23 @ f3_2, f3_2.T @ S33 @ f3_2]])
# print(torch.logdet(Phi_1) + torch.logdet(Phi_2))

tensor(16560.0234)
tensor(16559.8652)


In [14]:
X12 = F1.T @ S12 @ F2
X13 = F1.T @ S13 @ F3
X23 = F2.T @ S23 @ F3

print(X13 @ torch.linalg.inv(X23) @ X12.T)
print(S13 @ torch.linalg.inv(S23) @ S12.T)

tensor([[ 0.7558,  0.0086],
        [-0.0226,  0.0446]])
tensor([[ 0.6551,  2.1239, -0.5474],
        [ 1.6988,  6.5813, -2.0663],
        [-0.3463, -2.2200,  1.1871]])


In [15]:
W1_tilde.T @ W2_tilde

tensor([[-0.1758, -0.0060,  0.0444],
        [-0.6352, -0.0407,  0.0311],
        [ 0.0767,  0.0270,  0.1450]])

In [19]:
A1T.T @ F1.T @ S12 @ F2 @ A2T

tensor([[-0.1768, -0.0059,  0.0453],
        [-0.6349, -0.0408,  0.0308],
        [ 0.0723,  0.0273,  0.1490]])

In [25]:
X12 = (B1*S1).T @ (B2*S2)
X13 = (B1*S1).T @ (B3*S3)
X23 = (B2*S2).T @ (B3*S3)
print(X13 @ torch.linalg.inv(X23) @ X12.T)

tensor([[ 7.6278e-01, -3.7253e-08],
        [-2.9802e-08,  3.8986e-02]])


In [30]:
(C12_tilde - W1_tilde.T @ W2_tilde) @ W2_tilde.T @ torch.linalg.inv(1 - W2_tilde @ W2_tilde.T) + (C13_tilde - W1_tilde.T @ W3_tilde) @ W3_tilde.T @ torch.linalg.inv(1 - W3_tilde @ W3_tilde.T) 

tensor([[-2.1039e-02,  1.9578e-03],
        [ 5.2834e-03,  4.2273e-06],
        [-5.4943e-03,  7.0994e-03]])

In [13]:
# Fit 1-D once
std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(10000):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p, rcond=None)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1000 == 0): print(delta_W, delta_Phi)
W_hat = W_curr
Phi_hat = Phi_curr
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]

# Project out learned space and fit again
W1_tilde = W1 @ S11_inv2
W2_tilde = W2 @ S22_inv2
W3_tilde = W3 @ S33_inv2
a1 = (W1_tilde / torch.sqrt(torch.sum(W1_tilde**2))).T
a2 = (W2_tilde / torch.sqrt(torch.sum(W2_tilde**2))).T
a3 = (W3_tilde / torch.sqrt(torch.sum(W3_tilde**2))).T
f1_1 = S11_inv2 @ a1
f2_1 = S22_inv2 @ a2
f3_1 = S33_inv2 @ a3
P1 = torch.eye(p[0]) - f1_1 @ f1_1.T @ S11
P2 = torch.eye(p[1]) - f2_1 @ f2_1.T @ S22
P3 = torch.eye(p[2]) - f3_1 @ f3_1.T @ S33
Y_mod = [Y[0] @ P1, Y[1] @ P2, Y[2] @ P3]
Y_mod = [Y - torch.mean(Y, 0) for Y in Y_mod]
Y_mod_all = torch.cat(Y_mod, dim = 1)
S11_mod = (Y_mod[0].T @ Y_mod[0])/(n-1)
S22_mod = (Y_mod[1].T @ Y_mod[1])/(n-1)
S33_mod = (Y_mod[2].T @ Y_mod[2])/(n-1)
S12_mod = (Y_mod[0].T @ Y_mod[1])/(n-1)
S13_mod = (Y_mod[0].T @ Y_mod[2])/(n-1)
S23_mod = (Y_mod[1].T @ Y_mod[2])/(n-1)
S11_inv2_mod = sqrtm(torch.linalg.pinv(S11_mod, rcond = 1e-06))
S22_inv2_mod = sqrtm(torch.linalg.pinv(S22_mod, rcond = 1e-06))
S33_inv2_mod = sqrtm(torch.linalg.pinv(S33_mod, rcond = 1e-06))
Sigma_tilde_mod = (Y_mod_all.T @ Y_mod_all) / (n - 1)

std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(10000):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_mod_all, Sigma_tilde, p = p, rcond = None)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1000 == 0): print(delta_W, delta_Phi)
W_hat = W_curr
Phi_hat = Phi_curr
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]
W1_tilde = W1 @ S11_inv2_mod
W2_tilde = W2 @ S22_inv2_mod
W3_tilde = W3 @ S33_inv2_mod
a1 = (W1_tilde / torch.sqrt(torch.sum(W1_tilde**2))).T
a2 = (W2_tilde / torch.sqrt(torch.sum(W2_tilde**2))).T
a3 = (W3_tilde / torch.sqrt(torch.sum(W3_tilde**2))).T
f1_2 = S11_inv2_mod @ a1
f2_2 = S22_inv2_mod @ a2
f3_2 = S33_inv2_mod @ a3

tensor(5.6074) tensor(198.7069)
tensor(2.5940e-13) tensor(2.6059e-12)
tensor(4.0307e-13) tensor(3.2738e-12)
tensor(1.4144e-13) tensor(2.1565e-12)
tensor(2.6096e-13) tensor(3.6895e-12)
tensor(1.7221e-13) tensor(3.4053e-12)
tensor(2.7124e-13) tensor(2.6361e-12)
tensor(5.7988e-13) tensor(3.0287e-12)
tensor(8.5998e-13) tensor(6.4349e-12)
tensor(1.6333e-13) tensor(1.0418e-12)
tensor(1.8658) tensor(190.4414)
tensor(1.0515e-13) tensor(5.3042e-12)
tensor(7.8618e-14) tensor(1.2399e-12)
tensor(1.1841e-13) tensor(1.7657e-12)
tensor(3.3973e-13) tensor(3.0784e-12)
tensor(1.2258e-13) tensor(2.1245e-12)
tensor(1.9851e-13) tensor(1.7568e-12)
tensor(7.8143e-13) tensor(1.6378e-12)
tensor(1.2834e-13) tensor(3.3129e-12)
tensor(1.0626e-13) tensor(7.1765e-13)


In [14]:
Phi_1 = torch.tensor([[f1_1.T @ S11 @ f1_1, f1_1.T @ S12 @ f2_1, f1_1.T @ S13 @ f3_1],
                      [f1_1.T @ S12 @ f2_1, f2_1.T @ S22 @ f2_1, f2_1.T @ S23 @ f3_1],
                      [f1_1.T @ S13 @ f3_1, f2_1.T @ S23 @ f3_1, f3_1.T @ S33 @ f3_1]])
Phi_2 = torch.tensor([[f1_2.T @ S11_mod @ f1_2, f1_2.T @ S12_mod @ f2_2, f1_2.T @ S13_mod @ f3_2],
                      [f1_2.T @ S12_mod @ f2_2, f2_2.T @ S22_mod @ f2_2, f2_2.T @ S23_mod @ f3_2],
                      [f1_2.T @ S13_mod @ f3_2, f2_2.T @ S23_mod @ f3_2, f3_2.T @ S33_mod @ f3_2]])
print(torch.logdet(Phi_1) + torch.logdet(Phi_2))

tensor(-2.6776)


In [15]:
print(f1_2.T @ S11 @ f1_2)
print(f2_2.T @ S22 @ f2_2)
print(f3_2.T @ S33 @ f3_2)

tensor([[1.0003]])
tensor([[1.0276]])
tensor([[1.0462]])


In [41]:
print(f1_2.T @ S12 @ f2_2)
print(f1_2.T @ S13 @ f3_2)
print(f2_2.T @ S23 @ f3_2)

tensor([[0.9428]])
tensor([[1.0859]])
tensor([[1.0297]])


In [25]:
print(F1.T @ S12 @ F2)
print(F1.T @ S13 @ F3)
print(F2.T @ S23 @ F3)

tensor([[-0.3595, -0.2429],
        [ 0.2392, -0.1712]])
tensor([[ 0.0625, -0.5011],
        [ 0.5813,  0.0235]])
tensor([[ 0.3042,  0.2207],
        [-0.2623,  0.1182]])


In [72]:
W1_bar = W_em[0] @ S11_inv2
W2_bar = W_em[1] @ S22_inv2
W3_bar = W_em[2] @ S33_inv2

B1, S1, A1 = torch.linalg.svd(W1_bar, full_matrices = False)
B2, S2, A2 = torch.linalg.svd(W2_bar, full_matrices = False)
B3, S3, A3 = torch.linalg.svd(W3_bar, full_matrices = False)

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
M2 = torch.cat([torch.cat([torch.eye(3), P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, torch.eye(3), P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, torch.eye(3)], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.5966)
tensor(0.5966)


In [73]:
# Nope!
W_bar = torch.cat([W1_bar, W2_bar, W3_bar], 1)
B, S, A = torch.linalg.svd(W_bar, full_matrices = False)

A1 = A[:, psum[0]:psum[1]]
A2 = A[:, psum[1]:psum[2]]
A3 = A[:, psum[2]:psum[3]]

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

M2 = torch.cat([torch.cat([torch.eye(3), P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, torch.eye(3), P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, torch.eye(3)], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.5966)
tensor(0.9919)


In [27]:
M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
A, S, _ = torch.linalg.svd(M1, full_matrices = False)

A1 = A[psum[0]:psum[1], 0:2].T
A2 = A[psum[1]:psum[2], 0:2].T
A3 = A[psum[2]:psum[3], 0:2].T

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

In [20]:

# Fit 1-d one time
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(1000):
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step_alt(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): print(delta_W, delta_Phi)

# Project out learned space and fit again
a1 = W_curr[:, psum[0]:psum[1]] / torch.sqrt(torch.sum(W_curr[:, psum[0]:psum[1]]**2))
a2 = W_curr[:, psum[1]:psum[2]] / torch.sqrt(torch.sum(W_curr[:, psum[1]:psum[2]]**2))
f1 = a1 @ S11_inv2
f2 = a2 @ S22_inv2

Y_mod = [data.Y[0] @ (torch.eye(3) - f1.T @ f1), data.Y[1] @ (torch.eye(3) - f2.T @ f2)]
Y_mod_all = torch.cat(Y_mod, dim = 1)
Y_mod_all = Y_mod_all - torch.mean(Y_mod_all, 0)
Sigma_mod_tilde = (Y_mod_all.T @ Y_mod_all) / (n - 1)

W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(1000):
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_mod_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step_alt(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): print(delta_W, delta_Phi)
    
# Project out learned space and check residual covariance
a1 = W_curr[:, psum[0]:psum[1]] / torch.sqrt(torch.sum(W_curr[:, psum[0]:psum[1]]**2))
a2 = W_curr[:, psum[1]:psum[2]] / torch.sqrt(torch.sum(W_curr[:, psum[1]:psum[2]]**2))
f1 = a1 @ S11_inv2
f2 = a2 @ S22_inv2

Y_mod_2 = [Y_mod[0] @ (torch.eye(3) - f1.T @ f1), Y_mod[1] @ (torch.eye(3) - f2.T @ f2)]
Y_mod_2_all = torch.cat(Y_mod_2, dim = 1)
Y_mod_2_all = Y_mod_2_all - torch.mean(Y_mod_2_all, 0)
Sigma_mod_2_tilde = (Y_mod_2_all.T @ Y_mod_2_all) / (n - 1)

print(Sigma_mod_2_tilde)

tensor([[-0.2094,  0.5825],
        [ 0.3363,  0.1146],
        [ 0.4213,  0.2283],
        [-0.5645,  0.0326],
        [-0.0425,  0.4586],
        [ 0.1896,  0.0066],
        [ 0.3092,  0.5180],
        [-0.1159, -0.0901],
        [ 0.4474, -0.3280]])

In [24]:
A

tensor([[-0.1435,  0.3387,  0.4332, -0.6028, -0.0526,  0.2022,  0.3386, -0.1160,
          0.3767],
        [ 0.6568,  0.1127,  0.2353,  0.1272,  0.3418, -0.0299,  0.5002, -0.0857,
         -0.3304]])

In [39]:
A1 @ C13_tilde @ A3.T

tensor([[0.2900, 0.0547],
        [0.0324, 0.0781]])

In [29]:
W1_bar @ (C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T

tensor([[-0.0003,  0.0012],
        [-0.0005, -0.0002]])

In [31]:
M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
M2 = torch.cat([torch.cat([P1, P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, P2, P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, P3], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.1111)
tensor(-8.0517e-27)


In [8]:
print(P1C12P2)
print(P1C13P3)
print(P2C23P3)

tensor([[ 0.3023,  0.3841, -0.0876],
        [-0.3758,  0.0207,  0.1281],
        [-0.4620,  0.0773,  0.1595]])
tensor([[ 0.4323, -0.0576, -0.4560],
        [ 0.3167, -0.0927,  0.1888],
        [ 0.4783, -0.1311,  0.1927]])
tensor([[-0.2974,  0.1198, -0.5165],
        [ 0.2443, -0.0379, -0.2027],
        [ 0.1102, -0.0421,  0.1672]])


In [9]:
print(W1_bar.T @ W2_bar)
print(W1_bar.T @ W3_bar)
print(W2_bar.T @ W3_bar)

tensor([[ 0.3040,  0.3764, -0.0885],
        [-0.3769,  0.0271,  0.1287],
        [-0.4633,  0.0849,  0.1602]])
tensor([[ 0.4341, -0.0580, -0.4556],
        [ 0.3155, -0.0925,  0.1900],
        [ 0.4770, -0.1309,  0.1944]])
tensor([[-0.2974,  0.1195, -0.5125],
        [ 0.2405, -0.0352, -0.2208],
        [ 0.1100, -0.0418,  0.1651]])


In [10]:
print(torch.mean(abs(P1C12P2 - W1_bar.T @ W2_bar)))
print(torch.mean(abs(P1C13P3 - W1_bar.T @ W3_bar)))
print(torch.mean(abs(P2C23P3 - W2_bar.T @ W3_bar)))

tensor(0.0031)
tensor(0.0009)
tensor(0.0035)


In [106]:
print(torch.mean(abs((C12_tilde - W1_bar.T @ W2_bar) @  W2_bar.T)))
print(torch.mean(abs((C13_tilde - W1_bar.T @ W3_bar) @  W3_bar.T)))
print(torch.mean(abs((C23_tilde - W2_bar.T @ W3_bar) @  W3_bar.T)))

tensor(0.0028)
tensor(0.0027)
tensor(0.0006)


In [141]:
(C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T @ torch.linalg.inv(torch.eye(1) - W3_bar @ W3_bar.T) 

tensor([[-0.0293],
        [ 0.0047],
        [-0.0501]])

In [142]:
(C13_tilde - W1_bar.T @ W3_bar) @  torch.linalg.inv(torch.eye(3) - W3_bar.T @ W3_bar) @ W3_bar.T

tensor([[-0.0293],
        [ 0.0047],
        [-0.0501]])

In [143]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T

tensor([[ 0.0073],
        [-0.0012],
        [ 0.0126]])

In [145]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T @ torch.linalg.inv(1 - W2_bar @ W2_bar.T)

tensor([[ 0.0293],
        [-0.0047],
        [ 0.0501]])

In [12]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T @ torch.linalg.inv(1 - W2_bar @ W2_bar.T) + (C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T @ torch.linalg.inv(1 - W3_bar @ W3_bar.T) 

tensor([[ 4.5940e-03, -2.8103e-04],
        [ 4.4791e-03,  2.6003e-03],
        [-9.0168e-03, -5.8655e-05]])

In [51]:
[S11 - W_em[0].T @ W_em[0], S22 - W_em[1].T @ W_em[1], S33 - W_em[2].T @ W_em[2]]

[tensor([[4.6497, 0.2221, 0.0900],
         [0.2221, 1.9237, 0.3744],
         [0.0900, 0.3744, 1.7076]]),
 tensor([[ 1.9754,  1.5128, -1.7825],
         [ 1.5128,  2.2075, -2.4684],
         [-1.7825, -2.4684,  3.6764]]),
 tensor([[ 1.5073, -1.3006,  0.8230],
         [-1.3006,  8.2849, -3.3705],
         [ 0.8230, -3.3705,  5.0412]])]

In [54]:
W_curr.T

tensor([[ 0.0027, -0.6311],
        [-0.0130, -1.2672],
        [ 1.4927,  0.2291],
        [ 0.0977, -0.7140],
        [-0.0575, -1.3654],
        [-0.1034,  0.0930],
        [-0.1427, -0.4694],
        [-0.5685, -0.8861],
        [ 1.7051,  1.6871]])

In [29]:
# Lemma 5???
torch.mean(abs((Sigma_tilde - W_curr.T @ W_curr) @ torch.linalg.inv(torch.block_diag(S11 - W_em[0].T @ W_em[0], S22 - W_em[1].T @ W_em[1], S33 - W_em[2].T @ W_em[2])) @ W_curr.T - W_curr.T))

NameError: name 'W_em' is not defined

In [11]:
std_normal = torch.distributions.Normal(0, 1)
f = std_normal.sample([3, 3])
f = f @ torch.linalg.inv(sqrtm(f.T @ f))
f.T @ f
f1 = f[:,0, None]
f2 = f[:, 1, None]
f3 = f[:, 2, None]

f1f1T = f1 @ f1.T
f2f2T = f2 @ f2.T
f3f3T = f3 @ f3.T

mat1 = torch.cat([torch.cat([torch.eye(3), f1f1T @ C12_tilde @ f2f2T, f1f1T @ C13_tilde @ f3f3T], 1),
                  torch.cat([f2f2T @ C12_tilde.T @ f1f1T, torch.eye(3), f2f2T @ C23_tilde @ f3f3T], 1),
                  torch.cat([f3f3T @ C13_tilde.T @ f1f1T, f3f3T @ C23_tilde.T @ f2f2T, torch.eye(3)], 1)], 0)
mat2 = torch.cat([torch.cat([torch.eye(1), f1.T @ C12_tilde @ f2, f1.T @ C13_tilde @ f3], 1),
                  torch.cat([f2.T @ C12_tilde.T @ f1, torch.eye(1), f2.T @ C23_tilde @ f3], 1),
                  torch.cat([f3.T @ C13_tilde.T @ f1, f3.T @ C23_tilde.T @ f2, torch.eye(1)], 1)], 0)
print(torch.linalg.svd(mat1).S)
print(torch.linalg.svd(mat2).S)

tensor([1.2103, 1.0694, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.7203])
tensor([1.2103, 1.0694, 0.7203])


In [19]:
S11_inv = torch.linalg.inv(S11)
PU1d = W_ml[0] @ S11_inv
print(PU1d @ S11 @ PU1d.T)

S22_inv = torch.linalg.inv(S22)
PU2d = W_ml[1] @ S22_inv
print(PU2d @ S22 @ PU2d.T)

S33_inv = torch.linalg.inv(S33)
PU3d = W_ml[1] @ S33_inv
print(PU3d @ S33 @ PU3d.T)

print(PU1d @ S11 @ PU1d.T + PU2d @ S22 @ PU2d.T + PU3d @ S33 @ PU3d.T)

# print(PU1d @ S12 @ PU2d.T)
# print(PU1d.T @ PU2d)  # Should = next
# print( (U1d * p_mat[0:d]) @ U2d.T )

tensor([[0.8872, 0.0354],
        [0.0354, 0.8952]])
tensor([[1.3133, 0.0112],
        [0.0112, 0.9161]])
tensor([[1.3942, 0.4744],
        [0.4744, 0.3550]])
tensor([[3.5946, 0.5210],
        [0.5210, 2.1664]])
