Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 76 additions & 71 deletions src/diffpy/snmf/snmf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,86 +370,89 @@ def get_objective_function(self, residuals=None, stretch=None):
function = residual_term + regularization_term + sparsity_term
return function

def apply_interpolation_matrix(self, components=None, weights=None, stretch=None):
def compute_stretched_components(self, components=None, weights=None, stretch=None):
"""
Applies an interpolation-based transformation to the 'components' using `stretch`,
weighted by `weights`. Optionally computes first (`d_stretched_components`) and
second (`dd_stretched_components`) derivatives.
Interpolates each component along its sample axis according to per-(component, signal)
stretch factors, then applies per-(component, signal) weights. Also computes the
first and second derivatives with respect to stretch. Left and right, respectively,
refer to the sample prior to and subsequent to the interpolated sample's position.

Inputs
------
components : array, shape (signal_len, n_components)
Each column is a component with signal_len samples.
weights : array, shape (n_components, n_signals)
Per-(component, signal) weights.
stretch : array, shape (n_components, n_signals)
Per-(component, signal) stretch factors.

Outputs
-------
stretched_components : array, shape (signal_len, n_components * n_signals)
Interpolated and weighted components.
d_stretched_components : array, shape (signal_len, n_components * n_signals)
First derivatives with respect to stretch.
dd_stretched_components : array, shape (signal_len, n_components * n_signals)
Second derivatives with respect to stretch.
"""

# --- Defaults ---
if components is None:
components = self.components_
if weights is None:
weights = self.weights_
if stretch is None:
stretch = self.stretch_

# Compute scaled indices
stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1
stretch_tiled = np.tile(stretch_flat, (self.signal_length, 1))

# Compute `fractional_indices`
fractional_indices = (
np.tile(np.arange(self.signal_length)[:, None], (1, self.n_signals * self.n_components))
* stretch_tiled
# Dimensions
signal_len = components.shape[0] # number of samples
n_components = components.shape[1] # number of components
n_signals = weights.shape[1] # number of signals

# Guard stretches
eps = 1e-8
stretch = np.clip(stretch, eps, None)
stretch_inv = 1.0 / stretch

# Apply stretching to the original sample indices, represented as a "time-stretch"
t = np.arange(signal_len, dtype=float)[:, None, None] * stretch_inv[None, :, :]
# has shape (signal_len, n_components, n_signals)

# For each stretched coordinate, find its prior integer (original) index and their difference
i0 = np.floor(t).astype(np.int64) # prior original index
alpha = t - i0.astype(float) # fractional distance between left/right

# Clip indices to valid range (0, signal_len - 1) to maintain original size
max_idx = signal_len - 1
i0 = np.clip(i0, 0, max_idx)
i1 = np.clip(i0 + 1, 0, max_idx)

# Gather sample values
comps_3d = components[:, :, None] # expand components by a dimension for broadcasting across n_signals
c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values
c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values

# Linear interpolation to determine stretched sample values
interp = c0 * (1.0 - alpha) + c1 * alpha
interp_weighted = interp * weights[None, :, :]

# Derivatives
di = -t * stretch_inv[None, :, :] # first-derivative coefficient
ddi = -di * stretch_inv[None, :, :] * 2.0 # second-derivative coefficient

d_unweighted = c0 * (-di) + c1 * di
dd_unweighted = c0 * (-ddi) + c1 * ddi

d_weighted = d_unweighted * weights[None, :, :]
dd_weighted = dd_unweighted * weights[None, :, :]

# Flatten back to expected shape (signal_len, n_components * n_signals)
return (
interp_weighted.reshape(signal_len, n_components * n_signals),
d_weighted.reshape(signal_len, n_components * n_signals),
dd_weighted.reshape(signal_len, n_components * n_signals),
)

# Weighting matrix
weights_flat = weights.reshape(1, self.n_signals * self.n_components)
weights_tiled = np.tile(weights_flat, (self.signal_length, 1))

# Bias for indexing into reshaped components
# TODO break this up or describe what it does better
bias = np.kron(
np.arange(self.n_components) * (self.signal_length + 1),
np.ones((self.signal_length, self.n_signals), dtype=int),
).reshape(self.signal_length, self.n_components * self.n_signals)

# Handle boundary conditions for interpolation
components_bounded = np.vstack(
[components, components[-1, :]]
) # Duplicate last row (like MATLAB, not sure why)

# Compute floor indices
floor_indices = np.floor(fractional_indices).astype(int)

floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length)
floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length)

# Compute fractional part
fractional_floor_indices = fractional_indices - floor_indices

# Compute offset indices
offset_indices_1 = floor_indices_1 + bias
offset_indices_2 = floor_indices_2 + bias

# Extract values
# Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
comp_values_1 = components_bounded.flatten(order="F")[(offset_indices_1 - 1).ravel(order="F")].reshape(
self.signal_length, self.n_components * self.n_signals, order="F"
) # order = F uses FORTRAN, column major order
comp_values_2 = components_bounded.flatten(order="F")[(offset_indices_2 - 1).ravel(order="F")].reshape(
self.signal_length, self.n_components * self.n_signals, order="F"
)

# Interpolation
unweighted_stretched_comps = (
comp_values_1 * (1 - fractional_floor_indices) + comp_values_2 * fractional_floor_indices
)
stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting

# Compute first derivative
di = -fractional_indices * stretch_tiled
d_comps_unweighted = comp_values_1 * (-di) + comp_values_2 * di
d_stretched_components = d_comps_unweighted * weights_tiled

# Compute second derivative
ddi = -di * stretch_tiled * 2
dd_comps_unweighted = comp_values_1 * (-ddi) + comp_values_2 * ddi
dd_stretched_components = dd_comps_unweighted * weights_tiled

return stretched_components, d_stretched_components, dd_stretched_components

def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None):
"""
Computes the transformation matrix `stretch_transformed` for residuals,
Expand Down Expand Up @@ -560,7 +563,7 @@ def update_components(self):
Updates `components` using gradient-based optimization with adaptive step size.
"""
# Compute stretched components using the interpolation function
stretched_components, _, _ = self.apply_interpolation_matrix() # Discard the derivatives
stretched_components, _, _ = self.compute_stretched_components() # Discard the derivatives
# Compute reshaped_stretched_components and component_residuals
intermediate_reshaped = stretched_components.flatten(order="F").reshape(
(self.signal_length * self.n_signals, self.n_components), order="F"
Expand Down Expand Up @@ -648,7 +651,9 @@ def regularize_function(self, stretch=None):
if stretch is None:
stretch = self.stretch_

stretched_components, d_stretch_comps, dd_stretch_comps = self.apply_interpolation_matrix(stretch=stretch)
stretched_components, d_stretch_comps, dd_stretch_comps = self.compute_stretched_components(
stretch=stretch
)
intermediate = stretched_components.flatten(order="F").reshape(
(self.signal_length * self.n_signals, self.n_components), order="F"
)
Expand Down Expand Up @@ -751,8 +756,8 @@ def reconstruct_matrix(components, weights, stretch):
"""

signal_len = components.shape[0]
n_signals = weights.shape[1]
n_components = components.shape[1]
n_signals = weights.shape[1]

reconstructed_matrix = np.zeros((signal_len, n_signals))
sample_indices = np.arange(signal_len)
Expand Down