Skip to content

Commit

Permalink
Apply ShutterCount Normalisation Correction when ShutterCounts Availa…
Browse files Browse the repository at this point in the history
…ble (#2214)
  • Loading branch information
samtygier-stfc authored Jun 19, 2024
2 parents 86644da + 10749f1 commit b866326
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#2094: Apply ShutterCount Normalisation within Spectrum Viewer
74 changes: 61 additions & 13 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mantidimaging.core.data import ImageStack
from mantidimaging.core.io.csv_output import CSVOutput
from mantidimaging.core.io import saver
from mantidimaging.core.io.instrument_log import LogColumn
from mantidimaging.core.io.instrument_log import LogColumn, ShutterCountColumn
from mantidimaging.core.utility.sensible_roi import SensibleROI
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.utility.unit_conversion import UnitConversion
Expand Down Expand Up @@ -208,7 +208,10 @@ def normalise_issue(self) -> str:
return "Stack shapes must match"
return ""

def get_spectrum(self, roi: str | SensibleROI, mode: SpecType) -> np.ndarray:
def get_spectrum(self,
roi: str | SensibleROI,
mode: SpecType,
normalise_with_shuttercount: bool = False) -> np.ndarray:
if self._stack is None:
return np.array([])

Expand All @@ -228,12 +231,44 @@ def get_spectrum(self, roi: str | SensibleROI, mode: SpecType) -> np.ndarray:
return np.array([])
roi_spectrum = self.get_stack_spectrum(self._stack, roi)
roi_norm_spectrum = self.get_stack_spectrum(self._normalise_stack, roi)
return np.divide(roi_spectrum, roi_norm_spectrum, out=np.zeros_like(roi_spectrum), where=roi_norm_spectrum != 0)
spectrum = np.divide(roi_spectrum,
roi_norm_spectrum,
out=np.zeros_like(roi_spectrum),
where=roi_norm_spectrum != 0)
if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
spectrum = spectrum / average_shuttercount
return spectrum

def get_shuttercount_normalised_correction_parameter(self) -> float:
"""
Normalize ShutterCount values and return only the initial normalized value.
We normalise all values to future proof against normalizing against all available ShutterCount
values should we find the initial value is not sufficient.
"""
sample_shuttercount = self.get_stack_shuttercounts(self._stack)
open_shuttercount = self.get_stack_shuttercounts(self._normalise_stack)
if sample_shuttercount is None or open_shuttercount is None:
return 1.0 # No shutter count data available so no correction needed
normalised_shuttercounts = sample_shuttercount / open_shuttercount
return normalised_shuttercounts[0]

def get_stack_shuttercounts(self, stack: ImageStack | None) -> np.ndarray | None:
if stack is None or stack.shutter_count_file is None:
return None
try:
shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
except KeyError:
return None
return np.array(shutter_counts)

def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
def get_transmission_error_standard_dev(self,
roi: SensibleROI,
normalise_with_shuttercount: bool = False) -> np.ndarray:
"""
Get the transmission error standard deviation for a given roi
@param: roi_name The roi name
@param: normalised Default is True. If False, the normalization is not applied
@return: a numpy array representing the standard deviation of the transmission
"""
if self._stack is None or self._normalise_stack is None:
Expand All @@ -242,19 +277,30 @@ def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
sample = self._stack.data[:, top:bottom, left:right]
open_beam = self._normalise_stack.data[:, top:bottom, left:right]
safe_divide = np.divide(sample, open_beam, out=np.zeros_like(sample), where=open_beam != 0)
if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
safe_divide = safe_divide / average_shuttercount

return np.std(safe_divide, axis=(1, 2))

def get_transmission_error_propagated(self, roi: SensibleROI) -> np.ndarray:
def get_transmission_error_propagated(self,
roi: SensibleROI,
normalise_with_shuttercount: bool = False) -> np.ndarray:
"""
Get the transmission error using propagation of sqrt(n) error for a given roi
@param: roi_name The roi name
@param: normalised Default is True. If False, the normalization is not applied
@return: a numpy array representing the error of the transmission
"""
if self._stack is None or self._normalise_stack is None:
raise RuntimeError("Sample and open beam must be selected")
sample = self.get_stack_spectrum_summed(self._stack, roi)
open_beam = self.get_stack_spectrum_summed(self._normalise_stack, roi)
error = np.sqrt(sample / open_beam**2 + sample**2 / open_beam**3)

if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
error = error / average_shuttercount
return error

def get_image_shape(self) -> tuple[int, int]:
Expand All @@ -271,7 +317,7 @@ def has_stack(self) -> bool:
"""
return self._stack is not None

def save_csv(self, path: Path, normalized: bool) -> None:
def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: bool = False) -> None:
"""
Iterates over all ROIs and saves the spectrum for each one to a CSV file.
Expand All @@ -291,8 +337,9 @@ def save_csv(self, path: Path, normalized: bool) -> None:
csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV")

for roi_name in self.get_list_of_roi_names():
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE), "Counts")
if normalized:
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE, normalise_with_shuttercount),
"Counts")
if normalise:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts")
Expand All @@ -312,7 +359,7 @@ def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None:
"""
self.save_rits_roi(path, error_mode, self.get_roi(ROI_RITS))

def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) -> None:
def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI, normalise: bool = False) -> None:
"""
Saves the spectrum for one ROI to a RITS file.
Expand All @@ -329,12 +376,12 @@ def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) ->
raise ValueError("No Time of Flights for sample. Make sure spectra log has been loaded")

tof *= 1e6 # RITS expects ToF in μs
transmission = self.get_spectrum(roi, SpecType.SAMPLE_NORMED)
transmission = self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise)

if error_mode == ErrorMode.STANDARD_DEVIATION:
transmission_error = self.get_transmission_error_standard_dev(roi)
transmission_error = self.get_transmission_error_standard_dev(roi, normalise)
elif error_mode == ErrorMode.PROPAGATED:
transmission_error = self.get_transmission_error_propagated(roi)
transmission_error = self.get_transmission_error_propagated(roi, normalise)
else:
raise ValueError("Invalid error_mode given")

Expand Down Expand Up @@ -367,6 +414,7 @@ def save_rits_images(self,
error_mode: ErrorMode,
bin_size: int,
step: int,
normalise: bool = False,
progress: Progress | None = None) -> None:
"""
Saves multiple Region of Interest (ROI) images to RITS files.
Expand Down Expand Up @@ -404,7 +452,7 @@ def save_rits_images(self,
sub_right = min(sub_left + bin_size, right)
sub_roi = SensibleROI.from_list([sub_left, sub_top, sub_right, sub_bottom])
path = directory / f"rits_image_{x}_{y}.dat"
self.save_rits_roi(path, error_mode, sub_roi)
self.save_rits_roi(path, error_mode, sub_roi, normalise)
progress.update()
if sub_right == right:
break
Expand Down
10 changes: 7 additions & 3 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def handle_export_csv(self) -> None:
if path.suffix != ".csv":
path = path.with_suffix(".csv")

self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED)
self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, self.view.shuttercount_norm_enabled())

def handle_rits_export(self) -> None:
"""
Expand All @@ -247,8 +247,12 @@ def handle_rits_export(self) -> None:
if path is None:
LOG.debug("No path selected, aborting export")
return
run_function = partial(self.model.save_rits_images, path, error_mode, self.view.bin_size,
self.view.bin_step)
run_function = partial(self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled())

start_async_task_view(self.view, run_function, self._async_save_done)

Expand Down
111 changes: 109 additions & 2 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy.testing as npt
from parameterized import parameterized

from mantidimaging.core.io.instrument_log import InstrumentLog
from mantidimaging.core.io.instrument_log import InstrumentLog, ShutterCount, ShutterCountColumn
from mantidimaging.gui.windows.spectrum_viewer import SpectrumViewerWindowPresenter, SpectrumViewerWindowModel
from mantidimaging.gui.windows.spectrum_viewer.model import SpecType, ErrorMode
from mantidimaging.test_helpers.unit_test_helper import generate_images
Expand All @@ -36,17 +36,31 @@ def setUp(self) -> None:
self.presenter = mock.create_autospec(SpectrumViewerWindowPresenter)
self.model = SpectrumViewerWindowModel(self.presenter)

def _set_sample_stack(self, with_tof=False):
def _set_sample_stack(self, with_tof=False, with_shuttercount=False):
spectrum = np.arange(0, 10)
stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
if with_tof:
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="")
mock_inst_log.get_column.return_value = np.arange(0, 10) * 0.1
stack.log_file = mock_inst_log
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(5, 15)
stack._shutter_count_file = mock_shuttercounts
self.model.set_stack(stack)
self.model.set_new_roi("roi")
return stack, spectrum

def _set_normalise_stack(self, with_shuttercount=False):
spectrum = np.arange(10, 20)
normalise_stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
self.model.set_normalise_stack(normalise_stack)
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(10, 20)
normalise_stack._shutter_count_file = mock_shuttercounts
return normalise_stack

def _make_mock_path_stream(self):
mock_stream = CloseCheckStream()
mock_path = mock.create_autospec(Path)
Expand Down Expand Up @@ -521,3 +535,96 @@ def test_save_rits_correct_transmision(self, mock_save_rits_roi):
transmission = call[0][2]
expected_transmission = spectrum * expected_mean
npt.assert_array_equal(expected_transmission, transmission)

def test_get_stack_shuttercounts_returns_none_if_no_stack(self):
self.assertEqual(self.model.get_stack_shuttercounts(stack=None), None)

def test_get_stack_shuttercounts_returns_shutter_count_if_stack(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(stack),
stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))
self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(normalise_stack),
normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))

def test_get_shuttercount_normalised_correction_parameter_returns_one_if_no_stack(self):
with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.return_value = None
self.assertEqual(self.model.get_shuttercount_normalised_correction_parameter(), 1.0)

def test_get_shuttercount_normalised_correction_parameter_with_none_values(self):
self.model.get_stack_shuttercounts = mock.MagicMock(return_value=None)
expected_result = 1.0
result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_shuttercount_normalised_correction_parameter_with_values(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

stack_column = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
normalise_stack_column = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)

with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.side_effect = [stack_column, normalise_stack_column]
expected_result = stack_column[0] / normalise_stack_column[0]

result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_transmission_error_standard_dev(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
left, top, right, bottom = roi
sample = stack.data[:, top:bottom, left:right]
open = normalise_stack.data[:, top:bottom, left:right]

expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts
expected = np.std(expected, axis=(1, 2))

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_standard_dev(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_standard_dev_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_standard_dev("roi")

def test_get_transmission_error_propogated(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
sample = self.model.get_stack_spectrum_summed(stack, roi)
open = self.model.get_stack_spectrum_summed(normalise_stack, roi)

expected = np.sqrt(sample / open**2 + sample**2 / open**3) / average_shutter_counts

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_propagated(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_propogated_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_propagated("roi")
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,14 @@ def test_handle_export_csv_none(self, mock_save_csv: mock.Mock):
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_csv")
def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock):
self.view.get_csv_filename = mock.Mock(return_value=Path(path_name))
self.view.shuttercount_norm_enabled.return_value = False

self.presenter.model.set_stack(generate_images())

self.presenter.handle_export_csv()

self.view.get_csv_filename.assert_called_once()
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False)
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False, False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
Expand Down
4 changes: 4 additions & 0 deletions mantidimaging/gui/windows/spectrum_viewer/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SpectrumViewerWindowView(BaseMainWindowView):
normaliseStackSelector: DatasetSelectorWidgetView

normaliseCheckBox: QCheckBox
normalise_ShutterCount_CheckBox: QCheckBox
imageLayout: QVBoxLayout
exportButton: QPushButton
exportTabs: QTabWidget
Expand Down Expand Up @@ -351,6 +352,9 @@ def display_normalise_error(self):
def normalisation_enabled(self):
return self.normaliseCheckBox.isChecked()

def shuttercount_norm_enabled(self) -> bool:
return self.normalise_ShutterCount_CheckBox.isChecked()

def set_new_roi(self) -> None:
"""
Set a new ROI on the image
Expand Down

0 comments on commit b866326

Please sign in to comment.