Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply ShutterCount Normalisation Correction when ShutterCounts Available #2214

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#2094: Apply ShutterCount Normalisation within Spectrum Viewer
74 changes: 61 additions & 13 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mantidimaging.core.data import ImageStack
from mantidimaging.core.io.csv_output import CSVOutput
from mantidimaging.core.io import saver
from mantidimaging.core.io.instrument_log import LogColumn
from mantidimaging.core.io.instrument_log import LogColumn, ShutterCountColumn
from mantidimaging.core.utility.sensible_roi import SensibleROI
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.utility.unit_conversion import UnitConversion
Expand Down Expand Up @@ -208,7 +208,10 @@ def normalise_issue(self) -> str:
return "Stack shapes must match"
return ""

def get_spectrum(self, roi: str | SensibleROI, mode: SpecType) -> np.ndarray:
def get_spectrum(self,
roi: str | SensibleROI,
mode: SpecType,
normalise_with_shuttercount: bool = False) -> np.ndarray:
if self._stack is None:
return np.array([])

Expand All @@ -228,12 +231,44 @@ def get_spectrum(self, roi: str | SensibleROI, mode: SpecType) -> np.ndarray:
return np.array([])
roi_spectrum = self.get_stack_spectrum(self._stack, roi)
roi_norm_spectrum = self.get_stack_spectrum(self._normalise_stack, roi)
return np.divide(roi_spectrum, roi_norm_spectrum, out=np.zeros_like(roi_spectrum), where=roi_norm_spectrum != 0)
spectrum = np.divide(roi_spectrum,
roi_norm_spectrum,
out=np.zeros_like(roi_spectrum),
where=roi_norm_spectrum != 0)
if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
spectrum = spectrum / average_shuttercount
return spectrum

def get_shuttercount_normalised_correction_parameter(self) -> float:
"""
Normalize ShutterCount values and return only the initial normalized value.
We normalise all values to future proof against normalizing against all available ShutterCount
values should we find the initial value is not sufficient.
"""
sample_shuttercount = self.get_stack_shuttercounts(self._stack)
open_shuttercount = self.get_stack_shuttercounts(self._normalise_stack)
if sample_shuttercount is None or open_shuttercount is None:
return 1.0 # No shutter count data available so no correction needed
normalised_shuttercounts = sample_shuttercount / open_shuttercount
return normalised_shuttercounts[0]

def get_stack_shuttercounts(self, stack: ImageStack | None) -> np.ndarray | None:
if stack is None or stack.shutter_count_file is None:
return None
try:
shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
except KeyError:
return None
return np.array(shutter_counts)

def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
def get_transmission_error_standard_dev(self,
roi: SensibleROI,
normalise_with_shuttercount: bool = False) -> np.ndarray:
"""
Get the transmission error standard deviation for a given roi
@param: roi_name The roi name
@param: normalised Default is True. If False, the normalization is not applied
@return: a numpy array representing the standard deviation of the transmission
"""
if self._stack is None or self._normalise_stack is None:
Expand All @@ -242,19 +277,30 @@ def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
sample = self._stack.data[:, top:bottom, left:right]
open_beam = self._normalise_stack.data[:, top:bottom, left:right]
safe_divide = np.divide(sample, open_beam, out=np.zeros_like(sample), where=open_beam != 0)
if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
safe_divide = safe_divide / average_shuttercount

return np.std(safe_divide, axis=(1, 2))

def get_transmission_error_propagated(self, roi: SensibleROI) -> np.ndarray:
def get_transmission_error_propagated(self,
roi: SensibleROI,
normalise_with_shuttercount: bool = False) -> np.ndarray:
"""
Get the transmission error using propagation of sqrt(n) error for a given roi
@param: roi_name The roi name
@param: normalised Default is True. If False, the normalization is not applied
@return: a numpy array representing the error of the transmission
"""
if self._stack is None or self._normalise_stack is None:
raise RuntimeError("Sample and open beam must be selected")
sample = self.get_stack_spectrum_summed(self._stack, roi)
open_beam = self.get_stack_spectrum_summed(self._normalise_stack, roi)
error = np.sqrt(sample / open_beam**2 + sample**2 / open_beam**3)

if normalise_with_shuttercount:
average_shuttercount = self.get_shuttercount_normalised_correction_parameter()
error = error / average_shuttercount
return error

def get_image_shape(self) -> tuple[int, int]:
Expand All @@ -271,7 +317,7 @@ def has_stack(self) -> bool:
"""
return self._stack is not None

def save_csv(self, path: Path, normalized: bool) -> None:
def save_csv(self, path: Path, normalise: bool, normalise_with_shuttercount: bool = False) -> None:
"""
Iterates over all ROIs and saves the spectrum for each one to a CSV file.

Expand All @@ -291,8 +337,9 @@ def save_csv(self, path: Path, normalized: bool) -> None:
csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV")

for roi_name in self.get_list_of_roi_names():
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE), "Counts")
if normalized:
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE, normalise_with_shuttercount),
"Counts")
if normalise:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts")
Expand All @@ -312,7 +359,7 @@ def save_single_rits_spectrum(self, path: Path, error_mode: ErrorMode) -> None:
"""
self.save_rits_roi(path, error_mode, self.get_roi(ROI_RITS))

def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) -> None:
def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI, normalise: bool = False) -> None:
"""
Saves the spectrum for one ROI to a RITS file.

Expand All @@ -329,12 +376,12 @@ def save_rits_roi(self, path: Path, error_mode: ErrorMode, roi: SensibleROI) ->
raise ValueError("No Time of Flights for sample. Make sure spectra log has been loaded")

tof *= 1e6 # RITS expects ToF in μs
transmission = self.get_spectrum(roi, SpecType.SAMPLE_NORMED)
transmission = self.get_spectrum(roi, SpecType.SAMPLE_NORMED, normalise)

if error_mode == ErrorMode.STANDARD_DEVIATION:
transmission_error = self.get_transmission_error_standard_dev(roi)
transmission_error = self.get_transmission_error_standard_dev(roi, normalise)
elif error_mode == ErrorMode.PROPAGATED:
transmission_error = self.get_transmission_error_propagated(roi)
transmission_error = self.get_transmission_error_propagated(roi, normalise)
else:
raise ValueError("Invalid error_mode given")

Expand Down Expand Up @@ -367,6 +414,7 @@ def save_rits_images(self,
error_mode: ErrorMode,
bin_size: int,
step: int,
normalise: bool = False,
progress: Progress | None = None) -> None:
"""
Saves multiple Region of Interest (ROI) images to RITS files.
Expand Down Expand Up @@ -404,7 +452,7 @@ def save_rits_images(self,
sub_right = min(sub_left + bin_size, right)
sub_roi = SensibleROI.from_list([sub_left, sub_top, sub_right, sub_bottom])
path = directory / f"rits_image_{x}_{y}.dat"
self.save_rits_roi(path, error_mode, sub_roi)
self.save_rits_roi(path, error_mode, sub_roi, normalise)
progress.update()
if sub_right == right:
break
Expand Down
10 changes: 7 additions & 3 deletions mantidimaging/gui/windows/spectrum_viewer/presenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def handle_export_csv(self) -> None:
if path.suffix != ".csv":
path = path.with_suffix(".csv")

self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED)
self.model.save_csv(path, self.spectrum_mode == SpecType.SAMPLE_NORMED, self.view.shuttercount_norm_enabled())

def handle_rits_export(self) -> None:
"""
Expand All @@ -245,8 +245,12 @@ def handle_rits_export(self) -> None:
if path is None:
LOG.debug("No path selected, aborting export")
return
run_function = partial(self.model.save_rits_images, path, error_mode, self.view.bin_size,
self.view.bin_step)
run_function = partial(self.model.save_rits_images,
path,
error_mode,
self.view.bin_size,
self.view.bin_step,
normalise=self.view.shuttercount_norm_enabled())

start_async_task_view(self.view, run_function, self._async_save_done)

Expand Down
111 changes: 109 additions & 2 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy.testing as npt
from parameterized import parameterized

from mantidimaging.core.io.instrument_log import InstrumentLog
from mantidimaging.core.io.instrument_log import InstrumentLog, ShutterCount, ShutterCountColumn
from mantidimaging.gui.windows.spectrum_viewer import SpectrumViewerWindowPresenter, SpectrumViewerWindowModel
from mantidimaging.gui.windows.spectrum_viewer.model import SpecType, ErrorMode
from mantidimaging.test_helpers.unit_test_helper import generate_images
Expand All @@ -36,17 +36,31 @@ def setUp(self) -> None:
self.presenter = mock.create_autospec(SpectrumViewerWindowPresenter)
self.model = SpectrumViewerWindowModel(self.presenter)

def _set_sample_stack(self, with_tof=False):
def _set_sample_stack(self, with_tof=False, with_shuttercount=False):
spectrum = np.arange(0, 10)
stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
if with_tof:
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="")
mock_inst_log.get_column.return_value = np.arange(0, 10) * 0.1
stack.log_file = mock_inst_log
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(5, 15)
stack._shutter_count_file = mock_shuttercounts
self.model.set_stack(stack)
self.model.set_new_roi("roi")
return stack, spectrum

def _set_normalise_stack(self, with_shuttercount=False):
spectrum = np.arange(10, 20)
normalise_stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
self.model.set_normalise_stack(normalise_stack)
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(10, 20)
normalise_stack._shutter_count_file = mock_shuttercounts
return normalise_stack

def _make_mock_path_stream(self):
mock_stream = CloseCheckStream()
mock_path = mock.create_autospec(Path)
Expand Down Expand Up @@ -521,3 +535,96 @@ def test_save_rits_correct_transmision(self, mock_save_rits_roi):
transmission = call[0][2]
expected_transmission = spectrum * expected_mean
npt.assert_array_equal(expected_transmission, transmission)

def test_get_stack_shuttercounts_returns_none_if_no_stack(self):
self.assertEqual(self.model.get_stack_shuttercounts(stack=None), None)

def test_get_stack_shuttercounts_returns_shutter_count_if_stack(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(stack),
stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))
self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(normalise_stack),
normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))

def test_get_shuttercount_normalised_correction_parameter_returns_one_if_no_stack(self):
with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.return_value = None
self.assertEqual(self.model.get_shuttercount_normalised_correction_parameter(), 1.0)

def test_get_shuttercount_normalised_correction_parameter_with_none_values(self):
self.model.get_stack_shuttercounts = mock.MagicMock(return_value=None)
expected_result = 1.0
result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_shuttercount_normalised_correction_parameter_with_values(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

stack_column = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
normalise_stack_column = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)

with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.side_effect = [stack_column, normalise_stack_column]
expected_result = stack_column[0] / normalise_stack_column[0]

result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_transmission_error_standard_dev(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
left, top, right, bottom = roi
sample = stack.data[:, top:bottom, left:right]
open = normalise_stack.data[:, top:bottom, left:right]

expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts
expected = np.std(expected, axis=(1, 2))

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_standard_dev(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_standard_dev_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_standard_dev("roi")

def test_get_transmission_error_propogated(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
sample = self.model.get_stack_spectrum_summed(stack, roi)
open = self.model.get_stack_spectrum_summed(normalise_stack, roi)

expected = np.sqrt(sample / open**2 + sample**2 / open**3) / average_shutter_counts

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_propagated(roi, normalise_with_shuttercount=True)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_propogated_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_propagated("roi")
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,14 @@ def test_handle_export_csv_none(self, mock_save_csv: mock.Mock):
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_csv")
def test_handle_export_csv(self, path_name: str, mock_save_csv: mock.Mock):
self.view.get_csv_filename = mock.Mock(return_value=Path(path_name))
self.view.shuttercount_norm_enabled.return_value = False

self.presenter.model.set_stack(generate_images())

self.presenter.handle_export_csv()

self.view.get_csv_filename.assert_called_once()
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False)
mock_save_csv.assert_called_once_with(Path("/fake/path.csv"), False, False)

@parameterized.expand(["/fake/path", "/fake/path.dat"])
@mock.patch("mantidimaging.gui.windows.spectrum_viewer.model.SpectrumViewerWindowModel.save_rits_roi")
Expand Down
4 changes: 4 additions & 0 deletions mantidimaging/gui/windows/spectrum_viewer/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpectrumViewerWindowView(BaseMainWindowView):
normaliseStackSelector: DatasetSelectorWidgetView

normaliseCheckBox: QCheckBox
normalise_ShutterCount_CheckBox: QCheckBox
imageLayout: QVBoxLayout
exportButton: QPushButton
exportTabs: QTabWidget
Expand Down Expand Up @@ -372,6 +373,9 @@ def display_normalise_error(self):
def normalisation_enabled(self):
return self.normaliseCheckBox.isChecked()

def shuttercount_norm_enabled(self) -> bool:
return self.normalise_ShutterCount_CheckBox.isChecked()

def set_new_roi(self) -> None:
"""
Set a new ROI on the image
Expand Down