Skip to content

Commit

Permalink
Test ShutterCount Normalisation and Transmission Methods
Browse files Browse the repository at this point in the history
  • Loading branch information
JackEAllen committed Jun 13, 2024
1 parent 15e9e6a commit 7bbb210
Showing 1 changed file with 120 additions and 2 deletions.
122 changes: 120 additions & 2 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy.testing as npt
from parameterized import parameterized

from mantidimaging.core.io.instrument_log import InstrumentLog
from mantidimaging.core.io.instrument_log import InstrumentLog, ShutterCount, ShutterCountColumn
from mantidimaging.gui.windows.spectrum_viewer import SpectrumViewerWindowPresenter, SpectrumViewerWindowModel
from mantidimaging.gui.windows.spectrum_viewer.model import SpecType, ErrorMode
from mantidimaging.test_helpers.unit_test_helper import generate_images
Expand All @@ -36,17 +36,31 @@ def setUp(self) -> None:
self.presenter = mock.create_autospec(SpectrumViewerWindowPresenter)
self.model = SpectrumViewerWindowModel(self.presenter)

def _set_sample_stack(self, with_tof=False):
def _set_sample_stack(self, with_tof=False, with_shuttercount=False):
spectrum = np.arange(0, 10)
stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
if with_tof:
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="")
mock_inst_log.get_column.return_value = np.arange(0, 10) * 0.1
stack.log_file = mock_inst_log
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(5, 15)
stack._shutter_count_file = mock_shuttercounts
self.model.set_stack(stack)
self.model.set_new_roi("roi")
return stack, spectrum

def _set_normalise_stack(self, with_shuttercount=False):
spectrum = np.arange(10, 20)
normalise_stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
self.model.set_normalise_stack(normalise_stack)
if with_shuttercount:
mock_shuttercounts = mock.create_autospec(ShutterCount, source_file="")
mock_shuttercounts.get_column.return_value = np.arange(10, 20)
normalise_stack._shutter_count_file = mock_shuttercounts
return normalise_stack

def _make_mock_path_stream(self):
mock_stream = CloseCheckStream()
mock_path = mock.create_autospec(Path)
Expand Down Expand Up @@ -521,3 +535,107 @@ def test_save_rits_correct_transmision(self, mock_save_rits_roi):
transmission = call[0][2]
expected_transmission = spectrum * expected_mean
npt.assert_array_equal(expected_transmission, transmission)

def test_get_stack_shuttercounts_returns_none_if_no_stack(self):
self.assertEqual(self.model.get_stack_shuttercounts(stack=None), None)

def test_get_stack_shuttercounts_returns_shutter_count_if_stack(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(stack),
stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))
self.assertTrue(
np.array_equal(self.model.get_stack_shuttercounts(normalise_stack),
normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)))

def test_get_shuttercount_normalised_correction_parameter_returns_one_if_no_stack(self):
with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.return_value = None
self.assertEqual(self.model.get_shuttercount_normalised_correction_parameter(), 1.0)

def test_get_shuttercount_normalised_correction_parameter_with_none_values(self):
self.model.get_stack_shuttercounts = mock.MagicMock(return_value=None)
expected_result = 1.0
result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_shuttercount_normalised_correction_parameter_with_apply_form_false(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
stack_column = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
normalise_stack_column = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)

with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.side_effect = [stack_column, normalise_stack_column]
result = self.model.get_shuttercount_normalised_correction_parameter(apply_normalisation=False)
self.assertEqual(result, 1.0)

def test_get_shuttercount_normalised_correction_parameter_with_values(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)

stack_column = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
normalise_stack_column = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)

with mock.patch.object(self.model, "get_stack_shuttercounts") as mock_get_stack_shuttercounts:
mock_get_stack_shuttercounts.side_effect = [stack_column, normalise_stack_column]
expected_result = stack_column[0] / normalise_stack_column[0]

result = self.model.get_shuttercount_normalised_correction_parameter()
self.assertEqual(result, expected_result)

def test_get_transmission_error_standard_dev(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
left, top, right, bottom = roi
sample = stack.data[:, top:bottom, left:right]
open = normalise_stack.data[:, top:bottom, left:right]

expected = np.divide(sample, open, out=np.zeros_like(sample), where=open != 0) / average_shutter_counts
expected = np.std(expected, axis=(1, 2))

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_standard_dev(roi)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_standard_dev_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_standard_dev("roi")

def test_get_transmission_error_propogated(self):
stack, _ = self._set_sample_stack(with_tof=True, with_shuttercount=True)
normalise_stack = self._set_normalise_stack(with_shuttercount=True)
sample_shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
open_shutter_counts = normalise_stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
average_shutter_counts = sample_shutter_counts[0] / open_shutter_counts[0]

roi = self.model.get_roi("roi")
sample = self.model.get_stack_spectrum_summed(stack, roi)
open = self.model.get_stack_spectrum_summed(normalise_stack, roi)

expected = np.sqrt(sample / open**2 + sample**2 / open**3) / average_shutter_counts

with mock.patch.object(
self.model, "get_shuttercount_normalised_correction_parameter",
return_value=average_shutter_counts) as mock_get_shuttercount_normalised_correction_parameter:
result = self.model.get_transmission_error_propagated(roi)
mock_get_shuttercount_normalised_correction_parameter.assert_called_once()

self.assertEqual(len(expected), len(result))
np.testing.assert_allclose(expected, result)

def test_get_transmission_error_propogated_raises_runtimeerror_if_no_stack(self):
with self.assertRaises(RuntimeError):
self.model.get_transmission_error_propagated("roi")

0 comments on commit 7bbb210

Please sign in to comment.