Skip to content

Commit

Permalink
Apply Shuttercount Normalisation Correction when shutterCounts Available
Browse files Browse the repository at this point in the history
  • Loading branch information
JackEAllen committed Jun 3, 2024
1 parent 4d2ef3e commit 086db26
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions mantidimaging/gui/windows/spectrum_viewer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mantidimaging.core.data import ImageStack
from mantidimaging.core.io.csv_output import CSVOutput
from mantidimaging.core.io import saver
from mantidimaging.core.io.instrument_log import LogColumn
from mantidimaging.core.io.instrument_log import LogColumn, ShutterCountColumn
from mantidimaging.core.utility.sensible_roi import SensibleROI
from mantidimaging.core.utility.progress_reporting import Progress
from mantidimaging.core.utility.unit_conversion import UnitConversion
Expand Down Expand Up @@ -228,7 +228,27 @@ def get_spectrum(self, roi: str | SensibleROI, mode: SpecType) -> np.ndarray:
return np.array([])
roi_spectrum = self.get_stack_spectrum(self._stack, roi)
roi_norm_spectrum = self.get_stack_spectrum(self._normalise_stack, roi)
return np.divide(roi_spectrum, roi_norm_spectrum, out=np.zeros_like(roi_spectrum), where=roi_norm_spectrum != 0)
sample_shuttercount = self.get_stack_shuttercounts(self._stack)
flat_before_shuttercount = self.get_stack_shuttercounts(self._normalise_stack)
average_shuttercount = self.get_normalized_spectrum(sample_shuttercount, flat_before_shuttercount)
normalized_spectrum = np.divide(roi_spectrum,
roi_norm_spectrum,
out=np.zeros_like(roi_spectrum),
where=roi_norm_spectrum != 0)
final_spectrum = normalized_spectrum / average_shuttercount
return final_spectrum

def get_normalized_spectrum(self, sample_shuttercount: np.ndarray | None,
open_shuttercount: np.ndarray | None) -> float:
"""
Normalizes the first value of shutter values be dividing the sample by the flat field counts
loading whole stack to future proof normalising against all available shuttercount values
"""
if sample_shuttercount is None or open_shuttercount is None:
return 1
normalised_init_shuttercount = sample_shuttercount[0] / open_shuttercount[0]
print(f"Normalising using shutter count: {normalised_init_shuttercount}")
return normalised_init_shuttercount

def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
"""
Expand All @@ -241,7 +261,11 @@ def get_transmission_error_standard_dev(self, roi: SensibleROI) -> np.ndarray:
left, top, right, bottom = roi
sample = self._stack.data[:, top:bottom, left:right]
open_beam = self._normalise_stack.data[:, top:bottom, left:right]
safe_divide = np.divide(sample, open_beam, out=np.zeros_like(sample), where=open_beam != 0)
sample_shuttercount = self.get_stack_shuttercounts(self._stack)
flat_before_shuttercount = self.get_stack_shuttercounts(self._normalise_stack)
average_shuttercount = self.get_normalized_spectrum(sample_shuttercount, flat_before_shuttercount)
safe_divide = np.divide(sample, open_beam, out=np.zeros_like(sample), where=open_beam
!= 0) / average_shuttercount
return np.std(safe_divide, axis=(1, 2))

def get_transmission_error_propagated(self, roi: SensibleROI) -> np.ndarray:
Expand All @@ -254,7 +278,10 @@ def get_transmission_error_propagated(self, roi: SensibleROI) -> np.ndarray:
raise RuntimeError("Sample and open beam must be selected")
sample = self.get_stack_spectrum_summed(self._stack, roi)
open_beam = self.get_stack_spectrum_summed(self._normalise_stack, roi)
error = np.sqrt(sample / open_beam**2 + sample**2 / open_beam**3)
sample_shuttercount = self.get_stack_shuttercounts(self._stack)
flat_before_shuttercount = self.get_stack_shuttercounts(self._normalise_stack)
average_shuttercount = self.get_normalized_spectrum(sample_shuttercount, flat_before_shuttercount)
error = np.sqrt(sample / open_beam**2 + sample**2 / open_beam**3) / average_shuttercount
return error

def get_image_shape(self) -> tuple[int, int]:
Expand Down Expand Up @@ -414,6 +441,15 @@ def get_stack_time_of_flight(self) -> np.ndarray | None:
return None
return np.array(time_of_flights)

def get_stack_shuttercounts(self, stack: ImageStack | None) -> np.ndarray | None:
if stack is None or stack.shutter_count_file is None:
return None
try:
shutter_counts = stack.shutter_count_file.get_column(ShutterCountColumn.SHUTTER_COUNT)
except KeyError:
return None
return np.array(shutter_counts)

def get_roi_coords_filename(self, path: Path) -> Path:
"""
Get the path to save the ROI coordinates to.
Expand Down

0 comments on commit 086db26

Please sign in to comment.