diff --git a/src/diffpy/snmf/snmf_class.py b/src/diffpy/snmf/snmf_class.py index a62d021..658eeeb 100644 --- a/src/diffpy/snmf/snmf_class.py +++ b/src/diffpy/snmf/snmf_class.py @@ -370,13 +370,33 @@ def get_objective_function(self, residuals=None, stretch=None): function = residual_term + regularization_term + sparsity_term return function - def apply_interpolation_matrix(self, components=None, weights=None, stretch=None): + def compute_stretched_components(self, components=None, weights=None, stretch=None): """ - Applies an interpolation-based transformation to the 'components' using `stretch`, - weighted by `weights`. Optionally computes first (`d_stretched_components`) and - second (`dd_stretched_components`) derivatives. + Interpolates each component along its sample axis according to per-(component, signal) + stretch factors, then applies per-(component, signal) weights. Also computes the + first and second derivatives with respect to stretch. Left and right, respectively, + refer to the sample prior to and subsequent to the interpolated sample's position. + + Inputs + ------ + components : array, shape (signal_len, n_components) + Each column is a component with signal_len samples. + weights : array, shape (n_components, n_signals) + Per-(component, signal) weights. + stretch : array, shape (n_components, n_signals) + Per-(component, signal) stretch factors. + + Outputs + ------- + stretched_components : array, shape (signal_len, n_components * n_signals) + Interpolated and weighted components. + d_stretched_components : array, shape (signal_len, n_components * n_signals) + First derivatives with respect to stretch. + dd_stretched_components : array, shape (signal_len, n_components * n_signals) + Second derivatives with respect to stretch. """ + # --- Defaults --- if components is None: components = self.components_ if weights is None: @@ -384,72 +404,55 @@ def apply_interpolation_matrix(self, components=None, weights=None, stretch=None if stretch is None: stretch = self.stretch_ - # Compute scaled indices - stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1 - stretch_tiled = np.tile(stretch_flat, (self.signal_length, 1)) - - # Compute `fractional_indices` - fractional_indices = ( - np.tile(np.arange(self.signal_length)[:, None], (1, self.n_signals * self.n_components)) - * stretch_tiled + # Dimensions + signal_len = components.shape[0] # number of samples + n_components = components.shape[1] # number of components + n_signals = weights.shape[1] # number of signals + + # Guard stretches + eps = 1e-8 + stretch = np.clip(stretch, eps, None) + stretch_inv = 1.0 / stretch + + # Apply stretching to the original sample indices, represented as a "time-stretch" + t = np.arange(signal_len, dtype=float)[:, None, None] * stretch_inv[None, :, :] + # has shape (signal_len, n_components, n_signals) + + # For each stretched coordinate, find its prior integer (original) index and their difference + i0 = np.floor(t).astype(np.int64) # prior original index + alpha = t - i0.astype(float) # fractional distance between left/right + + # Clip indices to valid range (0, signal_len - 1) to maintain original size + max_idx = signal_len - 1 + i0 = np.clip(i0, 0, max_idx) + i1 = np.clip(i0 + 1, 0, max_idx) + + # Gather sample values + comps_3d = components[:, :, None] # expand components by a dimension for broadcasting across n_signals + c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values + c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values + + # Linear interpolation to determine stretched sample values + interp = c0 * (1.0 - alpha) + c1 * alpha + interp_weighted = interp * weights[None, :, :] + + # Derivatives + di = -t * stretch_inv[None, :, :] # first-derivative coefficient + ddi = -di * stretch_inv[None, :, :] * 2.0 # second-derivative coefficient + + d_unweighted = c0 * (-di) + c1 * di + dd_unweighted = c0 * (-ddi) + c1 * ddi + + d_weighted = d_unweighted * weights[None, :, :] + dd_weighted = dd_unweighted * weights[None, :, :] + + # Flatten back to expected shape (signal_len, n_components * n_signals) + return ( + interp_weighted.reshape(signal_len, n_components * n_signals), + d_weighted.reshape(signal_len, n_components * n_signals), + dd_weighted.reshape(signal_len, n_components * n_signals), ) - # Weighting matrix - weights_flat = weights.reshape(1, self.n_signals * self.n_components) - weights_tiled = np.tile(weights_flat, (self.signal_length, 1)) - - # Bias for indexing into reshaped components - # TODO break this up or describe what it does better - bias = np.kron( - np.arange(self.n_components) * (self.signal_length + 1), - np.ones((self.signal_length, self.n_signals), dtype=int), - ).reshape(self.signal_length, self.n_components * self.n_signals) - - # Handle boundary conditions for interpolation - components_bounded = np.vstack( - [components, components[-1, :]] - ) # Duplicate last row (like MATLAB, not sure why) - - # Compute floor indices - floor_indices = np.floor(fractional_indices).astype(int) - - floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length) - floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length) - - # Compute fractional part - fractional_floor_indices = fractional_indices - floor_indices - - # Compute offset indices - offset_indices_1 = floor_indices_1 + bias - offset_indices_2 = floor_indices_2 + bias - - # Extract values - # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line - comp_values_1 = components_bounded.flatten(order="F")[(offset_indices_1 - 1).ravel(order="F")].reshape( - self.signal_length, self.n_components * self.n_signals, order="F" - ) # order = F uses FORTRAN, column major order - comp_values_2 = components_bounded.flatten(order="F")[(offset_indices_2 - 1).ravel(order="F")].reshape( - self.signal_length, self.n_components * self.n_signals, order="F" - ) - - # Interpolation - unweighted_stretched_comps = ( - comp_values_1 * (1 - fractional_floor_indices) + comp_values_2 * fractional_floor_indices - ) - stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting - - # Compute first derivative - di = -fractional_indices * stretch_tiled - d_comps_unweighted = comp_values_1 * (-di) + comp_values_2 * di - d_stretched_components = d_comps_unweighted * weights_tiled - - # Compute second derivative - ddi = -di * stretch_tiled * 2 - dd_comps_unweighted = comp_values_1 * (-ddi) + comp_values_2 * ddi - dd_stretched_components = dd_comps_unweighted * weights_tiled - - return stretched_components, d_stretched_components, dd_stretched_components - def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None): """ Computes the transformation matrix `stretch_transformed` for residuals, @@ -560,7 +563,7 @@ def update_components(self): Updates `components` using gradient-based optimization with adaptive step size. """ # Compute stretched components using the interpolation function - stretched_components, _, _ = self.apply_interpolation_matrix() # Discard the derivatives + stretched_components, _, _ = self.compute_stretched_components() # Discard the derivatives # Compute reshaped_stretched_components and component_residuals intermediate_reshaped = stretched_components.flatten(order="F").reshape( (self.signal_length * self.n_signals, self.n_components), order="F" @@ -648,7 +651,9 @@ def regularize_function(self, stretch=None): if stretch is None: stretch = self.stretch_ - stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch) + stretched_components, d_stretch_comps, dd_stretch_comps = self.compute_stretched_components( + stretch=stretch + ) intermediate = stretched_components.flatten(order="F").reshape( (self.signal_length * self.n_signals, self.n_components), order="F" ) @@ -751,8 +756,8 @@ def reconstruct_matrix(components, weights, stretch): """ signal_len = components.shape[0] - n_signals = weights.shape[1] n_components = components.shape[1] + n_signals = weights.shape[1] reconstructed_matrix = np.zeros((signal_len, n_signals)) sample_indices = np.arange(signal_len)