Skip to content

Commit

Permalink
hotfix for Zanetta
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasaarholt committed Jul 5, 2018
1 parent 04962c0 commit 20af8d4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
6 changes: 4 additions & 2 deletions hyperspy/_components/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ def _separate_fixed_and_free_expression_elements(self):

def _compute_expression_part(self, function):
model = self.model
signal_shape = model.channel_switches.shape
# TODO: Need a better way of calculating the shape than this...
signal_shape = model.axes_manager.signal_shape[::-1]
#signal_shape = model.channel_switches.shape
if model.convolved and self.convolved:
data = self._convolve(function(model.convolution_axis), model=model)
else:
axes = [ax.axis for ax in model.axes_manager.signal_axes]
mesh = np.meshgrid(*axes)
data = np.ones(signal_shape)*function(*mesh)
data = function(*mesh)*np.ones(signal_shape)
return data[np.where(model.channel_switches)]

def check_parameter_linearity(expr, name):
Expand Down
8 changes: 5 additions & 3 deletions hyperspy/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,12 +1220,14 @@ def _compute_constant_term(self):
model = self.model
if model.convolved and self.convolved:
convolved = self._convolve(self.constant_term, model=model)
data = convolved[np.where(model.channel_switches)]
data = convolved
else:
signal_shape = np.prod(model.channel_switches.shape)
# TODO: Need a better way of calculating the shape than this...
signal_shape = model.axes_manager.signal_shape[::-1]
#signal_shape = np.prod(model.channel_switches.shape)
not_convolved = self.constant_term * np.ones(signal_shape)
data = not_convolved
return data
return data[np.where(model.channel_switches)]

def _convolve(self, to_convolve, model=None):
'''Convolve component with model convolution axis
Expand Down
16 changes: 11 additions & 5 deletions hyperspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,13 +889,17 @@ def _linear_fitting(self, bounded):
number_of_free_parameters = len(self.p0)
assert number_of_free_parameters > 0, \
'Model does not contain any free components!'
# TODO: Need a better way of calculating the shape than this...
axes = [ax.axis for ax in self.axes_manager.signal_axes]
mesh = np.meshgrid(*axes)
mesh = [me[np.where(self.channel_switches)] for me in mesh]
channels_signal_shape = mesh[0].shape
#channels_signal_shape = tuple((np.prod(self.channel_switches.shape),))

signal_shape = tuple((np.prod(self.channel_switches.shape),))

comp_data = np.zeros((number_of_free_parameters,) + signal_shape)
comp_data = np.zeros((number_of_free_parameters,) + channels_signal_shape)
comp_data_constant_values = np.zeros(
(number_of_free_parameters,) + signal_shape)
fixed_comp_data = np.zeros(signal_shape)
(number_of_free_parameters,) + channels_signal_shape)
fixed_comp_data = np.zeros(channels_signal_shape)

def p0_index_from_component(component):
return self.free_parameters.index(component.free_parameters[0])
Expand Down Expand Up @@ -931,6 +935,8 @@ def _append_component(component):
else:
# No free parameters, so component is a fixed.
# Entire value of fixed components
print(fixed_comp_data.shape)
print(component._compute_component().shape)
fixed_comp_data[:] += component._compute_component()

def get_parent_twin(parameter):
Expand Down

0 comments on commit 20af8d4

Please sign in to comment.