Skip to content

Commit

Permalink
Set a numpy array for a spectrum in one go. Refs #6156
Browse files Browse the repository at this point in the history
  • Loading branch information
martyngigg committed Nov 16, 2012
1 parent 679b3a7 commit e1ba46f
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <boost/python/overloads.hpp>
#include <boost/python/copy_const_reference.hpp>
#include <boost/python/implicit.hpp>
#include <boost/python/numeric.hpp>

using namespace Mantid::API;
using Mantid::Geometry::IDetector_sptr;
Expand All @@ -25,6 +26,7 @@ namespace
{
/// Typedef for data access, i.e. dataX,Y,E members
typedef Mantid::MantidVec&(MatrixWorkspace::*data_modifier)(const std::size_t);

/// return_value_policy for read-only numpy array
typedef return_value_policy<Policies::VectorRefToNumpy<Converters::WrapReadOnly> > return_readonly_numpy;
/// return_value_policy for read-write numpy array
Expand All @@ -34,6 +36,72 @@ namespace
// Overloads for binIndexOf function which has 1 optional argument
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(MatrixWorkspace_binIndexOfOverloads,
MatrixWorkspace::binIndexOf, 1, 2)

/**
* Set the values from an python array-style object into the given spectrum in the workspace
* @param self :: A reference to the calling object
* @param accessor :: A member-function pointer to the data{X,Y,E} member that will extract the writable values.
* @param wsIndex :: The workspace index for the spectrum to set
* @param values :: A numpy array. The length must match the size of the
*/
void setSpectrumFromPyObject(MatrixWorkspace & self, data_modifier accessor,
const size_t wsIndex, numeric::array values)
{
boost::python::tuple shape(values.attr("shape"));
if( boost::python::len(shape) != 1 )
{
throw std::invalid_argument("Invalid shape for setting 1D spectrum array, array is "
+ boost::lexical_cast<std::string>(boost::python::len(shape)) + "D");
}
const size_t pyArrayLength = boost::python::extract<size_t>(shape[0]);
Mantid::MantidVec & wsArrayRef = (self.*accessor)(wsIndex);
const size_t wsArrayLength = wsArrayRef.size();

if(pyArrayLength != wsArrayLength)
{
throw std::invalid_argument("Length mismatch between workspace array & python array. ws="
+ boost::lexical_cast<std::string>(wsArrayLength) + ", python=" + boost::lexical_cast<std::string>(pyArrayLength));
}
for(size_t i = 0; i < wsArrayLength; ++i)
{
wsArrayRef[i] = extract<double>(values[i]);
}
}


/**
* Set the X values from an python array-style object
* @param self :: A reference to the calling object
* @param wsIndex :: The workspace index for the spectrum to set
* @param values :: A numpy array. The length must match the size of the
*/
void setXFromPyObject(MatrixWorkspace & self, const size_t wsIndex, numeric::array values)
{
setSpectrumFromPyObject(self, &MatrixWorkspace::dataX, wsIndex, values);
}

/**
* Set the Y values from an python array-style object
* @param self :: A reference to the calling object
* @param wsIndex :: The workspace index for the spectrum to set
* @param values :: A numpy array. The length must match the size of the
*/
void setYFromPyObject(MatrixWorkspace & self, const size_t wsIndex, numeric::array values)
{
setSpectrumFromPyObject(self, &MatrixWorkspace::dataY, wsIndex, values);
}

/**
* Set the E values from an python array-style object
* @param self :: A reference to the calling object
* @param wsIndex :: The workspace index for the spectrum to set
* @param values :: A numpy array. The length must match the size of the
*/
void setEFromPyObject(MatrixWorkspace & self, const size_t wsIndex, numeric::array values)
{
setSpectrumFromPyObject(self, &MatrixWorkspace::dataE, wsIndex, values);
}

}

void export_MatrixWorkspace()
Expand Down Expand Up @@ -61,13 +129,15 @@ void export_MatrixWorkspace()
return_value_policy<copy_const_reference>(), "Returns the status of the distribution flag")
.def("YUnit", &MatrixWorkspace::YUnit, "Returns the current Y unit for the data (Y axis) in the workspace")
.def("YUnitLabel", &MatrixWorkspace::YUnitLabel, "Returns the caption for the Y axis")

//--------------------------------------- Setters -------------------------------------------------------------------------------
.def("setYUnitLabel", &MatrixWorkspace::setYUnitLabel, "Sets a new caption for the data (Y axis) in the workspace")
.def("setYUnit", &MatrixWorkspace::setYUnit, "Sets a new unit for the data (Y axis) in the workspace")
.def("setDistribution", (bool& (MatrixWorkspace::*)(const bool))&MatrixWorkspace::isDistribution,
return_value_policy<return_by_value>(), "Set distribution flag. If True the workspace has been divided by the bin-width.")
.def("replaceAxis", &MatrixWorkspace::replaceAxis)
//--------------------------------------- Data access ---------------------------------------------------------------------------

//--------------------------------------- Read spectrum data ---------------------------------------------------------------------------
.def("readX", &MatrixWorkspace::readX, return_readonly_numpy(),
"Creates a read-only numpy wrapper around the original X data at the given index")
.def("readY", &MatrixWorkspace::readY, return_readonly_numpy(),
Expand All @@ -76,6 +146,8 @@ void export_MatrixWorkspace()
"Creates a read-only numpy wrapper around the original E data at the given index")
.def("readDx", &MatrixWorkspace::readDx, return_readonly_numpy(),
"Creates a read-only numpy wrapper around the original Dx data at the given index")

//--------------------------------------- Write spectrum data ---------------------------------------------------------------------------
.def("dataX", (data_modifier)&MatrixWorkspace::dataX, return_readwrite_numpy(),
"Creates a writable numpy wrapper around the original X data at the given index")
.def("dataY", (data_modifier)&MatrixWorkspace::dataY, return_readwrite_numpy(),
Expand All @@ -84,6 +156,11 @@ void export_MatrixWorkspace()
"Creates a writable numpy wrapper around the original E data at the given index")
.def("dataDx", (data_modifier)&MatrixWorkspace::dataDx, return_readwrite_numpy(),
"Creates a writable numpy wrapper around the original Dx data at the given index")
.def("setX", &setXFromPyObject, "Set X values from a python list or numpy array. It performs a simple copy into the array.")
.def("setY", &setYFromPyObject, "Set Y values from a python list or numpy array. It performs a simple copy into the array.")
.def("setE", &setEFromPyObject, "Set E values from a python list or numpy array. It performs a simple copy into the array.")

// --------------------------------------- Extract data ---------------------------------------------------------------------------------
.def("extractX", Mantid::PythonInterface::cloneX,
"Extracts (copies) the X data from the workspace into a 2D numpy array. "
"Note: This can fail for large workspaces as numpy will require a block "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
from testhelpers import run_algorithm, can_be_instantiated
from mantid.api import (MatrixWorkspace, WorkspaceProperty, Workspace,
ExperimentInfo, AnalysisDataService)
ExperimentInfo, AnalysisDataService, WorkspaceFactory)
from mantid.geometry import Detector
from mantid.kernel import V3D

Expand Down Expand Up @@ -106,6 +106,61 @@ def do_numpy_test(arr):
for attr in [x,y,e,dx]:
do_numpy_test(attr)

def test_setting_spectra_from_array_of_incorrect_length_raises_error(self):
nvectors = 2
xlength = 11
ylength = 10
test_ws = WorkspaceFactory.create("Workspace2D", nvectors, xlength, ylength)

values = np.arange(xlength + 1)
self.assertRaises(ValueError, test_ws.setX, 0, values)
self.assertRaises(ValueError, test_ws.setY, 0, values)
self.assertRaises(ValueError, test_ws.setE, 0, values)

def test_setting_spectra_from_array_of_incorrect_shape_raises_error(self):
nvectors = 2
xlength = 11
ylength = 10
test_ws = WorkspaceFactory.create("Workspace2D", nvectors, xlength, ylength)

values = np.linspace(0,1,num=xlength-1)
values = values.reshape(5,2)
self.assertRaises(ValueError, test_ws.setX, 0, values)
self.assertRaises(ValueError, test_ws.setY, 0, values)
self.assertRaises(ValueError, test_ws.setE, 0, values)

def test_setting_spectra_from_array_using_incorrect_index_raises_error(self):
nvectors = 2
xlength = 11
ylength = 10

test_ws = WorkspaceFactory.create("Workspace2D", nvectors, xlength, ylength)
xvalues = np.arange(xlength)
self.assertRaises(RuntimeError, test_ws.setX, 3, xvalues)

def test_setting_spectra_from_array_sets_expected_values(self):
nvectors = 2
xlength = 11
ylength = 10

test_ws = WorkspaceFactory.create("Workspace2D", nvectors, xlength, ylength)
ws_index = 1

values = np.linspace(0,1,xlength)
test_ws.setX(ws_index, values)
ws_values = test_ws.readX(ws_index)
self.assertTrue(np.array_equal(values, ws_values))

values = np.ones(ylength)
test_ws.setY(ws_index, values)
ws_values = test_ws.readY(ws_index)
self.assertTrue(np.array_equal(values, ws_values))

values = np.sqrt(values)
test_ws.setE(ws_index, values)
ws_values = test_ws.readE(ws_index)
self.assertTrue(np.array_equal(values, ws_values))

def test_data_can_be_extracted_to_numpy_successfully(self):
x = self._test_ws.extractX()
y = self._test_ws.extractY()
Expand Down

0 comments on commit e1ba46f

Please sign in to comment.