diff --git a/include/array/array.h b/include/array/array.h index 19999f03..98f9853d 100644 --- a/include/array/array.h +++ b/include/array/array.h @@ -1679,6 +1679,10 @@ template ::type&>> NDARRAY_UNIQUE NDARRAY_HOST_DEVICE void for_each_value_in_order( const Shape& shape, Ptr base, Fn&& fn) { + if (!base) { + assert(shape.empty()); + return; + } // TODO: This is losing compile-time constant extents and strides info // (https://github.com/dsharlet/array/issues/1). auto base_and_stride = std::make_pair(base, shape.stride()); @@ -1698,6 +1702,10 @@ template ::type&>> NDARRAY_UNIQUE NDARRAY_HOST_DEVICE void for_each_value_in_order(const Shape& shape, const ShapeA& shape_a, PtrA base_a, const ShapeB& shape_b, PtrB base_b, Fn&& fn) { + if (!base_a || !base_b) { + assert(shape.empty()); + return; + } base_a += shape_a[shape.min()]; base_b += shape_b[shape.min()]; // TODO: This is losing compile-time constant extents and strides info