diff --git a/diffpy/snmf/subroutines.py b/diffpy/snmf/subroutines.py index 17b32918..6b55728e 100644 --- a/diffpy/snmf/subroutines.py +++ b/diffpy/snmf/subroutines.py @@ -1,9 +1,35 @@ import numpy as np from diffpy.snmf.optimizers import get_weights from diffpy.snmf.factorizers import lsqnonneg +from diffpy.snmf.containers import ComponentSignal import numdifftools +def initialize_components(number_of_components, number_of_signals, grid_vector): + """Initializes ComponentSignals for each of the components in the decomposition + + Parameters + ---------- + number_of_components: int + The number of component signals in the NMF decomposition + number_of_signals: int + grid_vector: 1d array + The grid of the user provided signals. + + Returns + ------- + tuple of ComponentSignal objects + The tuple containing `number_of_components` of initialized ComponentSignal objects. + """ + if number_of_components <= 0: + raise ValueError(f"Number of components = {number_of_components}. Number_of_components must be >= 1.") + components = list() + for component in range(number_of_components): + component = ComponentSignal(grid_vector,number_of_signals,component) + components.append(component) + return tuple(components) + + def lift_data(data_input, lift=1): """Lifts values of data_input diff --git a/diffpy/snmf/tests/test_subroutines.py b/diffpy/snmf/tests/test_subroutines.py index 91f6f7d1..f9bad750 100644 --- a/diffpy/snmf/tests/test_subroutines.py +++ b/diffpy/snmf/tests/test_subroutines.py @@ -1,7 +1,7 @@ import pytest import numpy as np from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \ - update_weights_matrix, initialize_arrays, lift_data + update_weights_matrix, initialize_arrays, lift_data, initialize_components to = [ ([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14), @@ -144,8 +144,20 @@ def test_reconstruct_data(trd): (([[[1.5, 2], [10.5, 1], [0.5, 2]], 1]), ([[2, 2.5], [11, 1.5], [1, 2.5]])), (([[[-10, -10.5], [-12.2, -12.2], [0, 0]], 1]), ([[2.2, 1.7], [0, 0], [12.2, 12.2]])), ] + + @pytest.mark.parametrize('tld', tld) def test_lift_data(tld): actual = lift_data(tld[0][0], tld[0][1]) expected = tld[1] np.testing.assert_allclose(actual, expected) + +tcc = [(2, 3,[0, .5, 1, 1.5]), # Regular usage + #(0, 3,[0, .5, 1, 1.5]), # Zero components raise an exception. Not tested + ] +@pytest.mark.parametrize('tcc', tcc) +def test_initialize_components(tcc): + actual = initialize_components(tcc[0], tcc[1], tcc[2]) + assert len(actual) == tcc[0] + assert len(actual[0].weights) == tcc[1] + assert (actual[0].grid == np.array(tcc[2])).all()