Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
abda5be
commit 6e0996b
Showing
3 changed files
with
232 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
Code/Mantid/scripts/Reflectometry/isis_reflectometry/convert_to_wavelength.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import mantid.simpleapi as msi | ||
import mantid.api | ||
from mantid.kernel import logger | ||
|
||
class ConvertToWavelength(object): | ||
|
||
# List of workspaces to process. | ||
__ws_list = [] | ||
|
||
@classmethod | ||
def get_monitors_mask(cls, ws): | ||
""" | ||
Get the monitor indexes as a mask. | ||
Arguments: | ||
ws -- Workspace to determine the monitor masks for. | ||
""" | ||
monitor_masks = list() | ||
for i in range(ws.getNumberHistograms()): | ||
ismonitor = False | ||
try: | ||
det = ws.getDetector(i) | ||
ismonitor = det.isMonitor() | ||
except RuntimeError: | ||
pass | ||
monitor_masks.append(ismonitor) | ||
return monitor_masks | ||
|
||
@classmethod | ||
def sum_workspaces(cls, workspaces): | ||
""" | ||
Sum together all workspaces. return the result. | ||
Returns: | ||
Result of sum ( a workspace) | ||
""" | ||
return sum(workspaces) | ||
|
||
def __to_workspace_list(self, source_list): | ||
temp=[] | ||
for item in source_list: | ||
if isinstance(item, mantid.api.MatrixWorkspace): | ||
temp.append(item) | ||
elif isinstance(item, str): | ||
if not mtd.doesExist(item): | ||
raise ValueError("Unknown source item %s" % item) | ||
temp.append(mtd[item]) | ||
else: | ||
raise ValueError("Expects a list of workspace or workspace names.") | ||
return temp | ||
|
||
|
||
def __init__(self, source): | ||
""" | ||
Constructor | ||
Arguments: | ||
list -- source workspace or workspaces. | ||
Convert inputs into a list of workspace objects. | ||
""" | ||
if not isinstance(source, list): | ||
source_list = [source] | ||
self.__ws_list = source_list | ||
else: | ||
self.__ws_list = source | ||
|
||
def convert(self, wavelength_min, wavelength_max, monitors_to_correct=None, bg_min=None, bg_max=None): | ||
""" | ||
Run the conversion | ||
Arguments: | ||
bg_min: x min background in wavelength | ||
bg_max: x max background in wavelength | ||
wavelength_min: min wavelength in x for monitor workspace | ||
wavelength_max: max wavelength in x for detector workspace | ||
Returns: | ||
monitor_ws: A workspace of monitors | ||
""" | ||
# Sanity check inputs. | ||
if(wavelength_min >= wavelength_max): | ||
raise ValueError("Wavelength_min must be < wavelength_max min: %s, max: %s" % (wavelength_min, wavelength_max)) | ||
|
||
if any((monitors_to_correct, bg_min, bg_max)) and not all((monitors_to_correct, bg_min, bg_max)): | ||
raise ValueError("Either provide ALL, monitors_to_correct, bg_min, bg_max or none of them") | ||
|
||
if all((bg_min, bg_max)) and bg_min >= bg_max: | ||
raise ValueError("Background min must be < Background max") | ||
|
||
sum = ConvertToWavelength.sum_workspaces(self.__ws_list) | ||
sum_wavelength= msi.ConvertUnits(InputWorkspace=sum, Target="Wavelength") | ||
monitor_masks = ConvertToWavelength.get_monitors_mask(sum) | ||
|
||
# Assuming that the monitors are in a block start-end. | ||
first_detector_index = monitor_masks.index(False) | ||
logger.debug("First detector index %s" % str(first_detector_index)) | ||
|
||
# Crop out the monitor workspace | ||
monitor_ws = msi.CropWorkspace(InputWorkspace=sum_wavelength, StartWorkspaceIndex=0,EndWorkspaceIndex=first_detector_index-1) | ||
# Crop out the detector workspace | ||
detector_ws = msi.CropWorkspace(InputWorkspace=sum_wavelength, XMin=wavelength_min,XMax=wavelength_max,StartWorkspaceIndex=first_detector_index) | ||
# Apply a flat background | ||
if all((monitors_to_correct, bg_min, bg_max)): | ||
monitor_ws = msi.CalculateFlatBackground(InputWorkspace=monitor_ws,WorkspaceIndexList=monitors_to_correct,StartX=bg_min, EndX=bg_max) | ||
|
||
return (monitor_ws, detector_ws) | ||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import unittest | ||
from mantid.simpleapi import * | ||
from isis_reflectometry.convert_to_wavelength import ConvertToWavelength | ||
|
||
class ConvertToWavelengthTest(unittest.TestCase): | ||
""" | ||
Test the convert to wavelength type. | ||
""" | ||
def test_construction_from_single_ws(self): | ||
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength(ws) | ||
self.assertIsNotNone(converter, "Should have been able to make a valid converter from a single workspace") | ||
DeleteWorkspace(ws) | ||
|
||
def test_construction_from_single_ws_name(self): | ||
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
|
||
converter = ConvertToWavelength(ws.getName()) | ||
self.assertIsNotNone(converter, "Should have been able to make a valid converter from a single workspace name") | ||
DeleteWorkspace(ws) | ||
|
||
def test_construction_from_many_workspaces(self): | ||
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
ws2 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength([ws1, ws2]) | ||
self.assertIsNotNone(converter, "Should have been able to make a valid converter from many workspace objects") | ||
DeleteWorkspace(ws1) | ||
DeleteWorkspace(ws2) | ||
|
||
def test_construction_from_many_workspace_names(self): | ||
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
ws2 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength([ws1.getName(), ws2.getName()]) | ||
self.assertIsNotNone(converter, "Should have been able to make a valid converter from many workspace objects") | ||
DeleteWorkspace(ws1) | ||
DeleteWorkspace(ws2) | ||
|
||
def test_get_monitors_mask(self): | ||
ws = Load(Filename='INTER00013460') | ||
masks = ConvertToWavelength.get_monitors_mask(ws) | ||
self.assertTrue(isinstance(masks, list), "Should have returned a list of masks") | ||
self.assertEqual(len(masks), ws.getNumberHistograms()) | ||
self.assertEqual(masks, [True, True, True, False, False], "Monitor masks did not match expected") | ||
DeleteWorkspace(ws) | ||
|
||
def test_sum_workspaces(self): | ||
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
ws2 = CloneWorkspace(ws1) | ||
ws3 = CloneWorkspace(ws1) | ||
sum = ConvertToWavelength.sum_workspaces([ws1, ws2, ws3]) | ||
self.assertEqual(set([3,6,9]), set(sum.readY(0)), "Fail to sum workspaces correctly") | ||
DeleteWorkspace(ws1) | ||
DeleteWorkspace(ws2) | ||
DeleteWorkspace(ws3) | ||
DeleteWorkspace(sum) | ||
|
||
def test_conversion_throws_with_min_wavelength_greater_or_equal_to_max_wavelength(self): | ||
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength(ws) | ||
self.assertRaises(ValueError, converter.convert, 1, 0) | ||
self.assertRaises(ValueError, converter.convert, 1, 1) | ||
DeleteWorkspace(ws) | ||
|
||
def test_conversion_throws_with_some_flat_background_params_but_not_all(self): | ||
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength(ws) | ||
self.assertRaises(ValueError, converter.convert, 0, 1, []) | ||
DeleteWorkspace(ws) | ||
|
||
def test_conversion_throws_with_min_background_greater_than_or_equal_to_max_background(self): | ||
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3]) | ||
converter = ConvertToWavelength(ws) | ||
self.assertRaises(ValueError, converter.convert, 0, 1, [], 0, 1) | ||
DeleteWorkspace(ws) | ||
|
||
|
||
def test_crop_range(self): | ||
original_ws = Load(Filename='INTER00013460') | ||
|
||
# Crop out one spectra | ||
temp_ws = ConvertToWavelength.crop_range(original_ws, (0, original_ws.getNumberHistograms()-2)) | ||
self.assertEqual(original_ws.getNumberHistograms()-1, temp_ws.getNumberHistograms()) | ||
|
||
# Crop out all but 2 spectra from start and end. | ||
temp_ws = ConvertToWavelength.crop_range(original_ws, ( (0, 1), (3, 4) ) ) | ||
self.assertEqual(2, temp_ws.getNumberHistograms()) | ||
|
||
# Crop out all but 2 spectra from start and end. Exactly the same as above, but slightly different tuple syntax | ||
temp_ws = ConvertToWavelength.crop_range(original_ws, ( ( (0, 1), (3, 4) ) )) | ||
self.assertEqual(2, temp_ws.getNumberHistograms()) | ||
|
||
# Test resilience to junk | ||
self.assertRaises(ValueError, ConvertToWavelength.crop_range, original_ws, 'a') | ||
self.assertRaises(ValueError, ConvertToWavelength.crop_range, original_ws, (1,2,3)) | ||
|
||
@classmethod | ||
def cropped_x_range(cls, ws, index): | ||
det_ws_x = ws.readX(index) | ||
mask = ws.readY(index) != 0 # CropWorkspace will only zero out y values! so we need to translate those to an x range | ||
cropped_x = det_ws_x[mask] | ||
return cropped_x[0], cropped_x[-1] | ||
|
||
def test_convert(self): | ||
ws = Load(Filename='INTER00013460') | ||
converter = ConvertToWavelength(ws) | ||
|
||
monitor_ws, detector_ws = converter.convert(wavelength_min=0, wavelength_max=10, monitors_to_correct=[0],bg_min=2, bg_max=8) | ||
|
||
masks = ConvertToWavelength.get_monitors_mask(ws) | ||
|
||
self.assertEqual(masks.count(True), monitor_ws.getNumberHistograms(), "Wrong number of spectra in monitor workspace") | ||
self.assertEqual(masks.count(False), detector_ws.getNumberHistograms(), "Wrong number of spectra in detector workspace") | ||
self.assertEqual("Wavelength", detector_ws.getAxis(0).getUnit().unitID()) | ||
self.assertEqual("Wavelength", monitor_ws.getAxis(0).getUnit().unitID()) | ||
x_min, x_max = ConvertToWavelengthTest.cropped_x_range(detector_ws, 0) | ||
|
||
self.assertGreaterEqual(x_min, 0) | ||
self.assertLessEqual(x_max, 10) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |