Skip to content

Commit a95cde3

Browse files
committed
Add gtest testing
1 parent 7ea59cc commit a95cde3

File tree

18 files changed

+1348
-130
lines changed

18 files changed

+1348
-130
lines changed

.appveyor.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ install:
2323
- conda update -q conda
2424
- conda info -a
2525
- conda install pytest -c conda-forge
26-
- cd test
2726
- conda install xtensor==0.8.0 pytest numpy pybind11==2.1.0 -c conda-forge
2827
- xcopy /S %APPVEYOR_BUILD_FOLDER%\include %MINICONDA%\include
2928

3029
build_script:
31-
- py.test .
30+
- py.test -s

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@
3030
# Vim tmp files
3131
*.swp
3232

33+
# Build directory
34+
build/
35+
36+
# Test build artefacts
37+
test/test_xtensor_python
38+
test/CMakeCache.txt
39+
test/Makefile
40+
test/CMakeFiles/
41+
test/cmake_install.cmake
42+
3343
# Documentation build artefacts
3444
docs/CMakeCache.txt
3545
docs/xml/

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ install:
6767
- conda update -q conda
6868
# Useful for debugging any issues with conda
6969
- conda info -a
70-
- cd test
7170
- conda install xtensor==0.8.0 pytest numpy pybind11==2.1.0 -c conda-forge
7271
- cp -r $TRAVIS_BUILD_DIR/include/* $HOME/miniconda/include/
7372

@@ -84,4 +83,5 @@ script:
8483
elif [[ "$TRAVIS_OS_NAME" == "osx" ]]; then
8584
export CXX=clang++ CC=clang;
8685
fi
87-
- py.test .
86+
- py.test -s
87+

CMakeLists.txt

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,25 +43,22 @@ set(XTENSOR_PYTHON_HEADERS
4343
)
4444

4545
if(BUILD_TESTS)
46-
4746
include_directories(${XTENSOR_PYTHON_INCLUDE_DIR})
4847
find_package(xtensor REQUIRED)
4948
include_directories(${xtensor_INCLUDE_DIR})
5049
find_package(NumPy REQUIRED)
5150
include_directories(${NUMPY_INCLUDE_DIRS})
5251

53-
#TODO replace this with a find_package(pybind11 REQUIRED)
54-
# in parent CMakeLists when pybind11 CMakeLists has been fixed.
55-
find_package(PythonLibs REQUIRED)
56-
include_directories(${PYTHON_INCLUDE_DIRS})
52+
find_package(pybind11 REQUIRED)
53+
include_directories(${pybind11_INCLUDE_DIRS})
5754

58-
if(MSVC)
59-
set(PYTHON_MODULE_EXTENSION ".pyd")
60-
else()
61-
set(PYTHON_MODULE_EXTENSION ".so")
62-
endif()
55+
if(MSVC)
56+
set(PYTHON_MODULE_EXTENSION ".pyd")
57+
else()
58+
set(PYTHON_MODULE_EXTENSION ".so")
59+
endif()
6360

64-
#add_subdirectory(test)
61+
add_subdirectory(test)
6562
add_subdirectory(benchmark)
6663
endif()
6764

benchmark/main.cpp

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,37 @@ PYBIND11_PLUGIN(benchmark_xtensor_python)
2323
py::module m("benchmark_xtensor_python", "Benchmark module for xtensor python bindings");
2424

2525
m.def("sum_array", [](xt::pyarray<double> const& x) {
26-
double sum = 0;
27-
for(auto e : x)
28-
sum += e;
29-
return sum;
30-
}
31-
);
26+
double sum = 0;
27+
for(auto e : x)
28+
sum += e;
29+
return sum;
30+
});
3231

3332
m.def("sum_tensor", [](xt::pytensor<double, 1> const& x) {
34-
double sum = 0;
35-
for(auto e : x)
36-
sum += e;
37-
return sum;
38-
}
39-
);
33+
double sum = 0;
34+
for(auto e : x)
35+
sum += e;
36+
return sum;
37+
});
4038

4139
m.def("pybind_sum_array", [](py::array_t<double> const& x) {
42-
double sum = 0;
43-
size_t size = x.size();
44-
const double* data = x.data(0);
45-
for(size_t i = 0; i < size; ++i)
46-
sum += data[i];
47-
return sum;
48-
}
49-
);
40+
double sum = 0;
41+
size_t size = x.size();
42+
const double* data = x.data(0);
43+
for(size_t i = 0; i < size; ++i)
44+
sum += data[i];
45+
return sum;
46+
});
5047

5148
m.def("rect_to_polar", [](xt::pyarray<complex_t> const& a) {
52-
return py::make_tuple(xt::pyvectorize([](complex_t x) { return std::abs(x); })(a),
53-
xt::pyvectorize([](complex_t x) { return std::arg(x); })(a));
49+
return py::vectorize([](complex_t x) { return std::abs(x); })(a);
5450
});
5551

5652
m.def("pybind_rect_to_polar", [](py::array a) {
57-
if (py::isinstance<py::array_t<complex_t>>(a))
58-
return py::make_tuple(py::vectorize([](complex_t x) { return std::abs(x); })(a),
59-
py::vectorize([](complex_t x) { return std::arg(x); })(a));
60-
else
61-
throw py::type_error("rect_to_polar unhandled type");
53+
if (py::isinstance<py::array_t<complex_t>>(a))
54+
return py::vectorize([](complex_t x) { return std::abs(x); })(a);
55+
else
56+
throw py::type_error("rect_to_polar unhandled type");
6257
});
6358

6459
return m.ptr();

include/xtensor-python/pyarray.hpp

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ namespace xt
7676

7777
value_type operator[](size_type i) const;
7878

79+
size_type size() const;
80+
7981
private:
8082

8183
const array_type* p_a;
@@ -108,14 +110,15 @@ namespace xt
108110
class pyarray : public pycontainer<pyarray<T>>,
109111
public xcontainer_semantic<pyarray<T>>
110112
{
111-
112113
public:
113114

114115
using self_type = pyarray<T>;
115116
using semantic_base = xcontainer_semantic<self_type>;
116117
using base_type = pycontainer<self_type>;
117118
using container_type = typename base_type::container_type;
118119
using value_type = typename base_type::value_type;
120+
using reference = typename base_type::reference;
121+
using const_reference = typename base_type::const_reference;
119122
using pointer = typename base_type::pointer;
120123
using size_type = typename base_type::size_type;
121124
using shape_type = typename base_type::shape_type;
@@ -125,7 +128,9 @@ namespace xt
125128
using inner_strides_type = typename base_type::inner_strides_type;
126129
using inner_backstrides_type = typename base_type::inner_backstrides_type;
127130

128-
pyarray() = default;
131+
pyarray();
132+
pyarray(const self_type&) = default;
133+
pyarray(self_type&&) = default;
129134
pyarray(const value_type& t);
130135
pyarray(nested_initializer_list_t<T, 1> t);
131136
pyarray(nested_initializer_list_t<T, 2> t);
@@ -138,11 +143,16 @@ namespace xt
138143
pyarray(const pybind11::object &o);
139144

140145
explicit pyarray(const shape_type& shape, layout l = layout::row_major);
141-
pyarray(const shape_type& shape, const strides_type& strides);
146+
explicit pyarray(const shape_type& shape, const_reference value, layout l = layout::row_major);
147+
explicit pyarray(const shape_type& shape, const strides_type& strides, const_reference value);
148+
explicit pyarray(const shape_type& shape, const strides_type& strides);
142149

143150
template <class E>
144151
pyarray(const xexpression<E>& e);
145152

153+
self_type& operator=(const self_type& e) = default;
154+
self_type& operator=(self_type&& e) = default;
155+
146156
template <class E>
147157
self_type& operator=(const xexpression<E>& e);
148158

@@ -182,6 +192,12 @@ namespace xt
182192
{
183193
}
184194

195+
template <class A>
196+
inline auto pyarray_backstrides<A>::size() const -> size_type
197+
{
198+
return p_a->dimension();
199+
}
200+
185201
template <class A>
186202
inline auto pyarray_backstrides<A>::operator[](size_type i) const -> value_type
187203
{
@@ -194,6 +210,16 @@ namespace xt
194210
* pyarray implementation *
195211
**************************/
196212

213+
template <class T>
214+
inline pyarray<T>::pyarray()
215+
{
216+
// TODO: avoid allocation
217+
shape_type shape = make_sequence<shape_type>(0, size_type(1));
218+
strides_type strides = make_sequence<strides_type>(0, size_type(0));
219+
init_array(shape, strides);
220+
m_data[0] = T();
221+
}
222+
197223
template <class T>
198224
inline pyarray<T>::pyarray(const value_type& t)
199225
{
@@ -260,9 +286,25 @@ namespace xt
260286
template <class T>
261287
inline pyarray<T>::pyarray(const shape_type& shape, layout l)
262288
{
263-
strides_type strides;
289+
strides_type strides(shape.size());
290+
compute_strides(shape, l, strides);
291+
init_array(shape, strides);
292+
}
293+
294+
template <class T>
295+
inline pyarray<T>::pyarray(const shape_type& shape, const_reference value, layout l)
296+
{
297+
strides_type strides(shape.size());
264298
compute_strides(shape, l, strides);
265299
init_array(shape, strides);
300+
std::fill(m_data.begin(), m_data.end(), value);
301+
}
302+
303+
template <class T>
304+
inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
305+
{
306+
init_array(shape, strides);
307+
std::fill(m_data.begin(), m_data.end(), value);
266308
}
267309

268310
template <class T>
@@ -306,7 +348,7 @@ namespace xt
306348
[](auto v) { return sizeof(T) * v; });
307349

308350
int flags = NPY_ARRAY_ALIGNED;
309-
if(!std::is_const<T>::value)
351+
if (!std::is_const<T>::value)
310352
{
311353
flags |= NPY_ARRAY_WRITEABLE;
312354
}
@@ -319,8 +361,10 @@ namespace xt
319361
nullptr, static_cast<int>(sizeof(T)), flags, nullptr)
320362
);
321363

322-
if(!tmp)
364+
if (!tmp)
365+
{
323366
throw std::runtime_error("NumPy: unable to create ndarray");
367+
}
324368

325369
this->m_ptr = tmp.release().ptr();
326370
init_from_python();
@@ -370,7 +414,6 @@ namespace xt
370414
{
371415
return m_data;
372416
}
373-
374417
}
375418

376419
#endif

include/xtensor-python/pycontainer.hpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,39 @@
1717
#include "pybind11/common.h"
1818
#include "pybind11/complex.h"
1919

20+
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
21+
#include "numpy/arrayobject.h"
22+
2023
#include "xtensor/xcontainer.hpp"
2124

2225
namespace xt
2326
{
27+
static bool is_numpy_imported = false;
28+
29+
class numpy_import
30+
{
31+
protected:
32+
33+
inline numpy_import()
34+
{
35+
if (!is_numpy_imported)
36+
{
37+
_import_array();
38+
is_numpy_imported = true;
39+
}
40+
}
41+
42+
numpy_import(const numpy_import&) = default;
43+
numpy_import(numpy_import&&) = default;
44+
numpy_import& operator=(const numpy_import&) = default;
45+
numpy_import& operator=(numpy_import&&) = default;
46+
};
2447

2548
template <class D>
2649
class pycontainer : public pybind11::object,
27-
public xcontainer<D>
50+
public xcontainer<D>,
51+
private numpy_import
2852
{
29-
3053
public:
3154

3255
using derived_type = D;
@@ -64,12 +87,13 @@ namespace xt
6487
void reshape(const shape_type& shape, const strides_type& strides);
6588

6689
using base_type::operator();
90+
using base_type::operator[];
6791
using base_type::begin;
6892
using base_type::end;
6993

7094
protected:
7195

72-
pycontainer() = default;
96+
pycontainer();
7397
~pycontainer() = default;
7498

7599
pycontainer(pybind11::handle h, borrowed_t);
@@ -116,6 +140,12 @@ namespace xt
116140
* pycontainer implementation *
117141
******************************/
118142

143+
template <class D>
144+
inline pycontainer<D>::pycontainer()
145+
: pybind11::object()
146+
{
147+
}
148+
119149
template <class D>
120150
inline pycontainer<D>::pycontainer(pybind11::handle h, borrowed_t)
121151
: pybind11::object(h, borrowed)
@@ -132,16 +162,20 @@ namespace xt
132162
inline pycontainer<D>::pycontainer(const pybind11::object& o)
133163
: pybind11::object(raw_array_t(o.ptr()), pybind11::object::stolen)
134164
{
135-
if(!this->m_ptr)
165+
if (!this->m_ptr)
166+
{
136167
throw pybind11::error_already_set();
168+
}
137169
}
138170

139171
template <class D>
140172
inline auto pycontainer<D>::ensure(pybind11::handle h) -> derived_type
141173
{
142174
auto result = pybind11::reinterpret_steal<derived_type>(raw_array_t(h.ptr()));
143-
if(result.ptr() == nullptr)
175+
if (result.ptr() == nullptr)
176+
{
144177
PyErr_Clear();
178+
}
145179
return result;
146180
}
147181

@@ -156,9 +190,10 @@ namespace xt
156190
template <class D>
157191
inline PyObject* pycontainer<D>::raw_array_t(PyObject* ptr)
158192
{
159-
if(ptr == nullptr)
193+
if (ptr == nullptr)
194+
{
160195
return nullptr;
161-
196+
}
162197
int type_num = detail::numpy_traits<value_type>::type_num;
163198
auto res = PyArray_FromAny(ptr, PyArray_DescrFromType(type_num), 0, 0,
164199
NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, nullptr);
@@ -174,7 +209,7 @@ namespace xt
174209
template <class D>
175210
inline void pycontainer<D>::reshape(const shape_type& shape)
176211
{
177-
if(shape.size() != this->dimension() || !std::equal(shape.begin(), shape.end(), this->shape().begin()))
212+
if (shape.size() != this->dimension() || !std::equal(shape.begin(), shape.end(), this->shape().begin()))
178213
{
179214
reshape(shape, layout::row_major);
180215
}

0 commit comments

Comments
 (0)