Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion diffpy/snmf/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,42 @@ def update_weights(components, data_input, method=None):
for i, component in enumerate(components):
stretched_components[:, i] = component.apply_stretch(signal)[0]
if method == 'align':
weights = lsqnonneg(stretched_components, data_input[:,signal])
weights = lsqnonneg(stretched_components, data_input[:, signal])
else:
weights = get_weights(stretched_components.T @ stretched_components,
-stretched_components.T @ data_input[:, signal], 0, 1)
weight_matrix[:, signal] = weights
return weight_matrix


def reconstruct_signal(components, signal_idx):
"""Reconstructs a specific signal from its weighted and stretched components.

Calculates the linear combination of stretched components where each term is the stretched component multiplied
by its weight factor.

Parameters
----------
components: tuple of ComponentSignal objects
The tuple containing the ComponentSignal objects
signal_idx: int
The index of the specific signal in the input data to be reconstructed

Returns
-------
1d array like
The reconstruction of a signal from calculated weights, stretching factors, and iq values.

"""
signal_length = len(components[0].grid)
reconstruction = np.zeros(signal_length)
for component in components:
stretched = component.apply_stretch(signal_idx)[0]
stretched_and_weighted = component.apply_weight(signal_idx, stretched)
reconstruction += stretched_and_weighted
return reconstruction


def initialize_arrays(number_of_components, number_of_moments, signal_length):
"""Generates the initial guesses for the weight, stretching, and component matrices

Expand Down
16 changes: 15 additions & 1 deletion diffpy/snmf/tests/test_subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from diffpy.snmf.containers import ComponentSignal
from diffpy.snmf.subroutines import objective_function, get_stretched_component, reconstruct_data, get_residual_matrix, \
update_weights_matrix, initialize_arrays, lift_data, initialize_components, construct_stretching_matrix, \
construct_component_matrix, construct_weight_matrix, update_weights
construct_component_matrix, construct_weight_matrix, update_weights, reconstruct_signal

to = [
([[[1, 2], [3, 4]], [[5, 6], [7, 8]], 1e11, [[1, 2], [3, 4]], [[1, 2], [3, 4]], 1], 2.574e14),
Expand Down Expand Up @@ -252,3 +252,17 @@ def test_construct_weight_matrix(tcwm):
def test_update_weights(tuw):
actual = update_weights(tuw[0], tuw[1], tuw[2])
assert np.shape(actual) == (len(tuw[0]), len(tuw[0][0].weights))

trs = [([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 1),
([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
ComponentSignal([0, .25, .5, .75, 1], 2, 2)], 0),
([ComponentSignal([0, .25, .5, .75, 1], 3, 0), ComponentSignal([0, .25, .5, .75, 1], 3, 1),
ComponentSignal([0, .25, .5, .75, 1], 3, 2)], 2),
# ([ComponentSignal([0, .25, .5, .75, 1], 2, 0), ComponentSignal([0, .25, .5, .75, 1], 2, 1),
# ComponentSignal([0, .25, .5, .75, 1], 2, 2)], -1),
]
@pytest.mark.parametrize('trs',trs)
def test_reconstruct_signal(trs):
actual = reconstruct_signal(trs[0], trs[1])
assert len(actual) == len(trs[0][0].grid)