In [259]:
import torch
import numpy as np
import pandas as pd
from MPCCA.py import micca_model
from MPCCA.py import simulate_model_gaussian
torch.set_printoptions(linewidth = 200)

In [150]:
def sqrtm(X):
    L, V = torch.linalg.eigh(X)
    print(L)
    return(V * torch.sqrt(abs(L)))

In [3]:
n = 1000
p = [3, 3, 3]
psum = np.concatenate([[0], np.cumsum(p, 0)]) 
d = 2
k = None

In [185]:
model = simulate_model_gaussian.generate_model(p, k, d)
data = model.simulate(n)
Y_all = torch.cat(data.Y, dim = 1)
Y_all = Y_all - torch.mean(Y_all, 0)
Sigma_tilde = (Y_all.T @ Y_all) / (n - 1)
W = torch.cat(model.W, dim = 1)
Phi = torch.block_diag(*model.Phi)

In [186]:
Y = [Y_m - torch.mean(Y_m, 0) for Y_m in data.Y]
S11 = (Y[0].T @ Y[0])/(n-1)
S22 = (Y[1].T @ Y[1])/(n-1)
S33 = (Y[2].T @ Y[2])/(n-1)
S12 = (Y[0].T @ Y[1])/(n-1)
S13 = (Y[0].T @ Y[2])/(n-1)
S23 = (Y[1].T @ Y[2])/(n-1)
S11_inv2 = sqrtm(torch.linalg.pinv(S11, rcond = 1e-06))
S22_inv2 = sqrtm(torch.linalg.pinv(S22, rcond = 1e-06))
S33_inv2 = sqrtm(torch.linalg.pinv(S33, rcond = 1e-06))
C12_tilde = S11_inv2 @ S12 @ S22_inv2
C13_tilde = S11_inv2 @ S13 @ S33_inv2

tensor([0.1593, 0.3018, 0.5869])
tensor([0.1654, 0.4159, 0.8997])
tensor([0.1507, 0.4626, 1.0055])


In [223]:
# TODO(brielin): do a simulation comparing the tiny details
# TODO(brielin): is Phi_init REALLY not psd using this?
def init_ppca(Y, d, sqrt_sets = True, orth_cca = False, ppca = False):
  # Y list of tensors
  # d int
  # return W, Phi
  n_sets = len(Y)
  n = Y[0].shape[0]
  ps = [y.shape[1] for y in Y]
  p = sum(ps)
  Sigmas = [y.T @ y / n for y in Y]
  Y_svds = [torch.linalg.svd(y, full_matrices=False) for y in Y]
  Us, Ss, Vts = list(map(list, zip(*Y_svds)))
  U_all = torch.cat(Us, dim = 1)
  Sig_ppca = U_all.T @ U_all
  L, V = torch.linalg.eigh(Sig_ppca)
  L = torch.abs(torch.flip(L, dims = [0]))
  if sqrt_sets:
    V = np.sqrt(n_sets) * torch.flip(V, dims=[1])
  else:
    V = torch.flip(V, dims=[1])
  psum = np.concatenate([[0], np.cumsum(ps, 0)])
  V_splits = [V[i:j, 0:d] for i, j in zip(psum[:-1], psum[1:])]
  if ppca:
    sigma2_ml = torch.sum(L[d:])/(p - d)
  else:
    sigma2_ml = 0
  Vds = [vd * torch.sqrt(L[0:d] - sigma2_ml) for vd in V_splits]
  if orth_cca:
    Pmats, _ = micca_model._make_P_mats(Vds)
    Vds = [torch.linalg.eigh(mat).eigenvectors[:, (mat.shape[0]-d):] for mat in Pmats]
  W_init = [(np.sqrt(n) * vt.T / s) @ vd for vt, s, vd in zip(Vts, Ss, Vds)]
  Phi_init = [Sigma - W @ W.T for Sigma, W in zip(Sigmas, W_init)]
  return torch.cat(W_init, 0), torch.block_diag(*Phi_init)

In [288]:
def fit_sequence(Y, d, niter, W0, Phi0):
  p = [ds.shape[1] for ds in Y]
  Y = torch.cat(Y, 1).float()
  n = Y.shape[0]
  Sigma_tilde = Y.T @ Y / (n -1)
  W_seq = [W0]
  Phi_seq = [Phi0]
  cov_delta_seq = [torch.mean((W0.T @ W0 + Phi0 - Sigma_tilde)**2)]
  l_seq = [micca_model.loglik(W0.T @ W0 + Phi0, Sigma_tilde, n)]
  for i in range(niter):
    W1, Phi1 = micca_model.EM_step_stable(W0, Phi0, Y, Sigma_tilde, p)
    l1 = micca_model.loglik(W1.T @ W1 + Phi1, Sigma_tilde, n)
    cd = torch.mean((W1.T @ W1 + Phi1 - Sigma_tilde)**2)
    l_seq = l_seq + [l1]
    cov_delta_seq = cov_delta_seq + [cd]
    W0 = W1
    Phi0 = Phi1
  return l_seq, cov_delta_seq

def sim_one(p, k, d, n, niter = 100):
  model = simulate_model_gaussian.generate_model(p, k, d)
  data = model.simulate(n)
  Y_all = torch.cat(data.Y, dim = 1)
  Y_all = Y_all - torch.mean(Y_all, 0)
  Sigma_tilde = (Y_all.T @ Y_all) / (n - 1)
  W = torch.cat(model.W, dim = 1)
  Phi = torch.block_diag(*model.Phi)

  W0, Phi0 = init_ppca(data.Y, d, True, False, False)
  l_seq, cd_seq = fit_sequence(data.Y, d, niter, W0=W0.T, Phi0=Phi0.T)
  res_df_TFF = pd.DataFrame({'method': ['TFF']*(niter+1), 'it': range(niter+1), 'l': l_seq, 'cd': cd_seq})

  std_normal = torch.distributions.Normal(0, 1)
  W0 = std_normal.sample([d, sum(p)])
  Phi0 = torch.eye(sum(p))
  l_seq, cd_seq = fit_sequence(data.Y, d, niter, W0=W0, Phi0=Phi0)
  res_df_rand = pd.DataFrame({'method': ['rand']*(niter+1), 'it': range(niter+1), 'l': l_seq, 'cd': cd_seq})

  res_df = pd.concat([res_df_TFF, res_df_rand])
  return(res_df)


def sim_many(p, k, d, n, niter = 100, nsims = 10):
  all_res = []
  for i in range(nsims):
    sim_res = sim_one(p, k, d, n, niter)
    sim_res['sim'] = i
    all_res = all_res + [sim_res]
  return(pd.concat(all_res))

In [289]:
sim_many([3, 3, 3], None, 2, 1000, niter=10, nsims=2)

11
11


Unnamed: 0,method,it,l,cd,sim
0,TFF,0,tensor(18280.4531),tensor(0.9509),0
1,TFF,1,tensor(18126.9883),tensor(0.6011),0
2,TFF,2,tensor(17993.2871),tensor(0.3283),0
3,TFF,3,tensor(17912.0977),tensor(0.1854),0
4,TFF,4,tensor(17866.8320),tensor(0.1090),0
5,TFF,5,tensor(17841.3789),tensor(0.0659),0
6,TFF,6,tensor(17827.7344),tensor(0.0409),0
7,TFF,7,tensor(17820.7754),tensor(0.0260),0
8,TFF,8,tensor(17817.2363),tensor(0.0168),0
9,TFF,9,tensor(17815.3496),tensor(0.0111),0


In [242]:
W_curr = W_init.T
Phi_curr = Phi_init
for i in range(50):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1 == 0): 
        print(i, delta_W, delta_Phi, micca_model.loglik(W_curr.T @ W_curr + Phi_curr, Sigma_tilde, n))
W_hat = W_curr
Phi_hat = Phi_curr

0 tensor(4.8433) tensor(0.5829) tensor(16818.5996)
1 tensor(0.7213) tensor(3.9082) tensor(16466.1191)
2 tensor(0.2684) tensor(3.4604) tensor(16177.3623)
3 tensor(0.0862) tensor(1.2911) tensor(16016.0254)
4 tensor(0.0301) tensor(0.3145) tensor(15953.2148)
5 tensor(0.0129) tensor(0.0774) tensor(15930.6514)
6 tensor(0.0065) tensor(0.0249) tensor(15920.9766)
7 tensor(0.0035) tensor(0.0106) tensor(15915.8799)
8 tensor(0.0021) tensor(0.0053) tensor(15912.8262)
9 tensor(0.0013) tensor(0.0028) tensor(15910.8584)
10 tensor(0.0008) tensor(0.0016) tensor(15909.5215)
11 tensor(0.0005) tensor(0.0009) tensor(15908.5850)
12 tensor(0.0004) tensor(0.0006) tensor(15907.9072)
13 tensor(0.0003) tensor(0.0004) tensor(15907.4111)
14 tensor(0.0002) tensor(0.0002) tensor(15907.0420)
15 tensor(0.0001) tensor(0.0002) tensor(15906.7607)
16 tensor(9.9971e-05) tensor(0.0001) tensor(15906.5488)
17 tensor(7.4115e-05) tensor(8.6282e-05) tensor(15906.3857)
18 tensor(5.5320e-05) tensor(6.3969e-05) tensor(15906.2578)
19

In [234]:
std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([d, sum(p)])
Phi_curr = torch.eye(sum(p))

print(micca_model.loglik(Sigma_tilde, W_curr.T @ W_curr + Phi_curr, n))

for i in range(50):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1 == 0): 
        print(i, delta_W, delta_Phi, micca_model.loglik(W_curr.T @ W_curr + Phi_curr, Sigma_tilde, n))
W_hat = W_curr
Phi_hat = Phi_curr

tensor(18212.1719)
0 tensor(3.6747) tensor(24.4191) tensor(16615.0605)
1 tensor(0.3235) tensor(1.6082) tensor(16318.5830)
2 tensor(0.1831) tensor(0.9433) tensor(16127.7695)
3 tensor(0.0927) tensor(0.4386) tensor(16025.1221)
4 tensor(0.0400) tensor(0.1851) tensor(15977.3301)
5 tensor(0.0172) tensor(0.0833) tensor(15953.7529)
6 tensor(0.0086) tensor(0.0456) tensor(15940.0674)
7 tensor(0.0052) tensor(0.0302) tensor(15931.1562)
8 tensor(0.0036) tensor(0.0220) tensor(15925.0459)
9 tensor(0.0026) tensor(0.0165) tensor(15920.7588)
10 tensor(0.0019) tensor(0.0125) tensor(15917.7021)
11 tensor(0.0014) tensor(0.0095) tensor(15915.4775)
12 tensor(0.0010) tensor(0.0072) tensor(15913.8203)
13 tensor(0.0008) tensor(0.0055) tensor(15912.5498)
14 tensor(0.0006) tensor(0.0043) tensor(15911.5391)
15 tensor(0.0005) tensor(0.0035) tensor(15910.7168)
16 tensor(0.0004) tensor(0.0029) tensor(15910.0225)
17 tensor(0.0003) tensor(0.0024) tensor(15909.4277)
18 tensor(0.0003) tensor(0.0021) tensor(15908.9072)
19

In [None]:
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]

W1_tilde = W1 @ S11_inv2
W2_tilde = W2 @ S22_inv2
W3_tilde = W3 @ S33_inv2

# The original conjecture in terms of SVD of W_tilde
B1, S1, A1T = torch.linalg.svd(W1_tilde, full_matrices=False)
B2, S2, A2T = torch.linalg.svd(W2_tilde, full_matrices=False)
B3, S3, A3T = torch.linalg.svd(W3_tilde, full_matrices=False)
F1 = S11_inv2 @ A1T.T
F2 = S22_inv2 @ A2T.T
F3 = S33_inv2 @ A3T.T

model_cov_1 = torch.cat([torch.cat([S11, W1.T @ W2, W1.T @ W3], axis = 1),
                         torch.cat([W2.T @ W1, S22, W2.T @ W3], axis = 1),
                         torch.cat([W3.T @ W1, W3.T @ W2, S33], axis = 1)], axis = 0)
model_cov_2 = torch.cat([torch.cat([torch.eye(p[0]), W1_tilde.T @ W2_tilde, W1_tilde.T @ W3_tilde], axis = 1),
                         torch.cat([W2_tilde.T @ W1_tilde, torch.eye(p[1]), W2_tilde.T @ W3_tilde], axis = 1),
                         torch.cat([W3_tilde.T @ W1_tilde, W3_tilde.T @ W2_tilde, torch.eye(p[2])], axis = 1)], axis = 0)
gen_var = torch.cat([torch.cat([F1.T @ S11 @ F1, F1.T @ S12 @ F2, F1.T @ S13 @ F3], axis = 1),
                     torch.cat([(F1.T @ S12 @ F2).T, F2.T @ S22 @ F2, F2.T @ S23 @ F3], axis = 1),
                     torch.cat([(F1.T @ S13 @ F3).T, (F2.T @ S23 @ F3).T, F3.T @ S33 @ F3], axis = 1)], axis = 0)
gen_var_2 = torch.cat([torch.cat([torch.eye(p[0]), A1T.T @ F1.T @ S12 @ F2 @ A2T, A1T.T @ F1.T @ S13 @ F3 @ A3T], axis = 1),
                       torch.cat([(A1T.T @ F1.T @ S12 @ F2 @ A2T).T, torch.eye(p[1]), A2T.T @ F2.T @ S23 @ F3 @ A3T], axis = 1),
                       torch.cat([(A1T.T @ F1.T @ S13 @ F3 @ A3T).T, (A2T.T @ F2.T @ S23 @ F3 @ A3T).T, torch.eye(p[1])], axis = 1)], axis = 0)

D_half = torch.block_diag(sqrtm(S11), sqrtm(S22), sqrtm(S33))

Sigma_alt = D_half @ gen_var_2 @ D_half
print(micca_model.loglik(model_cov_1, Sigma_tilde, n))
print(micca_model.loglik(Sigma_alt, Sigma_tilde, n))
# print(micca_model.loglik(W_ml.T @ W_ml + Phi_ml, Sigma_tilde, n))

# a1 = (W1_tilde.T/torch.sqrt(torch.sum(W1_tilde**2, 1)))
# a2 = (W2_tilde.T/torch.sqrt(torch.sum(W2_tilde**2, 1)))
# a3 = (W3_tilde.T/torch.sqrt(torch.sum(W3_tilde**2, 1)))
# f1 = S11_inv2 @ a1
# f2 = S22_inv2 @ a2
# f3 = S33_inv2 @ a3

# Phi_test = torch.cat([torch.cat([f1.T @ S11 @ f1, f1.T @ S12 @ f2, f1.T @ S13 @ f3], axis = 1),
#                  torch.cat([(f1.T @ S12 @ f2).T, f2.T @ S22 @ f2, f2.T @ S23 @ f3], axis = 1),
#                  torch.cat([(f1.T @ S13 @ f3).T, (f2.T @ S23 @ f3).T, f3.T @ S33 @ f3], axis = 1)], axis = 0)
# print(torch.logdet(Phi_test))

# f1_1 = F1[:, 0:1]
# f2_1 = F2[:, 0:1]
# f3_1 = F3[:, 0:1]
# f1_2 = F1[:, 1:2]
# f2_2 = F2[:, 1:2]
# f3_2 = F3[:, 0:1]

# Phi_1 = torch.tensor([[f1_1.T @ S11 @ f1_1, f1_1.T @ S12 @ f2_1, f1_1.T @ S13 @ f3_1],
#                       [f1_1.T @ S12 @ f2_1, f2_1.T @ S22 @ f2_1, f2_1.T @ S23 @ f3_1],
#                       [f1_1.T @ S13 @ f3_1, f2_1.T @ S23 @ f3_1, f3_1.T @ S33 @ f3_1]])
# Phi_2 = torch.tensor([[f1_2.T @ S11 @ f1_2, f1_2.T @ S12 @ f2_2, f1_2.T @ S13 @ f3_2],
#                       [f1_2.T @ S12 @ f2_2, f2_2.T @ S22 @ f2_2, f2_2.T @ S23 @ f3_2],
#                       [f1_2.T @ S13 @ f3_2, f2_2.T @ S23 @ f3_2, f3_2.T @ S33 @ f3_2]])
# print(torch.logdet(Phi_1) + torch.logdet(Phi_2))

tensor(16560.0234)
tensor(16559.8652)


In [None]:
X12 = F1.T @ S12 @ F2
X13 = F1.T @ S13 @ F3
X23 = F2.T @ S23 @ F3

print(X13 @ torch.linalg.inv(X23) @ X12.T)
print(S13 @ torch.linalg.inv(S23) @ S12.T)

tensor([[ 0.7558,  0.0086],
        [-0.0226,  0.0446]])
tensor([[ 0.6551,  2.1239, -0.5474],
        [ 1.6988,  6.5813, -2.0663],
        [-0.3463, -2.2200,  1.1871]])


In [None]:
W1_tilde.T @ W2_tilde

tensor([[-0.1758, -0.0060,  0.0444],
        [-0.6352, -0.0407,  0.0311],
        [ 0.0767,  0.0270,  0.1450]])

In [None]:
A1T.T @ F1.T @ S12 @ F2 @ A2T

tensor([[-0.1768, -0.0059,  0.0453],
        [-0.6349, -0.0408,  0.0308],
        [ 0.0723,  0.0273,  0.1490]])

In [None]:
X12 = (B1*S1).T @ (B2*S2)
X13 = (B1*S1).T @ (B3*S3)
X23 = (B2*S2).T @ (B3*S3)
print(X13 @ torch.linalg.inv(X23) @ X12.T)

tensor([[ 7.6278e-01, -3.7253e-08],
        [-2.9802e-08,  3.8986e-02]])


In [None]:
(C12_tilde - W1_tilde.T @ W2_tilde) @ W2_tilde.T @ torch.linalg.inv(1 - W2_tilde @ W2_tilde.T) + (C13_tilde - W1_tilde.T @ W3_tilde) @ W3_tilde.T @ torch.linalg.inv(1 - W3_tilde @ W3_tilde.T) 

tensor([[-2.1039e-02,  1.9578e-03],
        [ 5.2834e-03,  4.2273e-06],
        [-5.4943e-03,  7.0994e-03]])

In [None]:
# Fit 1-D once
std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(10000):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p, rcond=None)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1000 == 0): print(delta_W, delta_Phi)
W_hat = W_curr
Phi_hat = Phi_curr
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]

# Project out learned space and fit again
W1_tilde = W1 @ S11_inv2
W2_tilde = W2 @ S22_inv2
W3_tilde = W3 @ S33_inv2
a1 = (W1_tilde / torch.sqrt(torch.sum(W1_tilde**2))).T
a2 = (W2_tilde / torch.sqrt(torch.sum(W2_tilde**2))).T
a3 = (W3_tilde / torch.sqrt(torch.sum(W3_tilde**2))).T
f1_1 = S11_inv2 @ a1
f2_1 = S22_inv2 @ a2
f3_1 = S33_inv2 @ a3
P1 = torch.eye(p[0]) - f1_1 @ f1_1.T @ S11
P2 = torch.eye(p[1]) - f2_1 @ f2_1.T @ S22
P3 = torch.eye(p[2]) - f3_1 @ f3_1.T @ S33
Y_mod = [Y[0] @ P1, Y[1] @ P2, Y[2] @ P3]
Y_mod = [Y - torch.mean(Y, 0) for Y in Y_mod]
Y_mod_all = torch.cat(Y_mod, dim = 1)
S11_mod = (Y_mod[0].T @ Y_mod[0])/(n-1)
S22_mod = (Y_mod[1].T @ Y_mod[1])/(n-1)
S33_mod = (Y_mod[2].T @ Y_mod[2])/(n-1)
S12_mod = (Y_mod[0].T @ Y_mod[1])/(n-1)
S13_mod = (Y_mod[0].T @ Y_mod[2])/(n-1)
S23_mod = (Y_mod[1].T @ Y_mod[2])/(n-1)
S11_inv2_mod = sqrtm(torch.linalg.pinv(S11_mod, rcond = 1e-06))
S22_inv2_mod = sqrtm(torch.linalg.pinv(S22_mod, rcond = 1e-06))
S33_inv2_mod = sqrtm(torch.linalg.pinv(S33_mod, rcond = 1e-06))
Sigma_tilde_mod = (Y_mod_all.T @ Y_mod_all) / (n - 1)

std_normal = torch.distributions.Normal(0, 1)
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(10000):
    W_next, Phi_next = micca_model.EM_step_stable(W_curr, Phi_curr, Y_mod_all, Sigma_tilde, p = p, rcond = None)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%1000 == 0): print(delta_W, delta_Phi)
W_hat = W_curr
Phi_hat = Phi_curr
W1 = W_hat[:, psum[0]:psum[1]]
W2 = W_hat[:, psum[1]:psum[2]]
W3 = W_hat[:, psum[2]:psum[3]]
W1_tilde = W1 @ S11_inv2_mod
W2_tilde = W2 @ S22_inv2_mod
W3_tilde = W3 @ S33_inv2_mod
a1 = (W1_tilde / torch.sqrt(torch.sum(W1_tilde**2))).T
a2 = (W2_tilde / torch.sqrt(torch.sum(W2_tilde**2))).T
a3 = (W3_tilde / torch.sqrt(torch.sum(W3_tilde**2))).T
f1_2 = S11_inv2_mod @ a1
f2_2 = S22_inv2_mod @ a2
f3_2 = S33_inv2_mod @ a3

tensor(5.6074) tensor(198.7069)
tensor(2.5940e-13) tensor(2.6059e-12)
tensor(4.0307e-13) tensor(3.2738e-12)
tensor(1.4144e-13) tensor(2.1565e-12)
tensor(2.6096e-13) tensor(3.6895e-12)
tensor(1.7221e-13) tensor(3.4053e-12)
tensor(2.7124e-13) tensor(2.6361e-12)
tensor(5.7988e-13) tensor(3.0287e-12)
tensor(8.5998e-13) tensor(6.4349e-12)
tensor(1.6333e-13) tensor(1.0418e-12)
tensor(1.8658) tensor(190.4414)
tensor(1.0515e-13) tensor(5.3042e-12)
tensor(7.8618e-14) tensor(1.2399e-12)
tensor(1.1841e-13) tensor(1.7657e-12)
tensor(3.3973e-13) tensor(3.0784e-12)
tensor(1.2258e-13) tensor(2.1245e-12)
tensor(1.9851e-13) tensor(1.7568e-12)
tensor(7.8143e-13) tensor(1.6378e-12)
tensor(1.2834e-13) tensor(3.3129e-12)
tensor(1.0626e-13) tensor(7.1765e-13)


In [None]:
Phi_1 = torch.tensor([[f1_1.T @ S11 @ f1_1, f1_1.T @ S12 @ f2_1, f1_1.T @ S13 @ f3_1],
                      [f1_1.T @ S12 @ f2_1, f2_1.T @ S22 @ f2_1, f2_1.T @ S23 @ f3_1],
                      [f1_1.T @ S13 @ f3_1, f2_1.T @ S23 @ f3_1, f3_1.T @ S33 @ f3_1]])
Phi_2 = torch.tensor([[f1_2.T @ S11_mod @ f1_2, f1_2.T @ S12_mod @ f2_2, f1_2.T @ S13_mod @ f3_2],
                      [f1_2.T @ S12_mod @ f2_2, f2_2.T @ S22_mod @ f2_2, f2_2.T @ S23_mod @ f3_2],
                      [f1_2.T @ S13_mod @ f3_2, f2_2.T @ S23_mod @ f3_2, f3_2.T @ S33_mod @ f3_2]])
print(torch.logdet(Phi_1) + torch.logdet(Phi_2))

tensor(-2.6776)


In [None]:
print(f1_2.T @ S11 @ f1_2)
print(f2_2.T @ S22 @ f2_2)
print(f3_2.T @ S33 @ f3_2)

tensor([[1.0003]])
tensor([[1.0276]])
tensor([[1.0462]])


In [None]:
print(f1_2.T @ S12 @ f2_2)
print(f1_2.T @ S13 @ f3_2)
print(f2_2.T @ S23 @ f3_2)

tensor([[0.9428]])
tensor([[1.0859]])
tensor([[1.0297]])


In [None]:
print(F1.T @ S12 @ F2)
print(F1.T @ S13 @ F3)
print(F2.T @ S23 @ F3)

tensor([[-0.3595, -0.2429],
        [ 0.2392, -0.1712]])
tensor([[ 0.0625, -0.5011],
        [ 0.5813,  0.0235]])
tensor([[ 0.3042,  0.2207],
        [-0.2623,  0.1182]])


In [None]:
W1_bar = W_em[0] @ S11_inv2
W2_bar = W_em[1] @ S22_inv2
W3_bar = W_em[2] @ S33_inv2

B1, S1, A1 = torch.linalg.svd(W1_bar, full_matrices = False)
B2, S2, A2 = torch.linalg.svd(W2_bar, full_matrices = False)
B3, S3, A3 = torch.linalg.svd(W3_bar, full_matrices = False)

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
M2 = torch.cat([torch.cat([torch.eye(3), P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, torch.eye(3), P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, torch.eye(3)], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.5966)
tensor(0.5966)


In [None]:
# Nope!
W_bar = torch.cat([W1_bar, W2_bar, W3_bar], 1)
B, S, A = torch.linalg.svd(W_bar, full_matrices = False)

A1 = A[:, psum[0]:psum[1]]
A2 = A[:, psum[1]:psum[2]]
A3 = A[:, psum[2]:psum[3]]

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

M2 = torch.cat([torch.cat([torch.eye(3), P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, torch.eye(3), P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, torch.eye(3)], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.5966)
tensor(0.9919)


In [None]:
M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
A, S, _ = torch.linalg.svd(M1, full_matrices = False)

A1 = A[psum[0]:psum[1], 0:2].T
A2 = A[psum[1]:psum[2], 0:2].T
A3 = A[psum[2]:psum[3], 0:2].T

P1 = A1.T @ A1
P2 = A2.T @ A2
P3 = A3.T @ A3

P1C12P2 = P1 @ C12_tilde @ P2
P1C13P3 = P1 @ C13_tilde @ P3
P2C23P3 = P2 @ C23_tilde @ P3

In [None]:

# Fit 1-d one time
W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(1000):
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step_alt(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): print(delta_W, delta_Phi)

# Project out learned space and fit again
a1 = W_curr[:, psum[0]:psum[1]] / torch.sqrt(torch.sum(W_curr[:, psum[0]:psum[1]]**2))
a2 = W_curr[:, psum[1]:psum[2]] / torch.sqrt(torch.sum(W_curr[:, psum[1]:psum[2]]**2))
f1 = a1 @ S11_inv2
f2 = a2 @ S22_inv2

Y_mod = [data.Y[0] @ (torch.eye(3) - f1.T @ f1), data.Y[1] @ (torch.eye(3) - f2.T @ f2)]
Y_mod_all = torch.cat(Y_mod, dim = 1)
Y_mod_all = Y_mod_all - torch.mean(Y_mod_all, 0)
Sigma_mod_tilde = (Y_mod_all.T @ Y_mod_all) / (n - 1)

W_curr = std_normal.sample([1, sum(p)])
Phi_curr = torch.eye(sum(p))
for i in range(1000):
    W_next, Phi_next = micca_model.EM_step(W_curr, Phi_curr, Sigma_mod_tilde, p = p)
    # W_next, Phi_next = micca_model.EM_step_alt(W_curr, Phi_curr, Y_all, Sigma_tilde, p = p)
    delta_W = torch.sum((W_curr - W_next)**2)
    delta_Phi = torch.sum((Phi_curr - Phi_next)**2)
    W_curr = W_next
    Phi_curr = Phi_next
    if (i%100 == 0): print(delta_W, delta_Phi)
    
# Project out learned space and check residual covariance
a1 = W_curr[:, psum[0]:psum[1]] / torch.sqrt(torch.sum(W_curr[:, psum[0]:psum[1]]**2))
a2 = W_curr[:, psum[1]:psum[2]] / torch.sqrt(torch.sum(W_curr[:, psum[1]:psum[2]]**2))
f1 = a1 @ S11_inv2
f2 = a2 @ S22_inv2

Y_mod_2 = [Y_mod[0] @ (torch.eye(3) - f1.T @ f1), Y_mod[1] @ (torch.eye(3) - f2.T @ f2)]
Y_mod_2_all = torch.cat(Y_mod_2, dim = 1)
Y_mod_2_all = Y_mod_2_all - torch.mean(Y_mod_2_all, 0)
Sigma_mod_2_tilde = (Y_mod_2_all.T @ Y_mod_2_all) / (n - 1)

print(Sigma_mod_2_tilde)

tensor([[-0.2094,  0.5825],
        [ 0.3363,  0.1146],
        [ 0.4213,  0.2283],
        [-0.5645,  0.0326],
        [-0.0425,  0.4586],
        [ 0.1896,  0.0066],
        [ 0.3092,  0.5180],
        [-0.1159, -0.0901],
        [ 0.4474, -0.3280]])

In [None]:
A

tensor([[-0.1435,  0.3387,  0.4332, -0.6028, -0.0526,  0.2022,  0.3386, -0.1160,
          0.3767],
        [ 0.6568,  0.1127,  0.2353,  0.1272,  0.3418, -0.0299,  0.5002, -0.0857,
         -0.3304]])

In [None]:
A1 @ C13_tilde @ A3.T

tensor([[0.2900, 0.0547],
        [0.0324, 0.0781]])

In [None]:
W1_bar @ (C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T

tensor([[-0.0003,  0.0012],
        [-0.0005, -0.0002]])

In [None]:
M1 = torch.cat([torch.cat([torch.eye(3), W1_bar.T @ W2_bar, W1_bar.T @ W3_bar], 1),
                torch.cat([W2_bar.T @ W1_bar, torch.eye(3), W2_bar.T @ W3_bar], 1),
                torch.cat([W3_bar.T @ W1_bar, W3_bar.T @ W2_bar, torch.eye(3)], 1)], 0)
M2 = torch.cat([torch.cat([P1, P1C12P2, P1C13P3], 1),
                torch.cat([P1C12P2.T, P2, P2C23P3], 1),
                torch.cat([P1C13P3.T, P2C23P3.T, P3], 1)], 0)
print(torch.linalg.det(M1))
print(torch.linalg.det(M2))

tensor(0.1111)
tensor(-8.0517e-27)


In [None]:
print(P1C12P2)
print(P1C13P3)
print(P2C23P3)

tensor([[ 0.3023,  0.3841, -0.0876],
        [-0.3758,  0.0207,  0.1281],
        [-0.4620,  0.0773,  0.1595]])
tensor([[ 0.4323, -0.0576, -0.4560],
        [ 0.3167, -0.0927,  0.1888],
        [ 0.4783, -0.1311,  0.1927]])
tensor([[-0.2974,  0.1198, -0.5165],
        [ 0.2443, -0.0379, -0.2027],
        [ 0.1102, -0.0421,  0.1672]])


In [None]:
print(W1_bar.T @ W2_bar)
print(W1_bar.T @ W3_bar)
print(W2_bar.T @ W3_bar)

tensor([[ 0.3040,  0.3764, -0.0885],
        [-0.3769,  0.0271,  0.1287],
        [-0.4633,  0.0849,  0.1602]])
tensor([[ 0.4341, -0.0580, -0.4556],
        [ 0.3155, -0.0925,  0.1900],
        [ 0.4770, -0.1309,  0.1944]])
tensor([[-0.2974,  0.1195, -0.5125],
        [ 0.2405, -0.0352, -0.2208],
        [ 0.1100, -0.0418,  0.1651]])


In [None]:
print(torch.mean(abs(P1C12P2 - W1_bar.T @ W2_bar)))
print(torch.mean(abs(P1C13P3 - W1_bar.T @ W3_bar)))
print(torch.mean(abs(P2C23P3 - W2_bar.T @ W3_bar)))

tensor(0.0031)
tensor(0.0009)
tensor(0.0035)


In [None]:
print(torch.mean(abs((C12_tilde - W1_bar.T @ W2_bar) @  W2_bar.T)))
print(torch.mean(abs((C13_tilde - W1_bar.T @ W3_bar) @  W3_bar.T)))
print(torch.mean(abs((C23_tilde - W2_bar.T @ W3_bar) @  W3_bar.T)))

tensor(0.0028)
tensor(0.0027)
tensor(0.0006)


In [None]:
(C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T @ torch.linalg.inv(torch.eye(1) - W3_bar @ W3_bar.T) 

tensor([[-0.0293],
        [ 0.0047],
        [-0.0501]])

In [None]:
(C13_tilde - W1_bar.T @ W3_bar) @  torch.linalg.inv(torch.eye(3) - W3_bar.T @ W3_bar) @ W3_bar.T

tensor([[-0.0293],
        [ 0.0047],
        [-0.0501]])

In [None]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T

tensor([[ 0.0073],
        [-0.0012],
        [ 0.0126]])

In [None]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T @ torch.linalg.inv(1 - W2_bar @ W2_bar.T)

tensor([[ 0.0293],
        [-0.0047],
        [ 0.0501]])

In [None]:
(C12_tilde - W1_bar.T @ W2_bar) @ W2_bar.T @ torch.linalg.inv(1 - W2_bar @ W2_bar.T) + (C13_tilde - W1_bar.T @ W3_bar) @ W3_bar.T @ torch.linalg.inv(1 - W3_bar @ W3_bar.T) 

tensor([[ 4.5940e-03, -2.8103e-04],
        [ 4.4791e-03,  2.6003e-03],
        [-9.0168e-03, -5.8655e-05]])

In [None]:
[S11 - W_em[0].T @ W_em[0], S22 - W_em[1].T @ W_em[1], S33 - W_em[2].T @ W_em[2]]

[tensor([[4.6497, 0.2221, 0.0900],
         [0.2221, 1.9237, 0.3744],
         [0.0900, 0.3744, 1.7076]]),
 tensor([[ 1.9754,  1.5128, -1.7825],
         [ 1.5128,  2.2075, -2.4684],
         [-1.7825, -2.4684,  3.6764]]),
 tensor([[ 1.5073, -1.3006,  0.8230],
         [-1.3006,  8.2849, -3.3705],
         [ 0.8230, -3.3705,  5.0412]])]

In [None]:
W_curr.T

tensor([[ 0.0027, -0.6311],
        [-0.0130, -1.2672],
        [ 1.4927,  0.2291],
        [ 0.0977, -0.7140],
        [-0.0575, -1.3654],
        [-0.1034,  0.0930],
        [-0.1427, -0.4694],
        [-0.5685, -0.8861],
        [ 1.7051,  1.6871]])

In [None]:
# Lemma 5???
torch.mean(abs((Sigma_tilde - W_curr.T @ W_curr) @ torch.linalg.inv(torch.block_diag(S11 - W_em[0].T @ W_em[0], S22 - W_em[1].T @ W_em[1], S33 - W_em[2].T @ W_em[2])) @ W_curr.T - W_curr.T))

NameError: name 'W_em' is not defined

In [None]:
std_normal = torch.distributions.Normal(0, 1)
f = std_normal.sample([3, 3])
f = f @ torch.linalg.inv(sqrtm(f.T @ f))
f.T @ f
f1 = f[:,0, None]
f2 = f[:, 1, None]
f3 = f[:, 2, None]

f1f1T = f1 @ f1.T
f2f2T = f2 @ f2.T
f3f3T = f3 @ f3.T

mat1 = torch.cat([torch.cat([torch.eye(3), f1f1T @ C12_tilde @ f2f2T, f1f1T @ C13_tilde @ f3f3T], 1),
                  torch.cat([f2f2T @ C12_tilde.T @ f1f1T, torch.eye(3), f2f2T @ C23_tilde @ f3f3T], 1),
                  torch.cat([f3f3T @ C13_tilde.T @ f1f1T, f3f3T @ C23_tilde.T @ f2f2T, torch.eye(3)], 1)], 0)
mat2 = torch.cat([torch.cat([torch.eye(1), f1.T @ C12_tilde @ f2, f1.T @ C13_tilde @ f3], 1),
                  torch.cat([f2.T @ C12_tilde.T @ f1, torch.eye(1), f2.T @ C23_tilde @ f3], 1),
                  torch.cat([f3.T @ C13_tilde.T @ f1, f3.T @ C23_tilde.T @ f2, torch.eye(1)], 1)], 0)
print(torch.linalg.svd(mat1).S)
print(torch.linalg.svd(mat2).S)

tensor([1.2103, 1.0694, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.7203])
tensor([1.2103, 1.0694, 0.7203])


In [None]:
S11_inv = torch.linalg.inv(S11)
PU1d = W_ml[0] @ S11_inv
print(PU1d @ S11 @ PU1d.T)

S22_inv = torch.linalg.inv(S22)
PU2d = W_ml[1] @ S22_inv
print(PU2d @ S22 @ PU2d.T)

S33_inv = torch.linalg.inv(S33)
PU3d = W_ml[1] @ S33_inv
print(PU3d @ S33 @ PU3d.T)

print(PU1d @ S11 @ PU1d.T + PU2d @ S22 @ PU2d.T + PU3d @ S33 @ PU3d.T)

# print(PU1d @ S12 @ PU2d.T)
# print(PU1d.T @ PU2d)  # Should = next
# print( (U1d * p_mat[0:d]) @ U2d.T )

tensor([[0.8872, 0.0354],
        [0.0354, 0.8952]])
tensor([[1.3133, 0.0112],
        [0.0112, 0.9161]])
tensor([[1.3942, 0.4744],
        [0.4744, 0.3550]])
tensor([[3.5946, 0.5210],
        [0.5210, 2.1664]])
