-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SECSI solver as initialization alignment step (#15)
* SECSI solving * Working * Add SECSI solver * Offer both fitting approaches * Possibly done * Done * Fixes * Fix
- Loading branch information
Showing
7 changed files
with
331 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from itertools import combinations | ||
import numpy as np | ||
import tensorly as tl | ||
from tensorly.tenalg import khatri_rao, mode_dot | ||
from .jointdiag import jointdiag | ||
|
||
|
||
def SECSI(X, d: int, maxIter: int = 50, tolerance=1e-12, verbose=True): | ||
""" | ||
Computes Semi-Algebraic CP factorization using a joint diagonalization algorithm: | ||
Args: | ||
X: 3-mode tensor to be factorized | ||
d: estimated rank of factorization | ||
maxIter: number of jointdiag iterations to compute per estimate | ||
""" | ||
|
||
R = X.ndim # Number of modes | ||
|
||
# Compute truncated higher order SVD, cut off to d(rank) elements | ||
core_trunc, factors_trunc = tl.decomposition.tucker(X, rank=[d] * len(X.shape)) | ||
|
||
# Initialize dataframe for estimate matrices | ||
f_estimates = [] | ||
norm_est = [] | ||
|
||
# Loop finds all valid Simultaneuous Matrix Diagonalizations(SMDs) | ||
for k_mode, l_mode in combinations(range(R), 2): | ||
# Find all combinations of modes (k,l) longer than or equal to the rank d | ||
if X.shape[k_mode] < d or X.shape[l_mode] < d: | ||
continue | ||
|
||
# Find 3rd mode | ||
mode_not_kl = list(set(range(R)) - {k_mode, l_mode})[0] | ||
|
||
# Compute n-mode product between core and factor matrices for 3rd mode | ||
Skl = mode_dot(core_trunc, factors_trunc[mode_not_kl], mode_not_kl) | ||
|
||
# Rearranges tensor so that k,l are first 2 modes | ||
SMD = Skl.transpose(k_mode, l_mode, mode_not_kl) | ||
|
||
# Compute 2 norm condition number for each matrix slice, save values | ||
conds = np.linalg.cond(SMD.T) | ||
|
||
### | ||
# Using computed SMDs, we now generate factor matrices through joint diagonalization | ||
|
||
# Save matrix slice with minimal norm | ||
optimal_slice = SMD[:, :, np.argmin(conds)] # Pivot Slice | ||
|
||
# Sets up left and right hand side SMDs using pivot slice | ||
# Solves matrix equation optimal * X = n-th slice --> n-th slice / optimal | ||
# Solves lhs version, optimal * X = n-th slice ^T --> optimal / n-th slice | ||
SMD_rhs = np.linalg.solve(optimal_slice.T, SMD.T).T | ||
SMD_lhs = np.linalg.solve(optimal_slice, np.moveaxis(SMD, 2, 0)).T | ||
|
||
for SMD_sel, first_mode, second_mode in zip( | ||
[SMD_rhs, SMD_lhs], (k_mode, l_mode), (l_mode, k_mode) | ||
): | ||
# Compute joint diagonalization of all matrix slices in SMD | ||
Diags, Transform = jointdiag( | ||
SMD_sel, | ||
MaxIter=maxIter, | ||
threshold=tolerance, | ||
verbose=verbose, | ||
) | ||
|
||
cp_tensor = tl.cp_tensor.CPTensor( | ||
(None, [np.zeros_like(f) for f in factors_trunc]) | ||
) | ||
|
||
# Now compute two estimates of all three factor matrices... | ||
# First estimate based on factor * transform matrix | ||
cp_tensor.factors[first_mode] = factors_trunc[first_mode] @ Transform | ||
|
||
# Picks out diagonal of n-th slice of SMD, saves n-th row of krp matrix | ||
cp_tensor.factors[mode_not_kl] = np.diagonal(Diags) | ||
|
||
# Khatri rao product of other two estimates | ||
if first_mode < mode_not_kl: | ||
krp = khatri_rao( | ||
(cp_tensor.factors[first_mode], cp_tensor.factors[mode_not_kl]) | ||
) | ||
else: | ||
krp = khatri_rao( | ||
(cp_tensor.factors[mode_not_kl], cp_tensor.factors[first_mode]) | ||
) | ||
|
||
# Estimates final factor matrix by solving least squares matrix equation using least squares solution to X * krp = unfolding | ||
cp_tensor.factors[second_mode], resid, _, _ = np.linalg.lstsq( | ||
krp, tl.unfold(X, second_mode).T, rcond=None | ||
) | ||
cp_tensor.factors[second_mode] = cp_tensor.factors[second_mode].T | ||
|
||
f_estimates.append(cp_tensor) | ||
norm_est.append(resid.sum()) | ||
|
||
# Stops if previous loop found nothing | ||
if len(norm_est) == 0: | ||
raise Warning("No SMDs found, too many rank deficiencies") | ||
|
||
# TODO: Sort f_ests based on error | ||
return norm_est, f_estimates |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from itertools import combinations | ||
import numpy as np | ||
|
||
|
||
def jointdiag( | ||
SMD: np.ndarray, | ||
MaxIter: int = 50, | ||
threshold: float = 1e-10, | ||
verbose=False, | ||
): | ||
""" | ||
Jointly diagonalizes n matrices, organized in tensor of dimension (k,k,n). | ||
Returns Diagonalized matrices. | ||
If showQ = True, returns transform matrix in second index. | ||
If showError = True, returns estimate of error in third index. | ||
""" | ||
|
||
X = SMD.copy() | ||
D = X.shape[0] # Dimension of square matrix slices | ||
assert X.ndim == 3, "Input must be a 3D tensor" | ||
assert D == X.shape[1], "All slices must be square" | ||
assert np.all(np.isreal(X)), "Must be real-valued" | ||
|
||
# Initial error calculation | ||
# Transpose is because np.tril operates on the last two dimensions | ||
e = ( | ||
np.linalg.norm(X) ** 2.0 | ||
- np.linalg.norm(np.diagonal(X, axis1=1, axis2=2)) ** 2.0 | ||
) | ||
|
||
if verbose: | ||
print(f"Sweep # 0: e = {e:.3e}") | ||
|
||
# Additional output parameters | ||
Q_total = np.eye(D) | ||
|
||
for k in range(MaxIter): | ||
# loop over all pairs of slices | ||
for p, q in combinations(range(D), 2): | ||
# Finds matrix slice with greatest variability among diagonal elements | ||
d_ = X[p, p, :] - X[q, q, :] | ||
h = np.argmax(np.abs(d_)) | ||
|
||
# List of indices | ||
all_but_pq = list(set(range(D)) - set([p, q])) | ||
|
||
# Compute certain quantities | ||
dh = d_[h] | ||
Xh = X[:, :, h] | ||
Kh = np.dot(Xh[p, all_but_pq], Xh[q, all_but_pq]) - np.dot( | ||
Xh[all_but_pq, p], Xh[all_but_pq, q] | ||
) | ||
Gh = ( | ||
np.linalg.norm(Xh[p, all_but_pq]) ** 2 | ||
+ np.linalg.norm(Xh[q, all_but_pq]) ** 2 | ||
+ np.linalg.norm(Xh[all_but_pq, p]) ** 2 | ||
+ np.linalg.norm(Xh[all_but_pq, q]) ** 2 | ||
) | ||
xih = Xh[p, q] - Xh[q, p] | ||
|
||
# Build shearing matrix out of these quantities | ||
yk = np.arctanh((Kh - xih * dh) / (2 * (dh**2 + xih**2) + Gh)) | ||
|
||
# Inverse of Sk on left side | ||
pvec = X[p, :, :].copy() | ||
X[p, :, :] = X[p, :, :] * np.cosh(yk) - X[q, :, :] * np.sinh(yk) | ||
X[q, :, :] = -pvec * np.sinh(yk) + X[q, :, :] * np.cosh(yk) | ||
|
||
# Sk on right side | ||
pvec = X[:, p, :].copy() | ||
X[:, p, :] = X[:, p, :] * np.cosh(yk) + X[:, q, :] * np.sinh(yk) | ||
X[:, q, :] = pvec * np.sinh(yk) + X[:, q, :] * np.cosh(yk) | ||
|
||
# Update Q_total | ||
pvec = Q_total[:, p].copy() | ||
Q_total[:, p] = Q_total[:, p] * np.cosh(yk) + Q_total[:, q] * np.sinh(yk) | ||
Q_total[:, q] = pvec * np.sinh(yk) + Q_total[:, q] * np.cosh(yk) | ||
|
||
# Defines array of off-diagonal element differences | ||
xi_ = -X[q, p, :] - X[p, q, :] | ||
|
||
# More quantities computed | ||
Esum = 2 * np.dot(xi_, d_) | ||
Dsum = np.dot(d_, d_) - np.dot(xi_, xi_) | ||
qt = Esum / Dsum | ||
|
||
th1 = np.arctan(qt) | ||
angle_selection = np.cos(th1) * Dsum + np.sin(th1) * Esum | ||
|
||
# Defines 1 of 2 possible angles | ||
if angle_selection > 0.0: | ||
theta_k = th1 / 4 | ||
elif angle_selection < 0.0: | ||
theta_k = (th1 + np.pi) / 4 | ||
else: | ||
print("No solution found -- Jointdiag") | ||
return Esum, Dsum, qt | ||
|
||
# Given's rotation, this will minimize norm of off-diagonal elements only | ||
pvec = X[p, :, :].copy() | ||
X[p, :, :] = X[p, :, :] * np.cos(theta_k) - X[q, :, :] * np.sin(theta_k) | ||
X[q, :, :] = pvec * np.sin(theta_k) + X[q, :, :] * np.cos(theta_k) | ||
|
||
pvec = X[:, p, :].copy() | ||
X[:, p, :] = X[:, p, :] * np.cos(theta_k) - X[:, q, :] * np.sin(theta_k) | ||
X[:, q, :] = pvec * np.sin(theta_k) + X[:, q, :] * np.cos(theta_k) | ||
|
||
# Update Q_total | ||
pvec = Q_total[:, p].copy() | ||
Q_total[:, p] = Q_total[:, p] * np.cos(theta_k) - Q_total[:, q] * np.sin( | ||
theta_k | ||
) | ||
Q_total[:, q] = pvec * np.sin(theta_k) + Q_total[:, q] * np.cos(theta_k) | ||
|
||
# Error computation, check if loop needed... | ||
old_e = e | ||
e = ( | ||
np.linalg.norm(X) ** 2.0 | ||
- np.linalg.norm(np.diagonal(X, axis1=1, axis2=2)) ** 2.0 | ||
) | ||
|
||
if verbose: | ||
print(f"Sweep # {k + 1}: e = {e:.3e}") | ||
|
||
# TODO: Strangely the error increases on the first iteration | ||
if old_e - e < threshold and k > 2: | ||
break | ||
|
||
return X, Q_total |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import numpy as np | ||
import tensorly as tl | ||
from tensorly.random import random_cp | ||
from tlviz.factor_tools import factor_match_score | ||
from ..SECSI import SECSI | ||
|
||
|
||
def SECSItest(dim, true_rank, est_rank, noise=0.0, verbose=True): | ||
""" | ||
Built to test SECSI.py function. Creates three random factor matrices based on given dimension, rank. | ||
Computes CP tensor based on these matrices. | ||
Optionally adds noise to tensor. Tensor is then fed to SECSI, which outputs estimated factor matrices. | ||
Estimate tensor is built from these estimated factor matrices | ||
All estimates are evaluated, ranked by accuracy. | ||
Args: | ||
dim: tuple with desired tensor dimensions | ||
true_rank: true rank of the input tensor, length of randomized factor matrices | ||
est_rank: rank with which to compute estimate factor matrices | ||
noise: multiple of gaussian noise to be added to tensor | ||
random_state, for consistency. | ||
""" | ||
tensor_fac = random_cp(dim, true_rank, full=False) | ||
tensor = tl.cp_to_tensor(tensor_fac) | ||
|
||
# Adds noise | ||
tensor = tensor + np.random.normal(size=dim, scale=noise) | ||
|
||
norm_est, cp_estimates = SECSI(tensor, est_rank, 50, verbose=False) | ||
|
||
for cp_est in cp_estimates: | ||
assert factor_match_score(tensor_fac, cp_est) > 0.95 | ||
|
||
if verbose: | ||
for i, resid in enumerate(norm_est): | ||
if i == np.argmin(norm_est): | ||
print("Best estimate, {0}, has error: {1:.3e}".format(i, resid)) | ||
else: | ||
print("Estimate #{0} has error: {1:.3e}".format(i, resid)) | ||
return np.min(norm_est) | ||
|
||
|
||
def test_SECSI(): | ||
best_error = SECSItest((120, 90, 80), 12, 12, noise=0.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
from ..jointdiag import jointdiag | ||
|
||
|
||
def SyntheticData(k: int, d: int): | ||
""" | ||
Generates random diagonal tensor, and random mixing matrix | ||
Multiplies every slice of diagonal tensor with matrix 'synthetic' S * D * S^-1 | ||
Sends altered tensor into jointdiag function | ||
Returs diagonal tensor estimate and mixing matrix estimate | ||
""" | ||
rng = np.random.RandomState() | ||
mixing = rng.randn(d, d) | ||
diags = np.zeros((d, d, k)) | ||
synthetic = np.zeros((d, d, k)) | ||
|
||
for i in range(k): | ||
temp_diag = np.diag(rng.randn(d)) | ||
diags[:, :, i] = temp_diag | ||
synthetic[:, :, i] = np.linalg.inv(mixing) @ temp_diag @ mixing | ||
|
||
diag_est, mixing_est = jointdiag(synthetic, verbose=False) | ||
return diags, diag_est, mixing, mixing_est, synthetic | ||
|
||
|
||
def test_jointdiag(): | ||
diags, diag_est, _, _, _ = SyntheticData(40, 22) | ||
|
||
## Sorts outputted diagonal data | ||
idx_est = np.argsort(np.diag(diag_est[:, :, 0])) | ||
idx = np.argsort(np.diag(diags[:, :, 0])) | ||
diag_est = diag_est[idx_est, idx_est, :] | ||
diags = diags[idx, idx, :] | ||
|
||
np.testing.assert_allclose(diags, diag_est, atol=1e-10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.