In [2]:
import torch
import numpy as np
import micca_model
import simulate_model_gaussian
torch.set_printoptions(linewidth = 200)

In [9]:
def sqrtm(X):
    U, l, VT = torch.linalg.svd(X)
    return((U * torch.sqrt(l)) @ VT)

In [86]:
n = 1000
p = [3, 3, 3, 3, 3]
psum = np.concatenate([[0], np.cumsum(p, 0)]) 
d = 1
k = None

In [87]:
model = simulate_model_gaussian.generate_model(p, k, d)
data = model.simulate(n)
Y_all = torch.cat(data.Y, dim = 1)
Y_all = Y_all - torch.mean(Y_all, 0)
Sigma_tilde = (Y_all.T @ Y_all) / (n - 1)
W = torch.cat(model.W, dim = 1)
Phi = torch.block_diag(*model.Phi)

In [88]:
Y = [Y_m - torch.mean(Y_m, 0) for Y_m in data.Y]
S11 = (Y[0].T @ Y[0])/(n-1)
S22 = (Y[1].T @ Y[1])/(n-1)
S33 = (Y[2].T @ Y[2])/(n-1)
S12 = (Y[0].T @ Y[1])/(n-1)
S13 = (Y[0].T @ Y[2])/(n-1)
S23 = (Y[1].T @ Y[2])/(n-1)
S11_inv2 = sqrtm(torch.linalg.pinv(S11, rcond = 1e-06))
S22_inv2 = sqrtm(torch.linalg.pinv(S22, rcond = 1e-06))
S33_inv2 = sqrtm(torch.linalg.pinv(S33, rcond = 1e-06))
C12_tilde = S11_inv2 @ S12 @ S22_inv2
C13_tilde = S11_inv2 @ S13 @ S33_inv2
C23_tilde = S22_inv2 @ S23 @ S33_inv2

In [90]:
# Fit 2-D
std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([d, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(5000):
    # W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): 
        print(delta_W, delta_Phi, micca_model.loglik(W_curr.T @ W_curr + Phi_curr, Sigma_tilde, n))
W_hat = W_curr
Phi_hat = Phi_curr
Sigma_hat = W_hat.T @ W_hat + Phi_hat
Sigma_hat_inv = torch.linalg.inv(Sigma_hat)
W1_hat = W_hat[:, psum[0]:psum[1]]
W2_hat = W_hat[:, psum[1]:psum[2]]
W3_hat = W_hat[:, psum[2]:psum[3]]

tensor(15.2321) tensor(180.1042) tensor(26090.3359)
tensor(3.7083e-14) tensor(5.8087e-13) tensor(25386.8574)
tensor(6.1135e-14) tensor(3.2208e-12) tensor(25386.8594)
tensor(2.5019e-14) tensor(4.7362e-13) tensor(25386.8594)
tensor(1.3289e-13) tensor(1.1049e-12) tensor(25386.8594)
tensor(3.2711e-14) tensor(1.3780e-12) tensor(25386.8594)
tensor(2.4982e-14) tensor(4.5675e-13) tensor(25386.8594)
tensor(6.1446e-14) tensor(3.5016e-13) tensor(25386.8594)
tensor(1.4641e-13) tensor(1.2126e-12) tensor(25386.8594)
tensor(1.8723e-14) tensor(1.7206e-12) tensor(25386.8594)
tensor(1.4973e-13) tensor(1.3316e-12) tensor(25386.8594)
tensor(4.0142e-14) tensor(7.4851e-13) tensor(25386.8594)
tensor(2.1445e-14) tensor(7.0544e-13) tensor(25386.8594)
tensor(1.8677e-13) tensor(5.1625e-13) tensor(25386.8594)
tensor(2.5386e-14) tensor(1.2126e-12) tensor(25386.8594)
tensor(1.7851e-14) tensor(2.9265e-13) tensor(25386.8594)
tensor(3.0605e-13) tensor(6.3349e-13) tensor(25386.8594)
tensor(1.4694e-13) tensor(7.4785e-13

In [75]:
# Lemma 1
all(torch.linalg.svd(Sigma_tilde - W_hat.T @ W_hat).S > 0)

True

In [76]:
# Lemma 2
left = Sigma_tilde @ Sigma_hat_inv
right = (Sigma_tilde - W_hat.T @ W_hat) @ torch.linalg.inv(Phi_hat)
print(left - right)
print(torch.mean((left - right)**2))

tensor([[ 5.9605e-07, -1.1727e-07, -6.0203e-07, -1.7323e-07, -2.3469e-07,  2.7870e-07,  3.5306e-06,  5.1465e-06, -9.5740e-07,  1.2666e-07,  1.0803e-07,  1.0058e-07],
        [-1.0501e-06,  5.9605e-08,  9.2833e-07,  2.7381e-07,  2.2538e-07, -2.8126e-07, -8.9258e-06, -1.1533e-05,  1.8626e-06, -1.3411e-07, -1.1642e-07, -3.9488e-07],
        [-1.2577e-06,  2.8051e-07,  1.4901e-06,  3.5856e-07,  3.7858e-07, -4.4960e-07, -8.6324e-06, -1.1725e-05,  2.0321e-06, -1.8626e-07, -1.8196e-07, -3.4599e-07],
        [-2.8312e-07,  6.7987e-08,  2.8685e-07,  2.3842e-07,  1.0072e-07, -1.2166e-07, -1.9614e-06, -2.6803e-06,  4.6846e-07, -4.6566e-08, -4.4703e-08, -7.2177e-08],
        [-1.8775e-06,  5.2899e-07,  1.8589e-06,  4.4809e-07,  4.1723e-07, -5.3565e-07, -1.3143e-05, -1.7524e-05,  3.0249e-06, -3.0547e-07, -2.8126e-07, -5.4762e-07],
        [ 1.2666e-07, -3.7253e-08, -1.5087e-07, -1.1456e-07, -1.1699e-07, -1.1921e-07,  8.1956e-07,  1.2890e-06, -2.0117e-07,  3.3528e-08,  2.2352e-08,  3.1665e-08],
    

In [77]:
# Lemma 3
left = Phi_hat @ (Sigma_hat_inv - Sigma_hat_inv @ Sigma_tilde @ Sigma_hat_inv) @ Phi_hat
right = Phi_hat - (Sigma_tilde - W_hat.T @ W_hat)
print(left - right)
print(torch.mean((left - right)**2))

tensor([[-9.8837e-08,  4.4598e-07,  4.1443e-07,  1.0245e-07,  8.3447e-07, -5.2154e-08, -1.1716e-06, -5.0734e-07,  4.1351e-07, -4.4703e-08, -3.6880e-07, -1.7043e-07],
        [ 3.7422e-07, -5.1317e-07, -7.3437e-07, -1.3225e-07, -8.9407e-07,  1.3411e-07,  3.1106e-06,  1.0580e-06, -6.0722e-07,  1.0431e-07,  3.2037e-07,  7.4506e-07],
        [ 3.7027e-07, -8.4911e-07, -9.3627e-07, -1.9744e-07, -1.4231e-06,  1.0477e-07,  2.9616e-06,  1.1092e-06, -7.3947e-07,  1.0617e-07,  6.4122e-07,  6.2771e-07],
        [ 9.1270e-08, -1.9558e-07, -2.1607e-07,  1.0148e-08, -2.5128e-07,  2.3868e-08,  6.8918e-07,  2.4913e-07, -1.7975e-07,  2.5146e-08,  1.5646e-07,  1.4529e-07],
        [ 5.7369e-07, -1.4305e-06, -1.3579e-06, -4.1479e-07, -3.6030e-06,  9.3376e-07,  4.3660e-06,  1.6452e-06, -1.2815e-06,  1.6019e-07,  1.0058e-06,  9.0525e-07],
        [-4.4703e-08,  7.4506e-08,  1.2480e-07, -5.2384e-08, -2.4059e-07,  5.0116e-07, -3.8184e-07, -1.1548e-07, -1.1176e-08, -2.9802e-08, -5.9605e-08, -1.3411e-07],
    

In [78]:
# Lemma 4
print(all(torch.linalg.svd(Phi_hat[psum[0]:psum[1], psum[0]:psum[1]]).S > 0))
print(all(torch.linalg.svd(Phi_hat[psum[1]:psum[2], psum[1]:psum[2]]).S > 0))
print(all(torch.linalg.svd(Phi_hat[psum[2]:psum[3], psum[2]:psum[3]]).S > 0))

True
True
True


In [98]:
# Lemma 5
mid = Sigma_tilde - W_hat.T @ W_hat

for i_l, i_r in zip(psum[:-1], psum[1:]):                                                                                                                   
    mid[i_l:i_r, i_r:] = 0                                                                                                                             
    mid[i_r:, i_l:i_r] = 0 

right = (Sigma_tilde - W_hat.T @ W_hat) @ torch.linalg.inv(mid) @ W_hat.T
print(W_hat.T - right)
print(torch.mean((W_hat.T - right)**2))

tensor([[ 6.0797e-06],
        [-4.0531e-06],
        [-7.3910e-06],
        [-4.6194e-07],
        [ 6.0797e-06],
        [-3.7551e-06],
        [ 2.5146e-07],
        [ 1.2368e-06],
        [ 1.5367e-08],
        [-5.1409e-07],
        [-3.5763e-06],
        [ 3.7253e-08],
        [ 3.0398e-06],
        [ 4.8280e-06],
        [-1.4901e-06]])
tensor(1.3914e-11)


In [99]:
# Lemma 6 as written
W1_tilde = W1_hat @ S11_inv2
W2_tilde = W2_hat @ S22_inv2
W3_tilde = W3_hat @ S33_inv2

B1, S1, A1T = torch.linalg.svd(W1_tilde, full_matrices=False)
B2, S2, A2T = torch.linalg.svd(W2_tilde, full_matrices=False)
B3, S3, A3T = torch.linalg.svd(W3_tilde, full_matrices=False)
A1 = A1T.T
A2 = A2T.T
A3 = A3T.T

# print(C12_tilde @ A2 - (A1*S1) @ B1.T @ (B2*S2))
# print(torch.mean((C12_tilde @ A2 - (A1*S1) @ B1.T @ (B2*S2))**2))

# print(C12_tilde.T @ C12_tilde @ A2 - (A2*S2) @ B2.T @ (B1*(S1**2)) @ B1.T @ (B2*S2))
# print(torch.mean((C12_tilde.T @ C12_tilde @ A2 - (A2*S2) @ B2.T @ (B1*(S1**2)) @ B1.T @ (B2*S2))**2))

# A potential modification
# This is TRUE?
print(A1.T @ C12_tilde @ A2 - (B1*S1).T @ (B2*S2))
print(A1.T @ C13_tilde @ A3 - (B1*S1).T @ (B3*S3))
print(A2.T @ C23_tilde @ A3 - (B2*S2).T @ (B3*S3))

# This is False?
print(A2.T @ C12_tilde.T @ C12_tilde @ A2 - (B2@S2).T @ (B1*(S1**2)) @ B1.T @ (B2*S2))

tensor([[0.0016]])
tensor([[0.0100]])
tensor([[0.0002]])
tensor([[-0.0020]])


In [102]:
B1

tensor([[-1.]])

In [85]:
(B1*S1).T @ (B2*S2)

tensor([[-0.3958]])