@@ -76,6 +76,8 @@ namespace xt
7676
7777 value_type operator [](size_type i) const ;
7878
79+ size_type size () const ;
80+
7981 private:
8082
8183 const array_type* p_a;
@@ -108,14 +110,15 @@ namespace xt
108110 class pyarray : public pycontainer <pyarray<T>>,
109111 public xcontainer_semantic<pyarray<T>>
110112 {
111-
112113 public:
113114
114115 using self_type = pyarray<T>;
115116 using semantic_base = xcontainer_semantic<self_type>;
116117 using base_type = pycontainer<self_type>;
117118 using container_type = typename base_type::container_type;
118119 using value_type = typename base_type::value_type;
120+ using reference = typename base_type::reference;
121+ using const_reference = typename base_type::const_reference;
119122 using pointer = typename base_type::pointer;
120123 using size_type = typename base_type::size_type;
121124 using shape_type = typename base_type::shape_type;
@@ -125,7 +128,9 @@ namespace xt
125128 using inner_strides_type = typename base_type::inner_strides_type;
126129 using inner_backstrides_type = typename base_type::inner_backstrides_type;
127130
128- pyarray () = default ;
131+ pyarray ();
132+ pyarray (const self_type&) = default ;
133+ pyarray (self_type&&) = default ;
129134 pyarray (const value_type& t);
130135 pyarray (nested_initializer_list_t <T, 1 > t);
131136 pyarray (nested_initializer_list_t <T, 2 > t);
@@ -138,11 +143,16 @@ namespace xt
138143 pyarray (const pybind11::object &o);
139144
140145 explicit pyarray (const shape_type& shape, layout l = layout::row_major);
141- pyarray (const shape_type& shape, const strides_type& strides);
146+ explicit pyarray (const shape_type& shape, const_reference value, layout l = layout::row_major);
147+ explicit pyarray (const shape_type& shape, const strides_type& strides, const_reference value);
148+ explicit pyarray (const shape_type& shape, const strides_type& strides);
142149
143150 template <class E >
144151 pyarray (const xexpression<E>& e);
145152
153+ self_type& operator =(const self_type& e) = default ;
154+ self_type& operator =(self_type&& e) = default ;
155+
146156 template <class E >
147157 self_type& operator =(const xexpression<E>& e);
148158
@@ -182,6 +192,12 @@ namespace xt
182192 {
183193 }
184194
195+ template <class A >
196+ inline auto pyarray_backstrides<A>::size() const -> size_type
197+ {
198+ return p_a->dimension ();
199+ }
200+
185201 template <class A >
186202 inline auto pyarray_backstrides<A>::operator [](size_type i) const -> value_type
187203 {
@@ -194,6 +210,16 @@ namespace xt
194210 * pyarray implementation *
195211 **************************/
196212
213+ template <class T >
214+ inline pyarray<T>::pyarray()
215+ {
216+ // TODO: avoid allocation
217+ shape_type shape = make_sequence<shape_type>(0 , size_type (1 ));
218+ strides_type strides = make_sequence<strides_type>(0 , size_type (0 ));
219+ init_array (shape, strides);
220+ m_data[0 ] = T ();
221+ }
222+
197223 template <class T >
198224 inline pyarray<T>::pyarray(const value_type& t)
199225 {
@@ -260,9 +286,25 @@ namespace xt
260286 template <class T >
261287 inline pyarray<T>::pyarray(const shape_type& shape, layout l)
262288 {
263- strides_type strides;
289+ strides_type strides (shape.size ());
290+ compute_strides (shape, l, strides);
291+ init_array (shape, strides);
292+ }
293+
294+ template <class T >
295+ inline pyarray<T>::pyarray(const shape_type& shape, const_reference value, layout l)
296+ {
297+ strides_type strides (shape.size ());
264298 compute_strides (shape, l, strides);
265299 init_array (shape, strides);
300+ std::fill (m_data.begin (), m_data.end (), value);
301+ }
302+
303+ template <class T >
304+ inline pyarray<T>::pyarray(const shape_type& shape, const strides_type& strides, const_reference value)
305+ {
306+ init_array (shape, strides);
307+ std::fill (m_data.begin (), m_data.end (), value);
266308 }
267309
268310 template <class T >
@@ -306,7 +348,7 @@ namespace xt
306348 [](auto v) { return sizeof (T) * v; });
307349
308350 int flags = NPY_ARRAY_ALIGNED;
309- if (!std::is_const<T>::value)
351+ if (!std::is_const<T>::value)
310352 {
311353 flags |= NPY_ARRAY_WRITEABLE;
312354 }
@@ -319,8 +361,10 @@ namespace xt
319361 nullptr , static_cast <int >(sizeof (T)), flags, nullptr )
320362 );
321363
322- if (!tmp)
364+ if (!tmp)
365+ {
323366 throw std::runtime_error (" NumPy: unable to create ndarray" );
367+ }
324368
325369 this ->m_ptr = tmp.release ().ptr ();
326370 init_from_python ();
@@ -370,7 +414,6 @@ namespace xt
370414 {
371415 return m_data;
372416 }
373-
374417}
375418
376419#endif
0 commit comments