Skip to content

Commit

Permalink
Add SECSI solver as initialization alignment step (#15)
Browse files Browse the repository at this point in the history
* SECSI solving

* Working

* Add SECSI solver

* Offer both fitting approaches

* Possibly done

* Done

* Fixes

* Fix
  • Loading branch information
aarmey committed Jun 26, 2024
1 parent 793f8ad commit 9e02976
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 6 deletions.
103 changes: 103 additions & 0 deletions parafac2/SECSI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from itertools import combinations
import numpy as np
import tensorly as tl
from tensorly.tenalg import khatri_rao, mode_dot
from .jointdiag import jointdiag


def SECSI(X, d: int, maxIter: int = 50, tolerance=1e-12, verbose=True):
"""
Computes Semi-Algebraic CP factorization using a joint diagonalization algorithm:
Args:
X: 3-mode tensor to be factorized
d: estimated rank of factorization
maxIter: number of jointdiag iterations to compute per estimate
"""

R = X.ndim # Number of modes

# Compute truncated higher order SVD, cut off to d(rank) elements
core_trunc, factors_trunc = tl.decomposition.tucker(X, rank=[d] * len(X.shape))

# Initialize dataframe for estimate matrices
f_estimates = []
norm_est = []

# Loop finds all valid Simultaneuous Matrix Diagonalizations(SMDs)
for k_mode, l_mode in combinations(range(R), 2):
# Find all combinations of modes (k,l) longer than or equal to the rank d
if X.shape[k_mode] < d or X.shape[l_mode] < d:
continue

# Find 3rd mode
mode_not_kl = list(set(range(R)) - {k_mode, l_mode})[0]

# Compute n-mode product between core and factor matrices for 3rd mode
Skl = mode_dot(core_trunc, factors_trunc[mode_not_kl], mode_not_kl)

# Rearranges tensor so that k,l are first 2 modes
SMD = Skl.transpose(k_mode, l_mode, mode_not_kl)

# Compute 2 norm condition number for each matrix slice, save values
conds = np.linalg.cond(SMD.T)

###
# Using computed SMDs, we now generate factor matrices through joint diagonalization

# Save matrix slice with minimal norm
optimal_slice = SMD[:, :, np.argmin(conds)] # Pivot Slice

# Sets up left and right hand side SMDs using pivot slice
# Solves matrix equation optimal * X = n-th slice --> n-th slice / optimal
# Solves lhs version, optimal * X = n-th slice ^T --> optimal / n-th slice
SMD_rhs = np.linalg.solve(optimal_slice.T, SMD.T).T
SMD_lhs = np.linalg.solve(optimal_slice, np.moveaxis(SMD, 2, 0)).T

for SMD_sel, first_mode, second_mode in zip(
[SMD_rhs, SMD_lhs], (k_mode, l_mode), (l_mode, k_mode)
):
# Compute joint diagonalization of all matrix slices in SMD
Diags, Transform = jointdiag(
SMD_sel,
MaxIter=maxIter,
threshold=tolerance,
verbose=verbose,
)

cp_tensor = tl.cp_tensor.CPTensor(
(None, [np.zeros_like(f) for f in factors_trunc])
)

# Now compute two estimates of all three factor matrices...
# First estimate based on factor * transform matrix
cp_tensor.factors[first_mode] = factors_trunc[first_mode] @ Transform

# Picks out diagonal of n-th slice of SMD, saves n-th row of krp matrix
cp_tensor.factors[mode_not_kl] = np.diagonal(Diags)

# Khatri rao product of other two estimates
if first_mode < mode_not_kl:
krp = khatri_rao(
(cp_tensor.factors[first_mode], cp_tensor.factors[mode_not_kl])
)
else:
krp = khatri_rao(
(cp_tensor.factors[mode_not_kl], cp_tensor.factors[first_mode])
)

# Estimates final factor matrix by solving least squares matrix equation using least squares solution to X * krp = unfolding
cp_tensor.factors[second_mode], resid, _, _ = np.linalg.lstsq(
krp, tl.unfold(X, second_mode).T, rcond=None
)
cp_tensor.factors[second_mode] = cp_tensor.factors[second_mode].T

f_estimates.append(cp_tensor)
norm_est.append(resid.sum())

# Stops if previous loop found nothing
if len(norm_est) == 0:
raise Warning("No SMDs found, too many rank deficiencies")

# TODO: Sort f_ests based on error
return norm_est, f_estimates
129 changes: 129 additions & 0 deletions parafac2/jointdiag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from itertools import combinations
import numpy as np


def jointdiag(
SMD: np.ndarray,
MaxIter: int = 50,
threshold: float = 1e-10,
verbose=False,
):
"""
Jointly diagonalizes n matrices, organized in tensor of dimension (k,k,n).
Returns Diagonalized matrices.
If showQ = True, returns transform matrix in second index.
If showError = True, returns estimate of error in third index.
"""

X = SMD.copy()
D = X.shape[0] # Dimension of square matrix slices
assert X.ndim == 3, "Input must be a 3D tensor"
assert D == X.shape[1], "All slices must be square"
assert np.all(np.isreal(X)), "Must be real-valued"

# Initial error calculation
# Transpose is because np.tril operates on the last two dimensions
e = (
np.linalg.norm(X) ** 2.0
- np.linalg.norm(np.diagonal(X, axis1=1, axis2=2)) ** 2.0
)

if verbose:
print(f"Sweep # 0: e = {e:.3e}")

# Additional output parameters
Q_total = np.eye(D)

for k in range(MaxIter):
# loop over all pairs of slices
for p, q in combinations(range(D), 2):
# Finds matrix slice with greatest variability among diagonal elements
d_ = X[p, p, :] - X[q, q, :]
h = np.argmax(np.abs(d_))

# List of indices
all_but_pq = list(set(range(D)) - set([p, q]))

# Compute certain quantities
dh = d_[h]
Xh = X[:, :, h]
Kh = np.dot(Xh[p, all_but_pq], Xh[q, all_but_pq]) - np.dot(
Xh[all_but_pq, p], Xh[all_but_pq, q]
)
Gh = (
np.linalg.norm(Xh[p, all_but_pq]) ** 2
+ np.linalg.norm(Xh[q, all_but_pq]) ** 2
+ np.linalg.norm(Xh[all_but_pq, p]) ** 2
+ np.linalg.norm(Xh[all_but_pq, q]) ** 2
)
xih = Xh[p, q] - Xh[q, p]

# Build shearing matrix out of these quantities
yk = np.arctanh((Kh - xih * dh) / (2 * (dh**2 + xih**2) + Gh))

# Inverse of Sk on left side
pvec = X[p, :, :].copy()
X[p, :, :] = X[p, :, :] * np.cosh(yk) - X[q, :, :] * np.sinh(yk)
X[q, :, :] = -pvec * np.sinh(yk) + X[q, :, :] * np.cosh(yk)

# Sk on right side
pvec = X[:, p, :].copy()
X[:, p, :] = X[:, p, :] * np.cosh(yk) + X[:, q, :] * np.sinh(yk)
X[:, q, :] = pvec * np.sinh(yk) + X[:, q, :] * np.cosh(yk)

# Update Q_total
pvec = Q_total[:, p].copy()
Q_total[:, p] = Q_total[:, p] * np.cosh(yk) + Q_total[:, q] * np.sinh(yk)
Q_total[:, q] = pvec * np.sinh(yk) + Q_total[:, q] * np.cosh(yk)

# Defines array of off-diagonal element differences
xi_ = -X[q, p, :] - X[p, q, :]

# More quantities computed
Esum = 2 * np.dot(xi_, d_)
Dsum = np.dot(d_, d_) - np.dot(xi_, xi_)
qt = Esum / Dsum

th1 = np.arctan(qt)
angle_selection = np.cos(th1) * Dsum + np.sin(th1) * Esum

# Defines 1 of 2 possible angles
if angle_selection > 0.0:
theta_k = th1 / 4
elif angle_selection < 0.0:
theta_k = (th1 + np.pi) / 4
else:
print("No solution found -- Jointdiag")
return Esum, Dsum, qt

# Given's rotation, this will minimize norm of off-diagonal elements only
pvec = X[p, :, :].copy()
X[p, :, :] = X[p, :, :] * np.cos(theta_k) - X[q, :, :] * np.sin(theta_k)
X[q, :, :] = pvec * np.sin(theta_k) + X[q, :, :] * np.cos(theta_k)

pvec = X[:, p, :].copy()
X[:, p, :] = X[:, p, :] * np.cos(theta_k) - X[:, q, :] * np.sin(theta_k)
X[:, q, :] = pvec * np.sin(theta_k) + X[:, q, :] * np.cos(theta_k)

# Update Q_total
pvec = Q_total[:, p].copy()
Q_total[:, p] = Q_total[:, p] * np.cos(theta_k) - Q_total[:, q] * np.sin(
theta_k
)
Q_total[:, q] = pvec * np.sin(theta_k) + Q_total[:, q] * np.cos(theta_k)

# Error computation, check if loop needed...
old_e = e
e = (
np.linalg.norm(X) ** 2.0
- np.linalg.norm(np.diagonal(X, axis1=1, axis2=2)) ** 2.0
)

if verbose:
print(f"Sweep # {k + 1}: e = {e:.3e}")

# TODO: Strangely the error increases on the first iteration
if old_e - e < threshold and k > 2:
break

return X, Q_total
7 changes: 7 additions & 0 deletions parafac2/parafac2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cupy as cp
from tqdm import tqdm
import tensorly as tl
from .SECSI import SECSI
from tensorly.decomposition import parafac
from sklearn.utils.extmath import randomized_svd
from .utils import (
Expand Down Expand Up @@ -57,6 +58,7 @@ def parafac2_nd(
n_iter_max: int = 200,
tol: float = 1e-6,
random_state: Optional[int] = None,
SECSI_solver=False,
callback: Optional[Callable[[int, float, list, list], None]] = None,
) -> tuple[tuple, float]:
r"""The same interface as regular PARAFAC2."""
Expand Down Expand Up @@ -84,6 +86,11 @@ def parafac2_nd(
err = reconstruction_error(factors, projections, projected_X, norm_tensor)
errs = [err]

if SECSI_solver:
SECSerror, factorOuts = SECSI(projected_X, rank, verbose=False)
factors = factorOuts[np.argmin(SECSerror)].factors

print("")
tq = tqdm(range(n_iter_max), disable=(not verbose))
for iteration in tq:
jump = beta_i + 1.0
Expand Down
45 changes: 45 additions & 0 deletions parafac2/tests/test_SECSI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import tensorly as tl
from tensorly.random import random_cp
from tlviz.factor_tools import factor_match_score
from ..SECSI import SECSI


def SECSItest(dim, true_rank, est_rank, noise=0.0, verbose=True):
"""
Built to test SECSI.py function. Creates three random factor matrices based on given dimension, rank.
Computes CP tensor based on these matrices.
Optionally adds noise to tensor. Tensor is then fed to SECSI, which outputs estimated factor matrices.
Estimate tensor is built from these estimated factor matrices
All estimates are evaluated, ranked by accuracy.
Args:
dim: tuple with desired tensor dimensions
true_rank: true rank of the input tensor, length of randomized factor matrices
est_rank: rank with which to compute estimate factor matrices
noise: multiple of gaussian noise to be added to tensor
random_state, for consistency.
"""
tensor_fac = random_cp(dim, true_rank, full=False)
tensor = tl.cp_to_tensor(tensor_fac)

# Adds noise
tensor = tensor + np.random.normal(size=dim, scale=noise)

norm_est, cp_estimates = SECSI(tensor, est_rank, 50, verbose=False)

for cp_est in cp_estimates:
assert factor_match_score(tensor_fac, cp_est) > 0.95

if verbose:
for i, resid in enumerate(norm_est):
if i == np.argmin(norm_est):
print("Best estimate, {0}, has error: {1:.3e}".format(i, resid))
else:
print("Estimate #{0} has error: {1:.3e}".format(i, resid))
return np.min(norm_est)


def test_SECSI():
best_error = SECSItest((120, 90, 80), 12, 12, noise=0.0)
35 changes: 35 additions & 0 deletions parafac2/tests/test_jointdiag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from ..jointdiag import jointdiag


def SyntheticData(k: int, d: int):
"""
Generates random diagonal tensor, and random mixing matrix
Multiplies every slice of diagonal tensor with matrix 'synthetic' S * D * S^-1
Sends altered tensor into jointdiag function
Returs diagonal tensor estimate and mixing matrix estimate
"""
rng = np.random.RandomState()
mixing = rng.randn(d, d)
diags = np.zeros((d, d, k))
synthetic = np.zeros((d, d, k))

for i in range(k):
temp_diag = np.diag(rng.randn(d))
diags[:, :, i] = temp_diag
synthetic[:, :, i] = np.linalg.inv(mixing) @ temp_diag @ mixing

diag_est, mixing_est = jointdiag(synthetic, verbose=False)
return diags, diag_est, mixing, mixing_est, synthetic


def test_jointdiag():
diags, diag_est, _, _, _ = SyntheticData(40, 22)

## Sorts outputted diagonal data
idx_est = np.argsort(np.diag(diag_est[:, :, 0]))
idx = np.argsort(np.diag(diags[:, :, 0]))
diag_est = diag_est[idx_est, idx_est, :]
diags = diags[idx, idx, :]

np.testing.assert_allclose(diags, diag_est, atol=1e-10)
16 changes: 11 additions & 5 deletions parafac2/tests/test_parafac2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,24 @@ def test_init_reprod(sparse: bool):
cp.testing.assert_array_equal(proj1[ii], proj2[ii])


@pytest.mark.parametrize("SECSI_solver", [False, True])
@pytest.mark.parametrize("sparse", [False, True])
def test_parafac2(sparse: bool):
def test_parafac2(sparse: bool, SECSI_solver: bool):
"""Test for equivalence to TensorLy's PARAFAC2."""
X_ann = pf2_to_anndata(X, sparse=sparse)

(w1, f1, p1), e1 = parafac2_nd(X_ann, rank=3, random_state=1)
(w1, f1, p1), e1 = parafac2_nd(
X_ann, rank=3, random_state=1, SECSI_solver=SECSI_solver
)

# Test that the model still matches the data
err = _parafac2_reconstruction_error(X, (w1, f1, p1)) ** 2
np.testing.assert_allclose(1.0 - err / norm_tensor, e1, rtol=1e-5)

# Test reproducibility
(w2, f2, p2), e2 = parafac2_nd(X_ann, rank=3, random_state=1)
(w2, f2, p2), e2 = parafac2_nd(
X_ann, rank=3, random_state=1, SECSI_solver=SECSI_solver
)
# Compare to TensorLy
wT, fT, pT = parafac2(
X,
Expand Down Expand Up @@ -115,16 +120,17 @@ def test_pf2_r2x():
np.testing.assert_allclose(err, errCMF, rtol=1e-8)


@pytest.mark.parametrize("SECSI_solver", [False, True])
@pytest.mark.parametrize("sparse", [False, True])
def test_performance(sparse: bool):
def test_performance(sparse: bool, SECSI_solver: bool):
"""Test for equivalence to TensorLy's PARAFAC2."""
# 5000 by 2000 by 300 is roughly the lupus data
pf2shape = [(5_000, 2_000)] * 60
X = random_parafac2(pf2shape, rank=12, full=True, random_state=2)

X = pf2_to_anndata(X, sparse=sparse)

(w1, f1, p1), e1 = parafac2_nd(X, rank=9)
(w1, f1, p1), e1 = parafac2_nd(X, rank=9, SECSI_solver=SECSI_solver)


def test_total_norm():
Expand Down
Loading

0 comments on commit 9e02976

Please sign in to comment.