Skip to content

Commit 9b0c509

Browse files
john-halloranJohn Halloran
andauthored
refactor: change get_objective_function into a static method and getter (#174)
* refactor: compute objective function in a static method and retrieve via getter * refactor: change get_objective_function into a static method and getter * fix: update location for _tests-on-pr.yml * fix: use expected name for requirements/tests.txt * fix: add cvxpy to requirements/conda.txt * fix: add matplotlib to requirements/conda.txt --------- Co-authored-by: John Halloran <jhalloran@oxy.edu>
1 parent d83619e commit 9b0c509

File tree

6 files changed

+176
-15
lines changed

6 files changed

+176
-15
lines changed

.github/workflows/tests-on-pr.yml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
name: Tests on PR
22

33
on:
4-
push:
5-
branches:
6-
- main
74
pull_request:
85
workflow_dispatch:
96

107
jobs:
11-
validate:
12-
uses: Billingegroup/release-scripts/.github/workflows/_tests-on-pr.yml@v0
8+
tests-on-pr:
9+
uses: scikit-package/release-scripts/.github/workflows/_tests-on-pr.yml@v0
1310
with:
1411
project: diffpy.snmf
1512
c_extension: false

news/declass-obj.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
**Added:**
2+
3+
* Implement tests for ``compute_objective_function()``
4+
5+
**Changed:**
6+
7+
* Refactor ``get_objective_function()`` into a static method and getter
8+
9+
**Deprecated:**
10+
11+
* <news item>
12+
13+
**Removed:**
14+
15+
* <news item>
16+
17+
**Fixed:**
18+
19+
* <news item>
20+
21+
**Security:**
22+
23+
* <news item>

requirements/conda.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
numpy
2+
matplotlib
23
scipy
4+
cvxpy
35
diffpy.utils
46
numdifftools
File renamed without changes.

src/diffpy/snmf/snmf_class.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,30 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
359359
return residuals
360360

361361
def get_objective_function(self, residuals=None, stretch=None):
362-
if residuals is None:
363-
residuals = self.residuals
364-
if stretch is None:
365-
stretch = self.stretch_
366-
residual_term = 0.5 * np.linalg.norm(residuals, "fro") ** 2
367-
regularization_term = 0.5 * self.rho * np.linalg.norm(self._spline_smooth_operator @ stretch.T, "fro") ** 2
368-
sparsity_term = self.eta * np.sum(np.sqrt(self.components_)) # Square root penalty
369-
# Final objective function value
370-
function = residual_term + regularization_term + sparsity_term
371-
return function
362+
"""
363+
Return the objective value, passing stored attributes or overrides
364+
to _compute_objective_function().
365+
366+
Parameters
367+
----------
368+
residuals : ndarray, optional
369+
Residual matrix to use instead of self.residuals.
370+
stretch : ndarray, optional
371+
Stretch matrix to use instead of self.stretch_.
372+
373+
Returns
374+
-------
375+
float
376+
Current objective function value.
377+
"""
378+
return SNMFOptimizer._compute_objective_function(
379+
components=self.components_,
380+
residuals=self.residuals if residuals is None else residuals,
381+
stretch=self.stretch_ if stretch is None else stretch,
382+
rho=self.rho,
383+
eta=self.eta,
384+
spline_smooth_operator=self._spline_smooth_operator,
385+
)
372386

373387
def compute_stretched_components(self, components=None, weights=None, stretch=None):
374388
"""
@@ -702,6 +716,59 @@ def objective(stretch_vec):
702716
# Update stretch with the optimized values
703717
self.stretch_ = result.x.reshape(self.stretch_.shape)
704718

719+
@staticmethod
720+
def _compute_objective_function(components, residuals, stretch, rho, eta, spline_smooth_operator):
721+
r"""
722+
Computes the objective function used in stretched non-negative matrix factorization.
723+
724+
Parameters
725+
----------
726+
components : ndarray
727+
Non-negative matrix of component signals :math:`X`.
728+
residuals : ndarray
729+
Difference between reconstructed and observed data.
730+
stretch : ndarray
731+
Stretching factors :math:`A` applied to each component across samples.
732+
rho : float
733+
Regularization parameter enforcing smooth variation in :math:`A`.
734+
eta : float
735+
Sparsity-promoting regularization parameter applied to :math:`X`.
736+
spline_smooth_operator : ndarray
737+
Linear operator :math:`L` penalizing non-smooth changes in :math:`A`.
738+
739+
Returns
740+
-------
741+
float
742+
Value of the stretched-NMF objective function.
743+
744+
Notes
745+
-----
746+
The stretched-NMF objective function :math:`J` is
747+
748+
.. math::
749+
750+
J(X, Y, A) =
751+
\tfrac{1}{2} \lVert Z - Y\,S(A)X \rVert_F^2
752+
+ \tfrac{\rho}{2} \lVert L A \rVert_F^2
753+
+ \eta \sum_{i,j} \sqrt{X_{ij}} \,,
754+
755+
where :math:`Z` is the data matrix, :math:`Y` contains the non-negative
756+
weights, :math:`S(A)` denotes the spline-interpolated stretching operator,
757+
and :math:`\lVert \cdot \rVert_F` is the Frobenius norm.
758+
759+
Special cases
760+
-------------
761+
- :math:`\rho = 0` — no smoothness regularization on stretching factors.
762+
- :math:`\eta = 0` — no sparsity promotion on components.
763+
- :math:`\rho = \eta = 0` — reduces to the classical NMF least-squares
764+
objective :math:`\tfrac{1}{2} \lVert Z - YX \rVert_F^2`.
765+
766+
"""
767+
residual_term = 0.5 * np.linalg.norm(residuals, "fro") ** 2
768+
regularization_term = 0.5 * rho * np.linalg.norm(spline_smooth_operator @ stretch.T, "fro") ** 2
769+
sparsity_term = eta * np.sum(np.sqrt(components))
770+
return residual_term + regularization_term + sparsity_term
771+
705772

706773
def cubic_largest_real_root(p, q):
707774
"""

tests/test_snmf_optimizer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,75 @@ def test_final_objective_below_threshold(inputs):
4545
# Basic sanity check and the actual assertion
4646
assert np.isfinite(model.objective_function)
4747
assert model.objective_function < 5e6
48+
49+
50+
@pytest.mark.parametrize(
51+
"inputs, expected",
52+
# inputs tuple: (components, residuals, stretch, rho, eta, spline smoothness operator)
53+
[
54+
# Case 0: No smoothness or sparsity penalty, reduces to standard NMF objective
55+
# residual Frobenius norm^2 = 3^2 + 4^2 = 25 -> 0.5 * 25 = 12.5
56+
(
57+
(
58+
np.array([[0.0, 0.0], [3.0, 4.0]]),
59+
np.array([[0.0, 0.0], [3.0, 4.0]]),
60+
np.ones((2, 2)),
61+
0.0,
62+
0.0,
63+
np.zeros((2, 2)),
64+
),
65+
12.5,
66+
),
67+
# Case 1: rho = 0, sparsity penalty only
68+
# sqrt components sum = 1 + 2 + 3 + 4 = 10 -> eta * 10 = 5
69+
# residual term remains 12.5 -> total = 17.5
70+
(
71+
(
72+
np.array([[1.0, 4.0], [9.0, 16.0]]),
73+
np.array([[3.0, 4.0], [0.0, 0.0]]),
74+
np.ones((2, 2)),
75+
0.0,
76+
0.5,
77+
np.zeros((2, 2)),
78+
),
79+
17.5,
80+
),
81+
# Case 2: eta = 0, smoothness penalty only
82+
# residual = 12.5, smoothing = 0.5 * 1 * 1 = 0.5 -> total = 13.0
83+
(
84+
(
85+
np.array([[1.0, 2.0], [3.0, 4.0]]),
86+
np.array([[3.0, 4.0], [0.0, 0.0]]),
87+
np.array([[1.0, 2.0]]),
88+
1.0,
89+
0.0,
90+
np.array([[1.0, -1.0]]),
91+
),
92+
13.0,
93+
),
94+
# Case 3: penalty for smoothness and sparsity
95+
# residual = 2.5, sparsity = 1.5, smoothing = 9 -> total = 13.0
96+
(
97+
(
98+
np.array([[1.0, 4.0]]),
99+
np.array([[1.0, 2.0]]),
100+
np.array([[1.0, 4.0]]),
101+
2.0,
102+
0.5,
103+
np.array([[3.0, 0.0]]),
104+
),
105+
13.0,
106+
),
107+
],
108+
)
109+
def test_compute_objective_function(inputs, expected):
110+
components, residuals, stretch, rho, eta, operator = inputs
111+
result = SNMFOptimizer._compute_objective_function(
112+
components=components,
113+
residuals=residuals,
114+
stretch=stretch,
115+
rho=rho,
116+
eta=eta,
117+
spline_smooth_operator=operator,
118+
)
119+
assert np.isclose(result, expected)

0 commit comments

Comments
 (0)