Skip to content

Commit

Permalink
refs #8951. Some refactoring and input type checking.
Browse files Browse the repository at this point in the history
Tests for this have also been added.
  • Loading branch information
OwenArnold authored and RussellTaylor committed Feb 25, 2014
1 parent 571171a commit 7c4aa3d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 21 deletions.
Expand Up @@ -44,7 +44,7 @@ def __workspace_from_split_name(self, list_of_names, index):
def __workspaces_from_split_name(self, list_of_names):
workspaces = list()
for name in list_of_names:
workspaces.append(mtd[name])
workspaces.append(mtd[name.strip()])
return workspaces

'''
Expand All @@ -65,6 +65,15 @@ def __do_stitch_workspace(self, lhs_ws, rhs_ws, start_overlap, end_overlap, para
out_ws, scale_factor = Stitch1D(LHSWorkspace=lhs_ws, RHSWorkspace=rhs_ws, StartOverlap=start_overlap, EndOverlap=end_overlap,
Params=params, ScaleRHSWorkspace=scale_rhs_ws, UseManualScaleFactor=use_manual_scale_factor, ManualScaleFactor=manual_scale_factor, OutputWorkspace=out_name)
return (out_ws, scale_factor)

def __check_workspaces_are_common(self, input_workspace_names):
workspaces = self.__workspaces_from_split_name(input_workspace_names)
exemplar = workspaces[0]
for i in range(1, len(workspaces)):
test_ws = workspaces[i]
if type(exemplar) != type(test_ws):
raise RuntimeError("Input Workspaces must all be of the same type.")


def PyExec(self):

Expand All @@ -88,56 +97,52 @@ def PyExec(self):
raise ValueError("StartOverlaps and EndOverlaps are different lengths")
if not (len(startOverlaps) == (numberOfWorkspaces- 1)):
raise ValueError("Wrong number of StartOverlaps, should be %i not %i" % (numberOfWorkspaces - 1, startOverlaps))
self.__check_workspaces_are_common(inputWorkspaces)

scaleFactor = None
comma_separator = ","
no_separator = str()

# Iterate forward through the workspaces
if scaleRHSWorkspace:
lhsWS = self.__workspace_from_split_name(inputWorkspaces, 0)
print "SCALER RHS"

if isinstance(lhsWS, WorkspaceGroup):
print "WORKSPACE GROUP"

workspace_groups = self.__workspaces_from_split_name(inputWorkspaces)

group_separator = ""
group_workspaces = ""
out_group_separator = no_separator
out_group_workspaces = str()
# TODO. VERIFY THAT ALL INPUT WORKSPACES ARE GROUP WORKSPACES
print "TOTAL GROUP SIZE", lhsWS.size()

for i in range(lhsWS.size()):
print "i", i


to_process = ""
out_name = ""
separator = ""
print "NUMBER OF WORKSPACES", numberOfWorkspaces
to_process = str()
out_name = str()
separator = no_separator

for j in range(0, numberOfWorkspaces, 1):
print "j", j

to_process += separator + workspace_groups[j][i].name()
out_name += workspace_groups[j][i].name()
separator=","
separator=comma_separator
out_name += ("_" + str(i+1))


startOverlaps = self.getProperty("StartOverlaps").value
endOverlaps = self.getProperty("EndOverlaps").value
stitched, scaleFactor = Stitch1DMany(InputWorkspaces=to_process, OutputWorkspace=out_name, StartOverlaps=startOverlaps, EndOverlaps=endOverlaps,
Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor,
ManualScaleFactor=manualScaleFactor)

#lhsWS, scaleFactor = self.__do_stitch_workspace(lhsWS, rhsWS, startOverlaps[j-1], endOverlaps[j-1], params, scaleRHSWorkspace, useManualScaleFactor, manualScaleFactor)

group_workspaces += group_separator + out_name
group_separator = ","
out_group_workspaces += out_group_separator + out_name
out_group_separator = comma_separator

out_group = GroupWorkspaces(InputWorkspaces=group_workspaces)
out_group = GroupWorkspaces(InputWorkspaces=out_group_workspaces)
self.setProperty('OutputWorkspace', out_group)
else:
# TODO. VERIFY THAT ALL INPUT WORKSPACES ARE NOT GROUP WORKSPACES
for i in range(1, numberOfWorkspaces, 1):
rhsWS = self.__workspace_from_split_name(inputWorkspaces, i)
#lhsWS, scaleFactor = Stitch1D(LHSWorkspace=lhsWS, RHSWorkspace=rhsWS, StartOverlap=startOverlaps[i-1], EndOverlap=endOverlaps[i-1], Params=params, ScaleRHSWorkspace=scaleRHSWorkspace, UseManualScaleFactor=useManualScaleFactor, ManualScaleFactor=manualScaleFactor)
lhsWS, scaleFactor = self.__do_stitch_workspace(lhsWS, rhsWS, startOverlaps[i-1], endOverlaps[i-1], params, scaleRHSWorkspace, useManualScaleFactor, manualScaleFactor)
self.setProperty('OutputWorkspace', lhsWS)
DeleteWorkspace(lhsWS)
Expand Down
Expand Up @@ -66,6 +66,16 @@ def test_stich_throws_if_no_params(self):
except RuntimeError:
pass

def test_workspace_types_differ_throws(self):
tbl = CreateEmptyTableWorkspace()
input_workspaces = "%s, %s" % (self.a.name(), tbl.name()) # One table workspace, one matrix workspace
try:
stitchedViaStitchMany, scaleFactorMany = Stitch1DMany(InputWorkspaces=input_workspaces, Params=0.2)
self.fail("Input workspace type mis-match. Should have thrown.")
except RuntimeError:
pass
finally:
DeleteWorkspace(tbl)
#Cross-check that the result of using Stitch1DMany with two workspaces is the same as using Stitch1D.

def test_stitches_two(self):
Expand Down Expand Up @@ -123,6 +133,7 @@ def test_process_group_workspaces(self):
self.assertEqual(stitched.size(), 3, "Output should contain 3 workspaces")
self.assertEqual(stitched.name(), "stitched", "Output not named correctly")




if __name__ == '__main__':
Expand Down

0 comments on commit 7c4aa3d

Please sign in to comment.