Skip to content

Commit 1d6429f

Browse files
wolfvSylvainCorlay
authored andcommitted
support char arrays and complicated dtypes (structs) (xtensor-stack#149)
* add support for char arrays and structs * ..
1 parent 504b75b commit 1d6429f

File tree

5 files changed

+158
-26
lines changed

5 files changed

+158
-26
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,7 @@ namespace pybind11
4949
{
5050
if (!convert)
5151
{
52-
if (!PyArray_Check(src.ptr()))
53-
{
54-
return false;
55-
}
56-
int type_num = xt::detail::numpy_traits<T>::type_num;
57-
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
52+
if (!xt::detail::check_array<T>(src))
5853
{
5954
return false;
6055
}
@@ -477,7 +472,7 @@ namespace xt
477472
shape_type shape = xtl::make_sequence<shape_type>(0, size_type(1));
478473
strides_type strides = xtl::make_sequence<strides_type>(0, size_type(0));
479474
init_array(shape, strides);
480-
m_storage[0] = T();
475+
detail::default_initialize(m_storage);
481476
}
482477

483478
/**
@@ -705,13 +700,15 @@ namespace xt
705700
{
706701
flags |= NPY_ARRAY_WRITEABLE;
707702
}
708-
int type_num = detail::numpy_traits<T>::type_num;
703+
704+
auto dtype = pybind11::detail::npy_format_descriptor<T>::dtype();
709705

710706
npy_intp* shape_data = reinterpret_cast<npy_intp*>(const_cast<size_type*>(shape.data()));
711707
npy_intp* strides_data = reinterpret_cast<npy_intp*>(adapted_strides.data());
708+
712709
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
713-
PyArray_New(&PyArray_Type, static_cast<int>(shape.size()), shape_data, type_num, strides_data,
714-
nullptr, static_cast<int>(sizeof(T)), flags, nullptr));
710+
PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.release().ptr(), static_cast<int>(shape.size()), shape_data, strides_data,
711+
nullptr, flags, nullptr));
715712

716713
if (!tmp)
717714
{

include/xtensor-python/pycontainer.hpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "pybind11/complex.h"
1717
#include "pybind11/pybind11.h"
18+
#include "pybind11/numpy.h"
1819

1920
#ifndef FORCE_IMPORT_ARRAY
2021
#define NO_IMPORT_ARRAY
@@ -129,8 +130,11 @@ namespace xt
129130

130131
namespace detail
131132
{
133+
template <class T, class E = void>
134+
struct numpy_traits;
135+
132136
template <class T>
133-
struct numpy_traits
137+
struct numpy_traits<T, std::enable_if_t<pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value>>
134138
{
135139
private:
136140

@@ -184,6 +188,47 @@ namespace xt
184188
{
185189
return numpy_enum_adjuster<NPY_LONGLONG != NPY_INT64>::pyarray_type(obj);
186190
}
191+
192+
template <class T>
193+
void default_initialize_impl(T& storage, std::false_type)
194+
{
195+
}
196+
197+
template <class T>
198+
void default_initialize_impl(T& storage, std::true_type)
199+
{
200+
using value_type = typename T::value_type;
201+
storage[0] = value_type{};
202+
}
203+
204+
template <class T>
205+
void default_initialize(T& storage)
206+
{
207+
using value_type = typename T::value_type;
208+
default_initialize_impl(storage, std::is_copy_assignable<value_type>());
209+
}
210+
211+
template <class T>
212+
bool check_array_type(const pybind11::handle& src, std::true_type)
213+
{
214+
int type_num = xt::detail::numpy_traits<T>::type_num;
215+
return xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) == type_num;
216+
}
217+
218+
template <class T>
219+
bool check_array_type(const pybind11::handle& src, std::false_type)
220+
{
221+
return PyArray_EquivTypes((PyArray_Descr*) pybind11::detail::array_proxy(src.ptr())->descr,
222+
(PyArray_Descr*) pybind11::dtype::of<T>().ptr());
223+
}
224+
225+
template <class T>
226+
bool check_array(const pybind11::handle& src)
227+
{
228+
using is_arithmetic_type = std::integral_constant<bool, bool(pybind11::detail::satisfies_any_of<T, std::is_arithmetic, xtl::is_complex>::value)>;
229+
return PyArray_Check(src.ptr()) &&
230+
check_array_type<T>(src, is_arithmetic_type{});
231+
}
187232
}
188233

189234
/******************************
@@ -232,9 +277,9 @@ namespace xt
232277
template <class D>
233278
inline bool pycontainer<D>::check_(pybind11::handle h)
234279
{
235-
int type_num = detail::numpy_traits<value_type>::type_num;
280+
auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype();
236281
return PyArray_Check(h.ptr()) &&
237-
PyArray_EquivTypenums(PyArray_TYPE(reinterpret_cast<PyArrayObject*>(h.ptr())), type_num);
282+
PyArray_EquivTypes_(PyArray_TYPE(reinterpret_cast<PyArrayObject*>(h.ptr())), dtype.ptr());
238283
}
239284

240285
template <class D>
@@ -244,8 +289,9 @@ namespace xt
244289
{
245290
return nullptr;
246291
}
247-
int type_num = detail::numpy_traits<value_type>::type_num;
248-
auto res = PyArray_FromAny(ptr, PyArray_DescrFromType(type_num), 0, 0,
292+
293+
auto dtype = pybind11::detail::npy_format_descriptor<value_type>::dtype();
294+
auto res = PyArray_FromAny(ptr, (PyArray_Descr *) dtype.release().ptr(), 0, 0,
249295
NPY_ARRAY_ENSUREARRAY | NPY_ARRAY_FORCECAST, nullptr);
250296
return res;
251297
}

include/xtensor-python/pytensor.hpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,7 @@ namespace pybind11
5050
{
5151
if (!convert)
5252
{
53-
if (!PyArray_Check(src.ptr()))
54-
{
55-
return false;
56-
}
57-
int type_num = xt::detail::numpy_traits<T>::type_num;
58-
if(xt::detail::pyarray_type(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
53+
if (!xt::detail::check_array<T>(src))
5954
{
6055
return false;
6156
}
@@ -228,7 +223,7 @@ namespace xt
228223
m_shape = xtl::make_sequence<shape_type>(N, size_type(1));
229224
m_strides = xtl::make_sequence<strides_type>(N, size_type(0));
230225
init_tensor(m_shape, m_strides);
231-
m_storage[0] = T();
226+
detail::default_initialize(m_storage);
232227
}
233228

234229
/**
@@ -402,11 +397,12 @@ namespace xt
402397
{
403398
flags |= NPY_ARRAY_WRITEABLE;
404399
}
405-
int type_num = detail::numpy_traits<T>::type_num;
400+
auto dtype = pybind11::detail::npy_format_descriptor<T>::dtype();
406401

407402
auto tmp = pybind11::reinterpret_steal<pybind11::object>(
408-
PyArray_New(&PyArray_Type, N, const_cast<npy_intp*>(shape.data()),
409-
type_num, python_strides, nullptr, sizeof(T), flags, nullptr));
403+
PyArray_NewFromDescr(&PyArray_Type, (PyArray_Descr*) dtype.ptr(), static_cast<int>(shape.size()),
404+
const_cast<npy_intp*>(shape.data()), python_strides,
405+
nullptr, flags, nullptr));
410406

411407
if (!tmp)
412408
{

test_python/main.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,53 @@ void dump_numpy_constant()
107107
std::cout << "NPY_UINT64 = " << NPY_UINT64 << std::endl;
108108
}
109109

110+
struct A
111+
{
112+
double a;
113+
int b;
114+
char c;
115+
std::array<double, 3> x;
116+
};
117+
118+
struct B
119+
{
120+
double a;
121+
int b;
122+
};
123+
124+
xt::pyarray<A> dtype_to_python()
125+
{
126+
A a1{123, 321, 'a', {1, 2, 3}};
127+
A a2{111, 222, 'x', {5, 5, 5}};
128+
129+
return xt::pyarray<A>({a1, a2});
130+
}
131+
132+
xt::pyarray<B> dtype_from_python(xt::pyarray<B>& b)
133+
{
134+
if (b(0).a != 1 || b(0).b != 'p' || b(1).a != 123 || b(1).b != 'c')
135+
{
136+
throw std::runtime_error("FAIL");
137+
}
138+
139+
b(0).a = 123.;
140+
b(0).b = 'w';
141+
return b;
142+
}
143+
144+
void char_array(xt::pyarray<char[20]>& carr)
145+
{
146+
if (strcmp(carr(2), "python"))
147+
{
148+
throw std::runtime_error("TEST FAILED!");
149+
}
150+
std::fill(&carr(2)[0], &carr(2)[0] + 20, 0);
151+
carr(2)[0] = 'c';
152+
carr(2)[1] = '+';
153+
carr(2)[2] = '+';
154+
carr(2)[3] = '\0';
155+
}
156+
110157
PYBIND11_MODULE(xtensor_python_test, m)
111158
{
112159
xt::import_numpy();
@@ -142,4 +189,12 @@ PYBIND11_MODULE(xtensor_python_test, m)
142189
m.def("int_overload", int_overload<int64_t>);
143190

144191
m.def("dump_numpy_constant", dump_numpy_constant);
192+
193+
// Register additional dtypes
194+
PYBIND11_NUMPY_DTYPE(A, a, b, c, x);
195+
PYBIND11_NUMPY_DTYPE(B, a, b);
196+
197+
m.def("dtype_to_python", dtype_to_python);
198+
m.def("dtype_from_python", dtype_from_python);
199+
m.def("char_array", char_array);
145200
}

test_python/test_pyarray.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import xtensor_python_test as xt
2222
import numpy as np
2323

24-
class ExampleTest(TestCase):
24+
class XtensorTest(TestCase):
2525

2626
def test_example1(self):
2727
self.assertEqual(4, xt.example1([4, 5, 6]))
@@ -87,3 +87,41 @@ def test_int_overload(self):
8787
b = xt.int_overload(np.ones((10), dtype))
8888
self.assertEqual(str(dtype.__name__), b)
8989

90+
def test_dtype(self):
91+
var = xt.dtype_to_python()
92+
self.assertEqual(var.dtype.names, ('a', 'b', 'c', 'x'))
93+
94+
exp_dtype = {
95+
'a': (np.dtype('float64'), 0),
96+
'b': (np.dtype('int32'), 8),
97+
'c': (np.dtype('int8'), 12),
98+
'x': (np.dtype(('<f8', (3,))), 16)
99+
}
100+
101+
self.assertEqual(var.dtype.fields, exp_dtype)
102+
103+
self.assertEqual(var[0]['a'], 123)
104+
self.assertEqual(var[0]['b'], 321)
105+
self.assertEqual(var[0]['c'], ord('a'))
106+
self.assertTrue(np.all(var[0]['x'] == [1, 2, 3]))
107+
108+
self.assertEqual(var[1]['a'], 111)
109+
self.assertEqual(var[1]['b'], 222)
110+
self.assertEqual(var[1]['c'], ord('x'))
111+
self.assertTrue(np.all(var[1]['x'] == [5, 5, 5]))
112+
113+
d_dtype = np.dtype({'names':['a','b'], 'formats':['<f8','<i4'], 'offsets':[0,8], 'itemsize':16})
114+
115+
darr = np.array([(1, ord('p')), (123, ord('c'))], dtype=d_dtype)
116+
self.assertEqual(darr[0]['a'], 1)
117+
res = xt.dtype_from_python(darr)
118+
self.assertEqual(res[0]['a'], 123.)
119+
self.assertEqual(darr[0]['a'], 123.)
120+
121+
def test_char_array(self):
122+
var = np.array(['hello', 'from', 'python'], dtype=np.dtype('|S20'));
123+
self.assertEqual(var[0], b'hello')
124+
xt.char_array(var)
125+
self.assertEqual(var[0], b'hello')
126+
self.assertEqual(var[1], b'from')
127+
self.assertEqual(var[2], b'c++')

0 commit comments

Comments
 (0)