diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index ec17ad6..3bccee5 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -1,6 +1,7 @@ +import cvxpy as cp import numpy as np from scipy.optimize import minimize -from scipy.sparse import coo_matrix, csc_matrix, diags +from scipy.sparse import coo_matrix, diags class SNMFOptimizer: @@ -39,13 +40,15 @@ class SNMFOptimizer: max_iter : int The maximum number of times to update each of stretch, components, and weights before stopping the optimization. + min_iter : int + The minimum number of times to update each of stretch, components, and weights before terminating + the optimization due to low/no improvement. tol : float The convergence threshold. This is the minimum fractional improvement in the - objective function to allow without terminating the optimization. Note that - a minimum of 20 updates are run before this parameter is checked. + objective function to allow without terminating the optimization. n_components : int The number of components to extract from source_matrix. Must be provided when and only when - Y0 is not provided. + init_weights is not provided. random_state : int The seed for the initial guesses at the matrices (stretch, components, and weights) created by the decomposition. @@ -66,6 +69,7 @@ def __init__( rho=0, eta=0, max_iter=500, + min_iter=20, tol=5e-7, n_components=None, random_state=None, @@ -113,12 +117,14 @@ def __init__( self.source_matrix = source_matrix self.rho = rho self.eta = eta + self.tol = tol + self.max_iter = max_iter # Capture matrix dimensions self.signal_length, self.n_signals = source_matrix.shape self.num_updates = 0 self._rng = np.random.default_rng(random_state) - # Enforce exclusive specification of n_components or Y0 + # Enforce exclusive specification of n_components or init_weights if (n_components is None and init_weights is None) or ( n_components is not None and init_weights is not None ): @@ -157,161 +163,122 @@ def __init__( self._spline_smooth_operator = 0.25 * diags( [1, -2, 1], offsets=[0, 1, 2], shape=(self.n_signals - 2, self.n_signals) ) - self._spline_smooth_penalty = self._spline_smooth_operator.T @ self._spline_smooth_operator # Set up residual matrix, objective function, and history self.residuals = self.get_residual_matrix() - self._objective_history = [] - self.update_objective() + self.objective_function = self.get_objective_function() + self.best_objective = self.objective_function + self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] self.objective_difference = None + self._objective_history = [self.objective_function] # Set up tracking variables for update_components() self._prev_components = None - self.grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now) - self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now) + self._grad_components = np.zeros_like(self.components) + self._prev_grad_components = np.zeros_like(self.components) regularization_term = 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2 sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty print( - f"Start, Objective function: {self._objective_history[-1]:.5e}" - f", Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}" + f"Start, Objective function: {self.objective_function:.5e}" + f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}" ) # Main optimization loop - for iter in range(max_iter): - self.optimize_loop() + for outiter in range(self.max_iter): + self.outiter = outiter + self.outer_loop() # Print diagnostics regularization_term = ( 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2 ) sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty print( - f"Num_updates: {self.num_updates}, " - f"Obj fun: {self._objective_history[-1]:.5e}, " - f"Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}, " - f"Iter: {iter}" + f"Obj fun: {self.objective_function:.5e}, " + f"Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}, " + f"Iter: {self.outiter}" ) - # Convergence check: decide when to terminate for small/no improvement - if self.objective_difference < self._objective_history[-1] * tol and iter >= 20: + # Convergence check: Stop if diffun is small and at least min_iter iterations have passed + print("Checking if ", self.objective_difference, " < ", self.objective_function * tol) + if self.objective_difference < self.objective_function * tol and outiter >= min_iter: break - print(self.objective_difference, " < ", self._objective_history[-1] * tol) - # Normalize our results + self.normalize_results() + + def normalize_results(self): + # Select our best results for normalization + self.components = self.best_matrices[0] + self.weights = self.best_matrices[1] + self.stretch = self.best_matrices[2] + + # Normalize weights/stretch first weights_row_max = np.max(self.weights, axis=1, keepdims=True) - stretch_row_max = np.max(self.stretch, axis=1, keepdims=True) self.weights = self.weights / weights_row_max + stretch_row_max = np.max(self.stretch, axis=1, keepdims=True) self.stretch = self.stretch / stretch_row_max - # loop to normalize components - # effectively just re-running class with non-normalized components, normalized wts/stretch as inputs, - # then only update components - self._prev_components = None - self.grad_components = np.zeros_like(self.components) - self._prev_grad_components = np.zeros_like(self.components) + # effectively just re-running with component updates only vs normalized weights/stretch + self._grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now) + self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now) self.residuals = self.get_residual_matrix() + self.objective_function = self.get_objective_function() self.objective_difference = None - self._objective_history = [] - self.update_objective() - for norm_iter in range(100): + self._objective_history = [self.objective_function] + self.outiter = 0 + self.iter = 0 + for outiter in range(self.max_iter): + if iter == 1: + self.iter = 1 # So step size can adapt without an inner loop self.update_components() self.residuals = self.get_residual_matrix() - self.update_objective() - print(f"Objective function after normalize_components: {self._objective_history[-1]:.5e}") - self._objective_history.append(self._objective_history[-1]) + self.objective_function = self.get_objective_function() + print(f"Objective function after normalize_components: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) self.objective_difference = self._objective_history[-2] - self._objective_history[-1] - if self.objective_difference < self._objective_history[-1] * tol and norm_iter >= 20: + if self.objective_difference < self.objective_function * self.tol and outiter >= 7: break - # end of normalization (and program) - # note that objective function may not fully recover after normalization, this is okay - print("Finished optimization.") - - def optimize_loop(self): - # Update components first - self._prev_grad_components = self.grad_components.copy() - - self.update_components() - self.num_updates += 1 - self.residuals = self.get_residual_matrix() - self.update_objective() - print(f"Objective function after update_components: {self._objective_history[-1]:.5e}") - - if self.objective_difference is None: + def outer_loop(self): + for iter in range(4): + self.iter = iter + self._prev_grad_components = self._grad_components.copy() + self.update_components() + self.residuals = self.get_residual_matrix() + self.objective_function = self.get_objective_function() + print(f"Objective function after update_components: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) self.objective_difference = self._objective_history[-2] - self._objective_history[-1] + if self.objective_function < self.best_objective: + self.best_objective = self.objective_function + self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] - # Now we update weights - self.update_weights() + self.update_weights() + self.residuals = self.get_residual_matrix() + self.objective_function = self.get_objective_function() + print(f"Objective function after update_weights: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) + self.objective_difference = self._objective_history[-2] - self._objective_history[-1] + if self.objective_function < self.best_objective: + self.best_objective = self.objective_function + self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] - self.num_updates += 1 - self.residuals = self.get_residual_matrix() - self.update_objective() - print(f"Objective function after update_weights: {self._objective_history[-1]:.5e}") + self.objective_difference = self._objective_history[-2] - self._objective_history[-1] + if self._objective_history[-3] - self.objective_function < self.objective_difference * 1e-3: + break - # Now we update stretch self.update_stretch() - - self.num_updates += 1 self.residuals = self.get_residual_matrix() - self.update_objective() - print(f"Objective function after update_stretch: {self._objective_history[-1]:.5e}") - + self.objective_function = self.get_objective_function() + print(f"Objective function after update_stretch: {self.objective_function:.5e}") + self._objective_history.append(self.objective_function) self.objective_difference = self._objective_history[-2] - self._objective_history[-1] - - def apply_interpolation(self, a, x, return_derivatives=False): - """ - Applies an interpolation-based transformation to `x` based on scaling `a`. - Also can compute first (`d_intr_x`) and second (`dd_intr_x`) derivatives. - """ - x_len = len(x) - - # Ensure `a` is an array and reshape for broadcasting - a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D - - # Compute fractional indices, broadcasting over `a` - fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M) - - integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M)) - valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds - - # Apply valid_mask to keep correct indices - idx_int = np.where( - valid_mask, integer_indices, x_len - 2 - ) # Prevent out-of-bounds indexing (previously "I") - idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i") - - # Ensure x is a 1D array - x = np.asarray(x).ravel() - - # Compute interpolated_x (linear interpolation) - interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * ( - idx_frac - idx_int - ) - - # Fill the tail with the last valid value - intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :]) - interpolated_x = np.vstack([interpolated_x, intr_x_tail]) - - if return_derivatives: - # Compute first derivative (d_intr_x) - di = -idx_frac / a - d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di - d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))]) - - # Compute second derivative (dd_intr_x) - ddi = -di / a + idx_frac * a**-2 - dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi - dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))]) - else: - # Make placeholders - d_intr_x = np.empty(interpolated_x.shape) - dd_intr_x = np.empty(interpolated_x.shape) - - return interpolated_x, d_intr_x, dd_intr_x + if self.objective_function < self.best_objective: + self.best_objective = self.objective_function + self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()] def get_residual_matrix(self, components=None, weights=None, stretch=None): # Initialize residual matrix as negative of source_matrix - # In MATLAB this is getR if components is None: components = self.components if weights is None: @@ -320,15 +287,12 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None): stretch = self.stretch residuals = -self.source_matrix.copy() # Compute transformed components for all (k, m) pairs - for k in range(weights.shape[0]): - stretched_components, _, _ = self.apply_interpolation( - stretch[k, :], components[:, k] - ) # Only calculate Ax + for k in range(weights.shape[0]): # K + stretched_components, _, _ = apply_interpolation(stretch[k, :], components[:, k]) # Only use Ax residuals += weights[k, :] * stretched_components # Element-wise scaling and sum return residuals - def update_objective(self, residuals=None, stretch=None): - to_return = not (residuals is None and stretch is None) + def get_objective_function(self, residuals=None, stretch=None): if residuals is None: residuals = self.residuals if stretch is None: @@ -338,17 +302,13 @@ def update_objective(self, residuals=None, stretch=None): sparsity_term = self.eta * np.sum(np.sqrt(self.components)) # Square root penalty # Final objective function value function = residual_term + regularization_term + sparsity_term + return function - if to_return: - return function # Get value directly for use - else: - self._objective_history.append(function) # Store value - - def apply_interpolation_matrix(self, components=None, weights=None, stretch=None, return_derivatives=False): + def apply_interpolation_matrix(self, components=None, weights=None, stretch=None): """ - Applies an interpolation-based transformation to the matrix `components` using `stretch`, - weighted by `weights`. Optionally computes first and second derivatives. - Equivalent to getAfun_matrix in MATLAB. + Applies an interpolation-based transformation to the 'components' using `stretch`, + weighted by `weights`. Optionally computes first (`d_stretched_components`) and + second (`dd_stretched_components`) derivatives. """ if components is None: @@ -358,80 +318,76 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None if stretch is None: stretch = self.stretch - # Compute scaled indices (MATLAB: AA = repmat(reshape(A',1,M*K).^-1, N,1)) + # Compute scaled indices stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1 stretch_tiled = np.tile(stretch_flat, (self.signal_length, 1)) - # Compute `ii` (MATLAB: ii = repmat((0:N-1)',1,K*M).*tiled_stretch) + # Compute `fractional_indices` fractional_indices = ( np.tile(np.arange(self.signal_length)[:, None], (1, self.n_signals * self.n_components)) * stretch_tiled ) - # Weighting matrix (MATLAB: YY = repmat(reshape(Y',1,M*K), N,1)) + # Weighting matrix weights_flat = weights.reshape(1, self.n_signals * self.n_components) weights_tiled = np.tile(weights_flat, (self.signal_length, 1)) - # Bias for indexing into reshaped X (MATLAB: bias = kron((0:K-1)*(N+1),ones(N,M))) + # Bias for indexing into reshaped components # TODO break this up or describe what it does better bias = np.kron( np.arange(self.n_components) * (self.signal_length + 1), np.ones((self.signal_length, self.n_signals), dtype=int), ).reshape(self.signal_length, self.n_components * self.n_signals) - # Handle boundary conditions for interpolation (MATLAB: X1=[X;X(end,:)]) - components_bounded = np.vstack([components, components[-1, :]]) # Duplicate last row (like MATLAB) + # Handle boundary conditions for interpolation + components_bounded = np.vstack( + [components, components[-1, :]] + ) # Duplicate last row (like MATLAB, not sure why) - # Compute floor indices (MATLAB: II = floor(ii); II1=min(II+1,N+1); II2=min(II1+1,N+1)) + # Compute floor indices floor_indices = np.floor(fractional_indices).astype(int) - floor_ind_1 = np.minimum(floor_indices + 1, self.signal_length) - floor_ind_2 = np.minimum(floor_ind_1 + 1, self.signal_length) + floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length) + floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length) - # Compute fractional part (MATLAB: iI = ii - II) + # Compute fractional part fractional_floor_indices = fractional_indices - floor_indices - # Compute offset indices (MATLAB: II1_ = II1 + bias; II2_ = II2 + bias) - offset_floor_ind_1 = floor_ind_1 + bias - offset_floor_ind_2 = floor_ind_2 + bias + # Compute offset indices + offset_indices_1 = floor_indices_1 + bias + offset_indices_2 = floor_indices_2 + bias - # Extract values (MATLAB: XI1 = reshape(X1(II1_), N, K*M); XI2 = reshape(X1(II2_), N, K*M)) + # Extract values # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line - # order = F uses FORTRAN, column major order - components_val_1 = components_bounded.flatten(order="F")[(offset_floor_ind_1 - 1).ravel()].reshape( - self.signal_length, self.n_components * self.n_signals - ) - components_val_2 = components_bounded.flatten(order="F")[(offset_floor_ind_2 - 1).ravel()].reshape( - self.signal_length, self.n_components * self.n_signals + comp_values_1 = components_bounded.flatten(order="F")[(offset_indices_1 - 1).ravel(order="F")].reshape( + self.signal_length, self.n_components * self.n_signals, order="F" + ) # order = F uses FORTRAN, column major order + comp_values_2 = components_bounded.flatten(order="F")[(offset_indices_2 - 1).ravel(order="F")].reshape( + self.signal_length, self.n_components * self.n_signals, order="F" ) - # Interpolation (MATLAB: Ax2=XI1.*(1-iI)+XI2.*(iI); stretched_components=Ax2.*YY) - stretch_components2 = ( - components_val_1 * (1 - fractional_floor_indices) + components_val_2 * fractional_floor_indices + # Interpolation + unweighted_stretched_comps = ( + comp_values_1 * (1 - fractional_floor_indices) + comp_values_2 * fractional_floor_indices ) - stretched_components = stretch_components2 * weights_tiled # Apply weighting - - if return_derivatives: - # Compute first derivative (MATLAB: Tx2=XI1.*(-di)+XI2.*di; d_str_cmps=Tx2.*YY) - di = -fractional_indices * stretch_tiled - d_components2 = components_val_1 * (-di) + components_val_2 * di - d_stretch_components = d_components2 * weights_tiled - - # Compute second derivative (MATLAB: Hx2=XI1.*(-ddi)+XI2.*ddi; dd_str_components=Hx2.*YY) - ddi = -di * stretch_tiled * 2 - dd_components2 = components_val_1 * (-ddi) + components_val_2 * ddi - dd_stretch_components = dd_components2 * weights_tiled - else: - shape = stretched_components.shape - d_stretch_components = np.empty(shape) - dd_stretch_components = np.empty(shape) + stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting + + # Compute first derivative + di = -fractional_indices * stretch_tiled + d_comps_unweighted = comp_values_1 * (-di) + comp_values_2 * di + d_stretched_components = d_comps_unweighted * weights_tiled + + # Compute second derivative + ddi = -di * stretch_tiled * 2 + dd_comps_unweighted = comp_values_1 * (-ddi) + comp_values_2 * ddi + dd_stretched_components = dd_comps_unweighted * weights_tiled - return stretched_components, d_stretch_components, dd_stretch_components + return stretched_components, d_stretched_components, dd_stretched_components def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None): """ - Computes the transformation matrix `stretch_transformed` for `residuals`, - using scaling matrix `stretch` and coefficients `weights`. + Computes the transformation matrix `stretch_transformed` for residuals, + using scaling matrix `stretch` and weight coefficients `weights`. """ if stretch is None: @@ -441,38 +397,36 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None if residuals is None: residuals = self.residuals - # Compute scaling matrix (MATLAB: AA = repmat(reshape(A,1,M*K).^-1,Nindex,1)) + # Compute scaling matrix stretch_tiled = np.tile( stretch.reshape(1, self.n_signals * self.n_components, order="F") ** -1, (self.signal_length, 1) ) - # Compute indices (MATLAB: ii = repmat((index-1)',1,K*M).*AA) - indices = np.arange(self.signal_length)[:, None] * stretch_tiled # Shape (N, M*K), replacing `index` + # Compute indices + indices = np.arange(self.signal_length)[:, None] * stretch_tiled - # Weighting coefficients (MATLAB: YY = repmat(reshape(Y,1,M*K),Nindex,1)) + # Weighting coefficients weights_tiled = np.tile( weights.reshape(1, self.n_signals * self.n_components, order="F"), (self.signal_length, 1) ) - # Compute floor indices (MATLAB: II = floor(ii); II1 = min(II+1,N+1); II2 = min(II1+1,N+1)) + # Compute floor indices floor_indices = np.floor(indices).astype(int) floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length) floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length) - # Compute fractional part (MATLAB: iI = ii - II) + # Compute fractional part fractional_indices = indices - floor_indices - # Expand row indices (MATLAB: repm = repmat(1:K, Nindex, M)) + # Expand row indices repm = np.tile(np.arange(self.n_components), (self.signal_length, self.n_signals)) - # Compute transformations (MATLAB: kro = kron(R(index,:), ones(1, K))) + # Compute transformations kron = np.kron(residuals, np.ones((1, self.n_components))) - - # (MATLAB: kroiI = kro .* (iI); iIYY = (iI-1) .* YY) fractional_kron = kron * fractional_indices fractional_weights = (fractional_indices - 1) * weights_tiled - # Construct sparse matrices (MATLAB: sparse(II1_,repm,kro.*-iIYY,(N+1),K)) + # Construct sparse matrices x2 = coo_matrix( ((-kron * fractional_weights).flatten(), (floor_indices_1.flatten() - 1, repm.flatten())), shape=(self.signal_length + 1, self.n_components), @@ -492,64 +446,56 @@ def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None return stretch_transformed - def solve_quadratic_program(self, t, m, alg="trust-constr"): + def solve_quadratic_program(self, t, m): """ - Solves the quadratic program for updating y in stretched NMF using scipy.optimize: + Solves the quadratic program for updating y in stretched NMF: - min J(y) = 0.5 * y^T Q y + d^T y + min J(y) = 0.5 * y^T q y + d^T y subject to: 0 ≤ y ≤ 1 - Uses the 'trust-constr' solver with the analytical gradient and Hessian. - Alternatively, can use scipy's L-BFGS-B algorithm, which supports bound - constraints. - Parameters: - - t: (N, K) ndarray - Matrix computed from getAfun(A(k, m), X[:, k]). - - m: int - Index of the current column in source_matrix. + - t: (N, k) ndarray + - source_matrix_col: (N,) column of source_matrix for the corresponding m Returns: - - y: (k,) ndarray - Optimal solution for y, clipped to ensure non-negativity. + - y: (k,) optimal solution """ + source_matrix_col = self.source_matrix[:, m] - q = t.T @ t - d = -t.T @ source_matrix_col - k = q.shape[0] - reg_factor = 1e-8 * np.linalg.norm(q, ord="fro") - q += np.eye(k) * reg_factor - def objective(y): - return 0.5 * y @ q @ y + d @ y + # Compute q and d + q = t.T @ t # Gram matrix (k x k) + d = -t.T @ source_matrix_col # Linear term (k,) - def grad(y): - return q @ y + d + k = q.shape[0] # Number of variables - if alg == "trust-constr": + # Regularize q to ensure positive semi-definiteness + reg_factor = 1e-8 * np.linalg.norm(q, ord="fro") # Adaptive regularization, original was fixed + q += np.eye(k) * reg_factor - def hess(y): - return csc_matrix(q) # sparse format for efficiency + # Define optimization variable + y = cp.Variable(k) - bounds = [(0, 1)] * k - y0 = np.clip(-np.linalg.solve(q + np.eye(k) * 1e-5, d), 0, 1) - result = minimize( - objective, y0, method="trust-constr", jac=grad, hess=hess, bounds=bounds, options={"verbose": 0} - ) - elif alg == "L-BFGS-B": - bounds = [(0, 1) for _ in range(k)] # per-variable bounds - y0 = np.clip(-np.linalg.solve(q + np.eye(k) * 1e-5, d), 0, 1) # Initial guess - result = minimize(objective, y0, method="L-BFGS-B", jac=grad, bounds=bounds) + # Define quadratic objective + objective = cp.Minimize(0.5 * cp.quad_form(y, q) + d.T @ y) + + # Define constraints (0 ≤ y ≤ 1) + constraints = [y >= 0, y <= 1] + + # Solve using a QP solver + prob = cp.Problem(objective, constraints) + prob.solve(solver=cp.OSQP, verbose=False) - return np.maximum(result.x, 0) + # Get the solution + return np.maximum(y.value, 0) # Ensure non-negative values in case of solver tolerance issues def update_components(self): """ - Updates `components` using gradient-based optimization with adaptive step size step_size. + Updates `components` using gradient-based optimization with adaptive step size. """ - # Compute `stretched_components` using the interpolation function - stretched_components, _, _ = self.apply_interpolation_matrix() # Skip the other two outputs (derivatives) - # Compute RA and RR + # Compute stretched components using the interpolation function + stretched_components, _, _ = self.apply_interpolation_matrix() # Discard the derivatives + # Compute reshaped_stretched_components and component_residuals intermediate_reshaped = stretched_components.flatten(order="F").reshape( (self.signal_length * self.n_signals, self.n_components), order="F" ) @@ -557,32 +503,32 @@ def update_components(self): (self.signal_length, self.n_signals), order="F" ) component_residuals = reshaped_stretched_components - self.source_matrix - # Compute gradient `GraX` - self.grad_components = self.apply_transformation_matrix( + # Compute gradient + self._grad_components = self.apply_transformation_matrix( residuals=component_residuals - ).toarray() # toarray equivalent of MATLAB "full", makes non-sparse + ).toarray() # toarray equivalent of full, make non-sparse # Compute initial step size `initial_step_size` initial_step_size = np.linalg.eigvalsh(self.weights.T @ self.weights).max() * np.max( [self.stretch.max(), 1 / self.stretch.min()] ) # Compute adaptive step size `step_size` - if self._prev_components is None: + if self.outiter == 0 and self.iter == 0: step_size = initial_step_size else: num = np.sum( - (self.grad_components - self._prev_grad_components) * (self.components - self._prev_components) - ) # Elem-wise multiply + (self._grad_components - self._prev_grad_components) * (self.components - self._prev_components) + ) # Element-wise multiplication denom = np.linalg.norm(self.components - self._prev_components, "fro") ** 2 # Frobenius norm squared step_size = num / denom if denom > 0 else initial_step_size if step_size <= 0: step_size = initial_step_size - # Store our old component matrix before updating because it is used in step selection + # Store our old X before updating because it is used in step selection self._prev_components = self.components.copy() while True: # iterate updating components - components_step = self._prev_components - self.grad_components / step_size + components_step = self._prev_components - self._grad_components / step_size # Solve x^3 + p*x + q = 0 for the largest real root self.components = np.square(cubic_largest_real_root(-components_step, self.eta / (2 * step_size))) # Mask values that should be set to zero @@ -594,7 +540,7 @@ def update_components(self): ) self.components = mask * self.components - objective_improvement = self._objective_history[-1] - self.update_objective( + objective_improvement = self._objective_history[-1] - self.get_objective_function( residuals=self.get_residual_matrix() ) @@ -608,17 +554,18 @@ def update_components(self): def update_weights(self): """ - Updates weights using matrix operations, solving a quadratic program via to do so. + Updates weights using matrix operations, solving a quadratic program to do so. """ - for m in range(self.n_signals): - t = np.zeros((self.signal_length, self.n_components)) + signal_length = self.signal_length + n_signals = self.n_signals + + for m in range(n_signals): + t = np.zeros((signal_length, self.n_components)) - # Populate T using apply_interpolation + # Populate t using apply_interpolation for k in range(self.n_components): - t[:, k] = self.apply_interpolation( - self.stretch[k, m], self.components[:, k], return_derivatives=True - )[0].squeeze() + t[:, k] = apply_interpolation(self.stretch[k, m], self.components[:, k])[0].squeeze() # Solve quadratic problem for y y = self.solve_quadratic_program(t=t, m=m) @@ -627,91 +574,132 @@ def update_weights(self): self.weights[:, m] = y def regularize_function(self, stretch=None): - """ - Computes the regularization function, gradient, and Hessian for optimization. - Returns: - - fun: Objective function value (scalar) - - gra: Gradient (same shape as stretch) - """ if stretch is None: stretch = self.stretch - # Compute interpolated matrices - stretched_components, d_stretch_components, dd_stretch_components = self.apply_interpolation_matrix( - stretch=stretch, return_derivatives=True - ) + K = self.n_components + M = self.n_signals + N = self.signal_length - # Compute residual - intermediate_diff = stretched_components.flatten(order="F").reshape( - (self.signal_length * self.n_signals, self.n_components), order="F" - ) - stretch_difference = intermediate_diff.sum(axis=1).reshape((self.signal_length, self.n_signals), order="F") - stretch_difference = stretch_difference - self.source_matrix + stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch) + intermediate = stretched_components.flatten(order="F").reshape((N * M, K), order="F") + residuals = intermediate.sum(axis=1).reshape((N, M), order="F") - self.source_matrix - # Compute objective function - reg_func = self.update_objective(stretch_difference, stretch) + fun = self.get_objective_function(residuals, stretch) - # Compute gradient - tiled_derivative = np.sum( - d_stretch_components * np.tile(stretch_difference, (1, self.n_components)), axis=0 - ) - der_reshaped = np.asarray(tiled_derivative).reshape((self.n_signals, self.n_components), order="F") - func_grad = ( - der_reshaped.T + self.rho * stretch @ self._spline_smooth_operator.T @ self._spline_smooth_operator - ) + tiled_res = np.tile(residuals, (1, K)) + grad_flat = np.sum(d_stretch_comps * tiled_res, axis=0) + gra = grad_flat.reshape((M, K), order="F").T + gra += self.rho * stretch @ (self._spline_smooth_operator.T @ self._spline_smooth_operator) + + # Hessian would go here - return reg_func, func_grad + return fun, gra def update_stretch(self): """ - Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB). + Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB). """ - # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input) - stretch_init_vec = self.stretch.flatten() + # Flatten A for compatibility with the optimizer (since SciPy expects 1D input) + stretch_flat_initial = self.stretch.flatten() # Define the optimization function def objective(stretch_vec): stretch_matrix = stretch_vec.reshape(self.stretch.shape) # Reshape back to matrix form - func, grad = self.regularize_function(stretch_matrix) - grad = grad.flatten() - return func, grad + fun, gra = self.regularize_function(stretch_matrix) + gra = gra.flatten() + return fun, gra # Optimization constraints: lower bound 0.1, no upper bound - bounds = [(0.1, None)] * stretch_init_vec.size # Equivalent to 0.1 * ones(K, M) + bounds = [(0.1, None)] * stretch_flat_initial.size # Equivalent to 0.1 * ones(K, M) # Solve optimization problem (equivalent to fmincon) result = minimize( - fun=lambda stretch_vec: objective(stretch_vec)[0], # Objective function - x0=stretch_init_vec, # Initial guess - method="trust-constr", # Equivalent to 'trust-region-reflective' + fun=lambda stretch_vec: objective(stretch_vec)[0], + x0=stretch_flat_initial, + method="trust-constr", # Substitute for 'trust-region-reflective' jac=lambda stretch_vec: objective(stretch_vec)[1], # Gradient - bounds=bounds, # Lower bounds on stretch - # TODO: A Hessian can be incorporated for better convergence. + bounds=bounds, ) - # Update stretch with the optimized values + # Update A with the optimized values self.stretch = result.x.reshape(self.stretch.shape) def cubic_largest_real_root(p, q): """ - Vectorized solver for x^3 + p*x + q = 0. - Returns the largest real root element-wise. + Solves x^3 + p*x + q = 0 element-wise for matrices, returning the largest real root. """ - # calculate the discriminant + # Handle special case where q == 0 + y = np.where(q == 0, np.maximum(0, -p) ** 0.5, np.zeros_like(p)) # q=0 case + + # Compute discriminant delta = (q / 2) ** 2 + (p / 3) ** 3 - sqrt_delta = np.sqrt(np.abs(delta)) - # When delta >= 0: one real root - a = np.cbrt(-q / 2 + sqrt_delta) - b = np.cbrt(-q / 2 - sqrt_delta) - root1 = a + b + # Compute square root of delta safely + d = np.where(delta >= 0, np.sqrt(delta), np.sqrt(np.abs(delta)) * 1j) + # TODO: this line causes a warning but results seem correct + + # Compute cube roots safely + a1 = (-q / 2 + d) ** (1 / 3) + a2 = (-q / 2 - d) ** (1 / 3) + + # Compute cube roots of unity + w = (np.sqrt(3) * 1j - 1) / 2 + + # Compute the three possible roots (element-wise) + y1 = a1 + a2 + y2 = w * a1 + w**2 * a2 + y3 = w**2 * a1 + w * a2 + + # Take the largest real root element-wise when delta < 0 + real_roots = np.stack([np.real(y1), np.real(y2), np.real(y3)], axis=0) + y = np.max(real_roots, axis=0) * (delta < 0) # Keep only real roots when delta < 0 + + return y + + +def apply_interpolation(a, x): + """ + Applies an interpolation-based transformation to `x` based on scaling `a`. + Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives. + """ + x_len = len(x) + + # Ensure `a` is an array and reshape for broadcasting + a = np.atleast_1d(np.asarray(a)) # Ensures a is at least 1D + + # Compute fractional indices, broadcasting over `a` + fractional_indices = np.arange(x_len)[:, None] / a # Shape (N, M) + + integer_indices = np.floor(fractional_indices).astype(int) # Integer part (still (N, M)) + valid_mask = integer_indices < (x_len - 1) # Ensure indices are within bounds + + # Apply valid_mask to keep correct indices + idx_int = np.where(valid_mask, integer_indices, x_len - 2) # Prevent out-of-bounds indexing (previously "I") + idx_frac = np.where(valid_mask, fractional_indices, integer_indices) # Keep aligned (previously "i") + + # Ensure x is a 1D array + x = np.asarray(x).ravel() + + # Compute interpolated_x (linear interpolation) + interpolated_x = x[idx_int] * (1 - idx_frac + idx_int) + x[np.minimum(idx_int + 1, x_len - 1)] * ( + idx_frac - idx_int + ) + + # Fill the tail with the last valid value + intr_x_tail = np.full((x_len - len(idx_int), interpolated_x.shape[1]), interpolated_x[-1, :]) + interpolated_x = np.vstack([interpolated_x, intr_x_tail]) + + # Compute first derivative (d_intr_x) + di = -idx_frac / a + d_intr_x = x[idx_int] * (-di) + x[np.minimum(idx_int + 1, x_len - 1)] * di + d_intr_x = np.vstack([d_intr_x, np.zeros((x_len - len(idx_int), d_intr_x.shape[1]))]) - # When delta < 0: three real roots, use trigonometric method - phi = np.arccos(-q / (2 * np.sqrt(-((p / 3) ** 3) + 1e-12))) - r = 2 * np.sqrt(-p / 3) - root2 = r * np.cos(phi / 3) + # Compute second derivative (dd_intr_x) + ddi = -di / a + idx_frac * a**-2 + dd_intr_x = x[idx_int] * (-ddi) + x[np.minimum(idx_int + 1, x_len - 1)] * ddi + dd_intr_x = np.vstack([dd_intr_x, np.zeros((x_len - len(idx_int), dd_intr_x.shape[1]))]) - # Choose correct root depending on sign of delta - return np.where(delta >= 0, root1, root2) + return interpolated_x, d_intr_x, dd_intr_x