diff --git a/diffpy/snmf/subroutines.py b/diffpy/snmf/subroutines.py index cdaa1035..58990b7b 100644 --- a/diffpy/snmf/subroutines.py +++ b/diffpy/snmf/subroutines.py @@ -171,7 +171,7 @@ def update_weights(components, data_input, method=None): for i, component in enumerate(components): stretched_components[:, i] = component.apply_stretch(signal)[0] if method == 'align': - weights = lsqnonneg(stretched_components, data_input[:,signal]) + weights = lsqnonneg(stretched_components, data_input[:, signal]) else: weights = get_weights(stretched_components.T @ stretched_components, -stretched_components.T @ data_input[:, signal], 0, 1) @@ -179,6 +179,34 @@ def update_weights(components, data_input, method=None): return weight_matrix +def reconstruct_signal(components, signal_idx): + """Reconstructs a specific signal from its weighted and stretched components. + + Calculates the linear combination of stretched components where each term is the stretched component multiplied + by its weight factor. + + Parameters + ---------- + components: tuple of ComponentSignal objects + The tuple containing the ComponentSignal objects + signal_idx: int + The index of the specific signal in the input data to be reconstructed + + Returns + ------- + 1d array like + The reconstruction of a signal from calculated weights, stretching factors, and iq values. + + """ + signal_length = len(components[0].grid) + reconstruction = np.zeros(signal_length) + for component in components: + stretched = component.apply_stretch(signal_idx)[0] + stretched_and_weighted = component.apply_weight(signal_idx, stretched) + reconstruction += stretched_and_weighted + return reconstruction + + def initialize_arrays(number_of_components, number_of_moments, signal_length): """Generates the initial guesses for the weight, stretching, and component matrices diff --git a/diffpy/snmf/tests/test_subroutines.py b/diffpy/snmf/tests/test_subroutines.py index 06e11f30..aa55ddea 100644 --- a/diffpy/snmf/tests/test_subroutines.py +++ b/diffpy/snmf/tests/test_subroutines.py @@ -3,7 +3,7 @@ from diffpy.snmf.containers import ComponentSignal from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \ update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \ - construct_component_matrix, construct_weight_matrix, update_weights + construct_component_matrix, construct_weight_matrix, update_weights, reconstruct_signal to = [ ([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14), @@ -252,3 +252,17 @@ def test_construct_weight_matrix(tcwm): def test_update_weights(tuw): actual = update_weights(tuw[0], tuw[1], tuw[2]) assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights)) + +trs = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 1), + ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 0), + ([ComponentSignal([0, .25, .5, .75, 1], 3, 0), ComponentSignal([0, .25, .5, .75, 1], 3, 1), + ComponentSignal([0, .25, .5, .75, 1], 3, 2)], 2), + # ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1), + # ComponentSignal([0, .25, .5, .75, 1], 2, 2)], -1), +] +@pytest.mark.parametrize('trs',trs) +def test_reconstruct_signal(trs): + actual = reconstruct_signal(trs[0], trs[1]) + assert len(actual) == len(trs[0][0].grid)