Skip to content

Commit

Permalink
Fixed the frameset line_spectra method with tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
canismarko committed Aug 15, 2019
1 parent 9afdcee commit ec45fef
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 63 deletions.
18 changes: 12 additions & 6 deletions tests/test_frameset.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,39 +443,45 @@ def test_edge_mask(self):
# Check that the new edge mask is a boolean array
self.assertEqual(fs.edge_mask().dtype, bool)

def test_line_spectra(self):
store = MockStore()
frames = np.multiply(*np.meshgrid(np.arange(0, 16), np.arange(0, 16)))
frames = np.broadcast_to(frames, (1, 4, 16, 16))
store.get_dataset = mock.MagicMock(return_value=frames)
store.intensities = frames
store.pixel_sizes = np.ones(shape=(*frames.shape[:2], 2))
frameset = self.create_frameset(store=store)
xy0, xy1 = ((1, 1), (1, 13))
result = frameset.line_spectra(xy0=xy0, xy1=xy1)
self.assertEqual(result.shape, (12, 4))

def test_spectrum(self):
store = MockStore()

# Prepare fake energy data
energies = np.linspace(8300, 8400, num=51)
store.energies = np.broadcast_to(energies, (10, 51))

# Prepare fake spectrum (absorbance) data
spectrum = np.sin((energies-8300)*4*np.pi/100)
frames = np.broadcast_to(spectrum, (10, 128, 128, 51))
frames = np.swapaxes(frames, 3, 1)
store.get_frames = mock.Mock(return_value=frames)
store.intensities = frames
fs = self.create_frameset(store=store)

# Check that the return value is correct
result = fs.spectrum()
np.testing.assert_equal(result.index, energies)
np.testing.assert_almost_equal(result.values, spectrum)

# Check that multiple spectra can be acquired simultaneously
result = fs.spectrum(index=slice(0, 2))
result = np.array([ser.values for ser in result])
spectras = np.array([fs.spectrum(index=0), fs.spectrum(index=0)])
np.testing.assert_equal(result, spectras)

# Check that the derivative is calculated correctly
derivative = 4*np.pi/100 * np.cos((energies-8300)*4*np.pi/100)
result = fs.spectrum(derivative=1)
np.testing.assert_almost_equal(result.values, derivative, decimal=3)

def test_nonenergy_spectrum(self):

"""If the frames aren't in energy order"""
store = MockStore()
# Prepare fake energy data
Expand Down
99 changes: 42 additions & 57 deletions xanespy/xanes_frameset.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,20 +809,20 @@ def spectra(self):
ODs = np.moveaxis(ODs, E_axis, -1)
spectra = np.reshape(ODs, (-1, self.num_energies))
return spectra

def line_spectra(self, xy0: Tuple[int, int], xy1: Tuple[int, int],
representation="optical_depths",
timestep=0, edge_filter=False, edge_filter_kw={}):
timestep=0, frame_filter=False, frame_filter_kw={}):
"""Return an array of spectra on a line between two points.
This is effectively nearest neighbor interpolation between two (x,
y) pairs on the frames.
Returns
-------
spectra : np.ndarray
An array of spectra, one for each point on the line.
A 2D array of spectra, one for each point on the line.
Parameters
----------
xy0 : 2-tuple
Expand All @@ -833,12 +833,12 @@ def line_spectra(self, xy0: Tuple[int, int], xy1: Tuple[int, int],
Which type of data to use for extracting line profiles.
timestep : int, optional
Which time step to use for extracting line profiles.
edge_filter : bool, optional
frame_filter : bool, str, optional
Whether to first apply an edge filter mask to the data
before calculating line profile.
edge_filter_kw : dict, optional
frame_filter_kw : dict, optional
Extra keyword arguments to pass to the edge_mask() method.
"""
xy0 = xycoord(*xy0)
xy1 = xycoord(*xy1)
Expand All @@ -851,22 +851,16 @@ def line_spectra(self, xy0: Tuple[int, int], xy1: Tuple[int, int],
x = np.linspace(px0.horizontal, px1.horizontal, length)
y = np.linspace(px0.vertical, px1.vertical, length)
# Check if an edge mask is needed
if edge_filter:
mask = self.edge_mask(**edge_filter_kw)
else:
mask = np.zeros(shape=self.frame_shape())
mask = self.frame_mask(mask_type=frame_filter, **frame_filter_kw)
# Extract the values along the line
with self.store(mode='r') as store:
frames = store.get_dataset(representation)[timestep]
frames = np.ma.array(frames, mask=mask)
if frames.ndim > 2:
spectra = frames[:, y.astype(np.int), x.astype(np.int)]
spectra = np.swapaxes(spectra, 0, 1)
else:
spectra = frames[y.astype(np.int), x.astype(np.int)]
frames = np.ma.array(frames, mask=np.broadcast_to(mask, frames.shape))
spectra = frames[:, y.astype(np.int), x.astype(np.int)]
spectra = np.swapaxes(spectra, 0, 1)
# And we're done
return spectra

def fitting_param_names(self, representation="fit_parameters"):
"""Get the human-readable names of the fit parameters."""
with self.store() as store:
Expand Down Expand Up @@ -1024,10 +1018,11 @@ def edge_mask(self, *args, **kwargs):
return self.frame_mask(mask_type='edge', *args, **kwargs)

@functools.lru_cache()
def frame_mask(self, mask_type=None, sensitivity: float = 1, min_size: int = 0, frame_idx='mean') -> np.ndarray:
def frame_mask(self, mask_type=None, sensitivity: float = 1,
min_size: int = 0, frame_idx='mean') -> np.ndarray:
"""Calculate a mask for what is likely active material based on either
the edge or the contrast of the first time index.
Parameters
----------
mask_type : str, bool
Expand All @@ -1046,36 +1041,35 @@ def frame_mask(self, mask_type=None, sensitivity: float = 1, min_size: int = 0,
xp.xanes_math.contrast_mask(). Allows User to create a
contrast map from an individual (timestep - energy) rather
than the mean image.
"""
with self.store() as store:
mask_is_possible = store.has_dataset('optical_depths') and self.edge is not None
if mask_is_possible:
# Check for complex values and convert to optical_depths only
ODs = np.real(store.optical_depths[()])

if mask_type == 'contrast':
# Create mask based on contrast maps
mask = xm.contrast_mask(frames=ODs,
sensitivity=sensitivity,
min_size=min_size,
frame_idx=frame_idx)



elif mask_type == 'edge':
# Create mask based on edge jump
mask = self.edge.mask(frames=ODs,
energies=store.energies,
sensitivity=sensitivity,
min_size=min_size)

elif not mask_type:
# Create blank mask array
mask = np.zeros(shape=store.intensities.shape[-2:], dtype='bool')
else:
# Show warning if all of these fail
warnings.warn('Incorrect User input or invalid frames() dimensions')

else:
# Store has no optical_depth data so just return a blank array
mask = np.zeros(shape=store.intensities.shape[-2:], dtype='bool')
Expand Down Expand Up @@ -1173,18 +1167,18 @@ def fit_kedge(self, quiet=False, ncore=None):
store.whiteline_fit.attrs['frame_source'] = 'optical_depths'
except AttributeError:
pass

def fit_spectra(self, func, p0=None, pnames=None, name=None,
frame_filter='edge', frame_filter_kw: Mapping={},
nonnegative=False, component='real',
representation='optical_depths', dtype=None,
quiet=False, ncore=None):
"""Fit a given function to the spectra at each pixel.
The fit parameters will be saved in the HDF dataset
"{name}_params" based on the parameter ``name``. RMS residuals
for each pixel will be saved in "{name}_residuals".
Parameters
----------
func : callable, optional
Expand Down Expand Up @@ -1234,21 +1228,21 @@ def fit_spectra(self, func, p0=None, pnames=None, name=None,
ncore : int, optional
How many processes to use in the pool. See
:func:`~xanespy.utilities.nproc` for more details.
Returns
-------
params : numpy.ndarray
The fit parameters (as frames) for each source.
residuals : numpy.ndarray
Residual error after fitting, as maps.
Raises
------
GuessParamsError
If the *func* callable doesn't have a *guess_params*
method. This can be solved by either using a callable with a
*guess_params()* method, or explicitly supplying *p0*.
"""
# Get data
with self.store() as store:
Expand All @@ -1267,11 +1261,8 @@ def fit_spectra(self, func, p0=None, pnames=None, name=None,
raise exceptions.GuessParamsError(
"Fitting function {} has no ``guess_params`` method. "
"Initial parameters *p0* is required.".format(func)) from None
print(p0.shape)
p0 = p0.reshape((self.num_timesteps, *self.frame_shape(), -1))
print(p0.shape)
p0 = np.moveaxis(p0, -1, 1)
print(p0.shape)
# Make sure p0 is the right shape
if p0.ndim < 4:
p0 = prepare_p0(p0, self.frame_shape(), self.num_timesteps)
Expand All @@ -1288,7 +1279,6 @@ def fit_spectra(self, func, p0=None, pnames=None, name=None,
spectra = spectra.astype(dtype)
p0 = p0.astype(dtype)
# Perform the actual fitting
print(spectra.shape, p0.shape)
params, residuals = fit_spectra(observations=spectra,
func=func, p0=p0,
nonnegative=nonnegative,
Expand Down Expand Up @@ -1526,13 +1516,13 @@ def plot_signal_map(self, ax=None, signals_idx=None, interpolation=None):
ax.set_xlabel(px_unit)
ax.set_ylabel(px_unit)
ax.set_title("Composite of signals {}".format(signals_idx))

def calculate_signals(self, n_components=2, method="nmf",
frame_source='optical_depths',
frame_filter='edge', frame_filter_kw: Mapping={}):
"""Extract signals and assign each pixel to a group, then save the
resulting RGB cluster map.
Parameters
==========
n_components : int, optional
Expand All @@ -1550,7 +1540,7 @@ def calculate_signals(self, n_components=2, method="nmf",
Additional arguments to be used for producing an frame_mask.
See :meth:`~xanespy.xanes_frameset.XanesFrameset.frame_mask`
for possible values.
Returns
=======
signals : np.
Expand All @@ -1570,7 +1560,6 @@ def calculate_signals(self, n_components=2, method="nmf",
spectra = np.moveaxis(As, 0, 1)
# Get the edge mask so only active material is included
dummy_mask = np.ones(frame_shape, dtype=np.bool)

# See if we need a mask
if frame_filter:
# Clear caches to make sure we don't use stale mask data
Expand All @@ -1585,20 +1574,16 @@ def calculate_signals(self, n_components=2, method="nmf",
mask = dummy_mask
else:
log.debug("Using edge mask for signal extraction")

else:
log.debug("No edge mask for signal extraction")
mask = dummy_mask

# Separate the data into signals
if method.lower() == 'nmf':
signals, weights = xm.extract_signals_nmf(
spectra=spectra[mask.flatten()], n_components=n_components)

elif method.lower() == 'pca':
signals, weights = xm.extract_signals_pca(
spectra=spectra[mask.flatten()], n_components=n_components)

else:
raise ValueError('Recieved Invalid Method : {method}'.format(method=method))
# Reshape weights into frame dimensions
Expand Down Expand Up @@ -1796,30 +1781,30 @@ def subtract_surroundings(self, sensitivity: float=1.):
bg = broadcast_reverse(bg, store.optical_depths.shape[1:])
# Save the resultant data to disk
store.optical_depths[timestep] = store.optical_depths[timestep] - bg

@functools.lru_cache(maxsize=64)
def extent(self, representation='intensities', idx=...):
"""Determine physical dimensions for axes values.
If an index is given, it will first be applied to the frames
array. For any remaining dimensions besides the last two, the
median will be taken. For an array of extents for each frame,
use the ``extent_array`` method.
Arguments
---------
representation : str, optional
Name for which dataset to use.
idx : int, optional
Index for choosing a frame. Any valid numpy index is
allowed, eg. ``...`` (default) uses all frame.
Returns
-------
extent : tuple
The spatial extent for the frame with order specified by
``utilities.Extent``
"""
pixel_size = self.pixel_size(representation=representation, timestep=idx)
imshape = self.frame_shape(representation)
Expand All @@ -1833,12 +1818,12 @@ def extent(self, representation='intensities', idx=...):
extent = Extent(left=left, right=right,
bottom=bottom, top=top)
return extent

def plot_frame(self, idx, ax=None, cmap="bone",
representation="optical_depths", component='modulus', *args,
**kwargs):
"""Plot the frame with given index as an image.
Parameters
==========
idx : 2-tuple(int)
Expand All @@ -1853,26 +1838,26 @@ def plot_frame(self, idx, ax=None, cmap="bone",
for plotting. Only applicable to complex-valued data.
*args, **kwargs :
Passed to the matplotlib ``imshow`` function.
Returns
=======
artist
The imshow ImageArtist.
"""
if len(idx) != 2:
raise ValueError("Index must be a 2-tuple in (timestep, energy) order.")
return self.plot_mean_frame(ax=ax, component=component,
representation=representation, cmap=cmap, timeidx=idx, *args,
**kwargs)
return artist

@property
def num_timesteps(self):
with self.store() as store:
val = store.optical_depths.shape[0]
return val

@property
def timestep_names(self):
with self.store() as store:
Expand Down

0 comments on commit ec45fef

Please sign in to comment.