Skip to content

Commit

Permalink
Specview export tof wavelength (#2218)
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed Jun 10, 2024
2 parents f0def70 + fa39ff4 commit fb03173
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#2216: Add Wavelength, ToF and Energy unit conversions to Spectrum Viewer CSV Output
20 changes: 12 additions & 8 deletions mantidimaging/core/io/csv_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ class CSVOutput:
def __init__(self) -> None:
self.columns: dict[str, np.ndarray] = {}
self.num_rows: int | None = None
self.units: dict[str, str] = {}

def add_column(self, name: str, values: np.ndarray) -> None:
def add_column(self, name: str, values: np.ndarray, units: str) -> None:
as_column = values.reshape((-1, 1))
if self.num_rows is not None:
if as_column.size != self.num_rows:
raise ValueError("Column sizes must match")
else:
self.num_rows = as_column.size
if self.num_rows is not None and as_column.size != self.num_rows:
raise ValueError('Column sizes must match')

if units:
self.units[name] = units

self.num_rows = as_column.size if self.num_rows is None else self.num_rows
self.columns[name] = as_column

def write(self, outstream: IO[str]) -> None:
header = ",".join(self.columns.keys())
header = ','.join(self.columns.keys())
units = ','.join(self.units.values())
outstream.write('# ' + header + '\n' + '# ' + units + '\n')
data = np.hstack(list(self.columns.values()))
np.savetxt(outstream, data, header=header, fmt="%s", delimiter=",")
np.savetxt(outstream, data, fmt='%s', delimiter=',')
18 changes: 9 additions & 9 deletions mantidimaging/core/io/test/test_csv_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,37 @@ def test_add_column(self):
col_name = "new_col"
values = np.arange(10)

self.csv_output.add_column(col_name, values)
self.csv_output.add_column(col_name, values, "")

self.assertIn(col_name, self.csv_output.columns)
np.testing.assert_array_equal(self.csv_output.columns[col_name].flatten(), values.flatten())

def test_column_order(self):
col_names = "red orange yellow green blue indigo violet".split()
for col_name in col_names:
self.csv_output.add_column(col_name, np.arange(10))
self.csv_output.add_column(col_name, np.arange(10), "")

self.assertEqual(col_names, list(self.csv_output.columns.keys()))

def test_write_single(self):
col_name = "new_col"
values = np.arange(5, dtype=np.float32)
expected_out = "# new_col\n0.0\n1.0\n2.0\n3.0\n4.0\n"
self.csv_output.add_column(col_name, values)
expected_out = "# new_col\n# \n0.0\n1.0\n2.0\n3.0\n4.0\n"
self.csv_output.add_column(col_name, values, "")

self.csv_output.write(self.stream)
self.assertEqual(self.stream.getvalue(), expected_out)

def test_write_2_cols(self):
expected_out = "# col_1,col_2\n0.0,5.0\n1.0,6.0\n2.0,7.0\n3.0,8.0\n4.0,9.0\n"
self.csv_output.add_column("col_1", np.arange(5, dtype=np.float32))
self.csv_output.add_column("col_2", np.arange(5, 10, dtype=np.float32))
expected_out = "# col_1,col_2\n# \n0.0,5.0\n1.0,6.0\n2.0,7.0\n3.0,8.0\n4.0,9.0\n"
self.csv_output.add_column("col_1", np.arange(5, dtype=np.float32), "")
self.csv_output.add_column("col_2", np.arange(5, 10, dtype=np.float32), "")

self.csv_output.write(self.stream)
self.assertEqual(self.stream.getvalue(), expected_out)

def test_add_column_wrong_size(self):
self.csv_output.add_column("col_1", np.arange(5, dtype=np.float32))
self.csv_output.add_column("col_1", np.arange(5, dtype=np.float32), "")

with self.assertRaises(ValueError):
self.csv_output.add_column("col_2", np.arange(6, dtype=np.float32))
self.csv_output.add_column("col_2", np.arange(6, dtype=np.float32), "")
14 changes: 10 additions & 4 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,21 @@ def save_csv(self, path: Path, normalized: bool) -> None:
raise ValueError("No stack selected")

csv_output = CSVOutput()
csv_output.add_column("tof_index", np.arange(self._stack.data.shape[0]))
csv_output.add_column("ToF_index", np.arange(self._stack.data.shape[0]), "Index")
self.tof_data = self.get_stack_time_of_flight()
if self.tof_data is not None:
self.units.set_data_to_convert(self.tof_data)
csv_output.add_column("Wavelength", self.units.tof_seconds_to_wavelength_in_angstroms(), "Angstrom")
csv_output.add_column("ToF", self.units.tof_seconds_to_us(), "Microseconds")
csv_output.add_column("Energy", self.units.tof_seconds_to_energy(), "MeV")

for roi_name in self.get_list_of_roi_names():
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE))
csv_output.add_column(roi_name, self.get_spectrum(roi_name, SpecType.SAMPLE), "Counts")
if normalized:
if self._normalise_stack is None:
raise RuntimeError("No normalisation stack selected")
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN))
csv_output.add_column(roi_name + "_norm", self.get_spectrum(roi_name, SpecType.SAMPLE_NORMED))
csv_output.add_column(roi_name + "_open", self.get_spectrum(roi_name, SpecType.OPEN), "Counts")
csv_output.add_column(roi_name + "_norm", self.get_spectrum(roi_name, SpecType.SAMPLE_NORMED), "Counts")

with path.open("w") as outfile:
csv_output.write(outfile)
Expand Down
38 changes: 30 additions & 8 deletions mantidimaging/gui/windows/spectrum_viewer/test/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def setUp(self) -> None:
def _set_sample_stack(self, with_tof=False):
spectrum = np.arange(0, 10)
stack = ImageStack(np.ones([10, 11, 12]) * spectrum.reshape((10, 1, 1)))
self.model.set_stack(stack)
self.model.set_new_roi("roi")
if with_tof:
mock_inst_log = mock.create_autospec(InstrumentLog, source_file="")
mock_inst_log.get_column.return_value = np.arange(0, 10) * 0.1
stack.log_file = mock_inst_log
self.model.set_stack(stack)
self.model.set_new_roi("roi")
return stack, spectrum

def _make_mock_path_stream(self):
Expand Down Expand Up @@ -182,9 +182,10 @@ def test_save_csv(self):
self.model.save_csv(mock_path, False)

mock_path.open.assert_called_once_with("w")
self.assertIn("# tof_index,all,roi", mock_stream.captured[0])
self.assertIn("0.0,0.0,0.0", mock_stream.captured[1])
self.assertIn("1.0,2.0,2.0", mock_stream.captured[2])
self.assertIn("# ToF_index,all,roi", mock_stream.captured[0])
self.assertIn("# Index,Counts,Counts", mock_stream.captured[1])
self.assertIn("0.0,0.0,0.0", mock_stream.captured[2])
self.assertIn("1.0,2.0,2.0", mock_stream.captured[3])
self.assertTrue(mock_stream.is_closed)

def test_save_rits_dat(self):
Expand Down Expand Up @@ -319,9 +320,30 @@ def test_save_csv_norm(self):
self.model.save_csv(mock_path, True)

mock_path.open.assert_called_once_with("w")
self.assertIn("# tof_index,all,all_open,all_norm,roi,roi_open,roi_norm", mock_stream.captured[0])
self.assertIn("0.0,0.0,2.0,0.0,0.0,2.0,0.0", mock_stream.captured[1])
self.assertIn("1.0,1.0,2.0,0.5,1.0,2.0,0.5", mock_stream.captured[2])
self.assertIn("# ToF_index,all,all_open,all_norm,roi,roi_open,roi_norm", mock_stream.captured[0])
self.assertIn("# Index,Counts,Counts,Counts,Counts,Counts,Counts", mock_stream.captured[1])
self.assertIn("0.0,0.0,2.0,0.0,0.0,2.0,0.0", mock_stream.captured[2])
self.assertIn("1.0,1.0,2.0,0.5,1.0,2.0,0.5", mock_stream.captured[3])
self.assertTrue(mock_stream.is_closed)

def test_save_csv_norm_with_tof_loaded(self):
stack, _ = self._set_sample_stack(with_tof=True)
norm = ImageStack(np.full([10, 11, 12], 2))
stack.data[:, :, :5] *= 2
self.model.set_normalise_stack(norm)

mock_stream, mock_path = self._make_mock_path_stream()
with mock.patch.object(self.model, "save_roi_coords"):
self.model.save_csv(mock_path, True)

mock_path.open.assert_called_once_with("w")
self.assertIn("# ToF_index,Wavelength,ToF,Energy,all,all_open,all_norm,roi,roi_open,roi_norm",
mock_stream.captured[0])
self.assertIn("# Index,Angstrom,Microseconds,MeV,Counts,Counts,Counts", mock_stream.captured[1])
self.assertIn("0.0,0.0,0.0,inf,0.0,2.0,0.0,0.0,2.0,0.0", mock_stream.captured[2])
self.assertIn(
"1.0,7.064346392065392,100000.0,2.9271405738026552,1.4166666666666667,2.0,0.7083333333333334,1.4166666666666667,2.0,0.7083333333333334",
mock_stream.captured[3])
self.assertTrue(mock_stream.is_closed)

def test_WHEN_roi_name_generator_called_THEN_correct_names_returned_visible_to_model(self):
Expand Down

0 comments on commit fb03173

Please sign in to comment.