Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1385 from lscheinkman/1380
Browse files Browse the repository at this point in the history
 Issue #1380: Validate python parameters passed to SP compute method
  • Loading branch information
lscheinkman committed Jan 16, 2018
2 parents d608108 + 6a152d6 commit c33a6e8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/nupic/bindings/algorithms.i
Original file line number Diff line number Diff line change
Expand Up @@ -1163,11 +1163,11 @@ void forceRetentionOfImageSensorLiteLibrary(void) {
self._initFromCapnpPyBytes(proto.as_builder().to_bytes()) # copy * 2
%}

inline void compute(PyObject *py_x, bool learn, PyObject *py_y)
inline void compute(PyObject *py_inputArray, bool learn, PyObject *py_activeArray)
{
PyArrayObject* x = (PyArrayObject*) py_x;
PyArrayObject* y = (PyArrayObject*) py_y;
self->compute((nupic::UInt*) PyArray_DATA(x), (bool)learn, (nupic::UInt*) PyArray_DATA(y));
nupic::CheckedNumpyVectorWeakRefT<nupic::UInt> inputArray(py_inputArray);
nupic::CheckedNumpyVectorWeakRefT<nupic::UInt> activeArray(py_activeArray);
self->compute(inputArray.begin(), learn, activeArray.begin());
}

inline void stripUnlearnedColumns(PyObject *py_x)
Expand Down
26 changes: 26 additions & 0 deletions src/nupic/py_support/NumpyVector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <nupic/types/Types.hpp> // For nupic::Real.
#include <nupic/utils/Log.hpp> // For NTA_ASSERT
#include <algorithm> // For std::copy.
#include <boost/type_index/stl_type_index.hpp> // for 'type_id'

namespace nupic {

Expand Down Expand Up @@ -438,6 +439,31 @@ namespace nupic {
PyArrayObject* pyArray_;
};

/**
* Similar to NumpyVectorWeakRefT but also provides extra type checking
*/
template<typename T>
class CheckedNumpyVectorWeakRefT : public NumpyVectorWeakRefT<T>
{
public:
CheckedNumpyVectorWeakRefT(PyObject* pyArray)
: NumpyVectorWeakRefT<T>(pyArray)
{
if (PyArray_NDIM(this->pyArray_) != 1)
{
NTA_THROW << "Expecting 1D array "
<< "but got " << PyArray_NDIM(this->pyArray_) << "D array";
}
if (!PyArray_EquivTypenums(
PyArray_TYPE(this->pyArray_), LookupNumpyDType((const T *) 0)))
{
boost::typeindex::stl_type_index expectedType =
boost::typeindex::stl_type_index::type_id<T>();
NTA_THROW << "Expecting '" << expectedType.pretty_name() << "' "
<< "but got '" << PyArray_DTYPE(this->pyArray_)->type << "'";
}
}
};
} // End namespace nupic.

#endif

0 comments on commit c33a6e8

Please sign in to comment.