Skip to content

Commit

Permalink
refs #8372. Module with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OwenArnold committed Nov 21, 2013
1 parent abda5be commit 6e0996b
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 0 deletions.
1 change: 1 addition & 0 deletions Code/Mantid/scripts/CMakeLists.txt
@@ -1,5 +1,6 @@

set ( TEST_PY_FILES
test/ConvertToWavelengthTest.py
test/ReducerTest.py
test/SettingsTest.py
test/DgreduceTest.py
Expand Down
@@ -0,0 +1,110 @@
import mantid.simpleapi as msi
import mantid.api
from mantid.kernel import logger

class ConvertToWavelength(object):

# List of workspaces to process.
__ws_list = []

@classmethod
def get_monitors_mask(cls, ws):
"""
Get the monitor indexes as a mask.
Arguments:
ws -- Workspace to determine the monitor masks for.
"""
monitor_masks = list()
for i in range(ws.getNumberHistograms()):
ismonitor = False
try:
det = ws.getDetector(i)
ismonitor = det.isMonitor()
except RuntimeError:
pass
monitor_masks.append(ismonitor)
return monitor_masks

@classmethod
def sum_workspaces(cls, workspaces):
"""
Sum together all workspaces. return the result.
Returns:
Result of sum ( a workspace)
"""
return sum(workspaces)

def __to_workspace_list(self, source_list):
temp=[]
for item in source_list:
if isinstance(item, mantid.api.MatrixWorkspace):
temp.append(item)
elif isinstance(item, str):
if not mtd.doesExist(item):
raise ValueError("Unknown source item %s" % item)
temp.append(mtd[item])
else:
raise ValueError("Expects a list of workspace or workspace names.")
return temp


def __init__(self, source):
"""
Constructor
Arguments:
list -- source workspace or workspaces.
Convert inputs into a list of workspace objects.
"""
if not isinstance(source, list):
source_list = [source]
self.__ws_list = source_list
else:
self.__ws_list = source

def convert(self, wavelength_min, wavelength_max, monitors_to_correct=None, bg_min=None, bg_max=None):
"""
Run the conversion
Arguments:
bg_min: x min background in wavelength
bg_max: x max background in wavelength
wavelength_min: min wavelength in x for monitor workspace
wavelength_max: max wavelength in x for detector workspace
Returns:
monitor_ws: A workspace of monitors
"""
# Sanity check inputs.
if(wavelength_min >= wavelength_max):
raise ValueError("Wavelength_min must be < wavelength_max min: %s, max: %s" % (wavelength_min, wavelength_max))

if any((monitors_to_correct, bg_min, bg_max)) and not all((monitors_to_correct, bg_min, bg_max)):
raise ValueError("Either provide ALL, monitors_to_correct, bg_min, bg_max or none of them")

if all((bg_min, bg_max)) and bg_min >= bg_max:
raise ValueError("Background min must be < Background max")

sum = ConvertToWavelength.sum_workspaces(self.__ws_list)
sum_wavelength= msi.ConvertUnits(InputWorkspace=sum, Target="Wavelength")
monitor_masks = ConvertToWavelength.get_monitors_mask(sum)

# Assuming that the monitors are in a block start-end.
first_detector_index = monitor_masks.index(False)
logger.debug("First detector index %s" % str(first_detector_index))

# Crop out the monitor workspace
monitor_ws = msi.CropWorkspace(InputWorkspace=sum_wavelength, StartWorkspaceIndex=0,EndWorkspaceIndex=first_detector_index-1)
# Crop out the detector workspace
detector_ws = msi.CropWorkspace(InputWorkspace=sum_wavelength, XMin=wavelength_min,XMax=wavelength_max,StartWorkspaceIndex=first_detector_index)
# Apply a flat background
if all((monitors_to_correct, bg_min, bg_max)):
monitor_ws = msi.CalculateFlatBackground(InputWorkspace=monitor_ws,WorkspaceIndexList=monitors_to_correct,StartX=bg_min, EndX=bg_max)

return (monitor_ws, detector_ws)





121 changes: 121 additions & 0 deletions Code/Mantid/scripts/test/ConvertToWavelengthTest.py
@@ -0,0 +1,121 @@
import unittest
from mantid.simpleapi import *
from isis_reflectometry.convert_to_wavelength import ConvertToWavelength

class ConvertToWavelengthTest(unittest.TestCase):
"""
Test the convert to wavelength type.
"""
def test_construction_from_single_ws(self):
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength(ws)
self.assertIsNotNone(converter, "Should have been able to make a valid converter from a single workspace")
DeleteWorkspace(ws)

def test_construction_from_single_ws_name(self):
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])

converter = ConvertToWavelength(ws.getName())
self.assertIsNotNone(converter, "Should have been able to make a valid converter from a single workspace name")
DeleteWorkspace(ws)

def test_construction_from_many_workspaces(self):
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
ws2 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength([ws1, ws2])
self.assertIsNotNone(converter, "Should have been able to make a valid converter from many workspace objects")
DeleteWorkspace(ws1)
DeleteWorkspace(ws2)

def test_construction_from_many_workspace_names(self):
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
ws2 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength([ws1.getName(), ws2.getName()])
self.assertIsNotNone(converter, "Should have been able to make a valid converter from many workspace objects")
DeleteWorkspace(ws1)
DeleteWorkspace(ws2)

def test_get_monitors_mask(self):
ws = Load(Filename='INTER00013460')
masks = ConvertToWavelength.get_monitors_mask(ws)
self.assertTrue(isinstance(masks, list), "Should have returned a list of masks")
self.assertEqual(len(masks), ws.getNumberHistograms())
self.assertEqual(masks, [True, True, True, False, False], "Monitor masks did not match expected")
DeleteWorkspace(ws)

def test_sum_workspaces(self):
ws1 = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
ws2 = CloneWorkspace(ws1)
ws3 = CloneWorkspace(ws1)
sum = ConvertToWavelength.sum_workspaces([ws1, ws2, ws3])
self.assertEqual(set([3,6,9]), set(sum.readY(0)), "Fail to sum workspaces correctly")
DeleteWorkspace(ws1)
DeleteWorkspace(ws2)
DeleteWorkspace(ws3)
DeleteWorkspace(sum)

def test_conversion_throws_with_min_wavelength_greater_or_equal_to_max_wavelength(self):
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength(ws)
self.assertRaises(ValueError, converter.convert, 1, 0)
self.assertRaises(ValueError, converter.convert, 1, 1)
DeleteWorkspace(ws)

def test_conversion_throws_with_some_flat_background_params_but_not_all(self):
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength(ws)
self.assertRaises(ValueError, converter.convert, 0, 1, [])
DeleteWorkspace(ws)

def test_conversion_throws_with_min_background_greater_than_or_equal_to_max_background(self):
ws = CreateWorkspace(DataY=[1,2,3], DataX=[1,2,3])
converter = ConvertToWavelength(ws)
self.assertRaises(ValueError, converter.convert, 0, 1, [], 0, 1)
DeleteWorkspace(ws)


def test_crop_range(self):
original_ws = Load(Filename='INTER00013460')

# Crop out one spectra
temp_ws = ConvertToWavelength.crop_range(original_ws, (0, original_ws.getNumberHistograms()-2))
self.assertEqual(original_ws.getNumberHistograms()-1, temp_ws.getNumberHistograms())

# Crop out all but 2 spectra from start and end.
temp_ws = ConvertToWavelength.crop_range(original_ws, ( (0, 1), (3, 4) ) )
self.assertEqual(2, temp_ws.getNumberHistograms())

# Crop out all but 2 spectra from start and end. Exactly the same as above, but slightly different tuple syntax
temp_ws = ConvertToWavelength.crop_range(original_ws, ( ( (0, 1), (3, 4) ) ))
self.assertEqual(2, temp_ws.getNumberHistograms())

# Test resilience to junk
self.assertRaises(ValueError, ConvertToWavelength.crop_range, original_ws, 'a')
self.assertRaises(ValueError, ConvertToWavelength.crop_range, original_ws, (1,2,3))

@classmethod
def cropped_x_range(cls, ws, index):
det_ws_x = ws.readX(index)
mask = ws.readY(index) != 0 # CropWorkspace will only zero out y values! so we need to translate those to an x range
cropped_x = det_ws_x[mask]
return cropped_x[0], cropped_x[-1]

def test_convert(self):
ws = Load(Filename='INTER00013460')
converter = ConvertToWavelength(ws)

monitor_ws, detector_ws = converter.convert(wavelength_min=0, wavelength_max=10, monitors_to_correct=[0],bg_min=2, bg_max=8)

masks = ConvertToWavelength.get_monitors_mask(ws)

self.assertEqual(masks.count(True), monitor_ws.getNumberHistograms(), "Wrong number of spectra in monitor workspace")
self.assertEqual(masks.count(False), detector_ws.getNumberHistograms(), "Wrong number of spectra in detector workspace")
self.assertEqual("Wavelength", detector_ws.getAxis(0).getUnit().unitID())
self.assertEqual("Wavelength", monitor_ws.getAxis(0).getUnit().unitID())
x_min, x_max = ConvertToWavelengthTest.cropped_x_range(detector_ws, 0)

self.assertGreaterEqual(x_min, 0)
self.assertLessEqual(x_max, 10)

if __name__ == '__main__':
unittest.main()

0 comments on commit 6e0996b

Please sign in to comment.