diff --git a/python/bifrost/ndarray.py b/python/bifrost/ndarray.py index 573b4c09c..60584d54d 100644 --- a/python/bifrost/ndarray.py +++ b/python/bifrost/ndarray.py @@ -39,7 +39,7 @@ import ctypes import numpy as np from memory import raw_malloc, raw_free, raw_get_space, space_accessible -from bifrost.libbifrost import _bf, _check +from bifrost.libbifrost import _bf, _check, _fast_call import device from DataType import DataType from Space import Space @@ -78,7 +78,7 @@ def copy_array(dst, src): space_accessible(src_bf.bf.space, ['system'])): np.copyto(dst_bf, src_bf) else: - _check(_bf.ArrayCopy(dst_bf.as_BFarray(), src_bf.as_BFarray())) + _fast_call(_bf.ArrayCopy, dst_bf.as_BFarray(), src_bf.as_BFarray()) if dst_bf.bf.space != src_bf.bf.space: # TODO: Decide where/when these need to be called device.stream_synchronize() @@ -86,7 +86,8 @@ def copy_array(dst, src): def memset_array(dst, value): dst_bf = asarray(dst) - _check(_bf.ArrayMemset(dst_bf.as_BFarray(), value)) + #_check(_bf.ArrayMemset(dst_bf.as_BFarray(), value)) + _fast_call(_bf.ArrayMemset, dst_bf.as_BFarray(), value) return dst # Stores Bifrost-specific metadata that augments Numpy's metadata @@ -138,6 +139,7 @@ def __new__(cls, base=None, space=None, shape=None, dtype=None, # Allow conjugated to be redefined if conjugated is not None: obj.bf.conjugated = conjugated + obj._update_BFarray() else: if not isinstance(base, np.ndarray): # Convert base to np.ndarray @@ -222,6 +224,7 @@ def __new__(cls, base=None, space=None, shape=None, dtype=None, obj = np.ndarray.__new__(cls, shape, dtype_np, data_buffer, offset, strides) obj.bf = BFArrayInfo(space, dtype, native, conjugated, ownbuffer) + obj._update_BFarray() return obj def __array_finalize__(self, obj): if obj is None: @@ -241,10 +244,18 @@ def __array_finalize__(self, obj): native = obj.dtype.isnative conjugated = False self.bf = BFArrayInfo(space, dtype, native, conjugated) + self._update_BFarray() def __del__(self): if hasattr(self, 'bf') and self.bf.ownbuffer: raw_free(self.bf.ownbuffer, self.bf.space) + def _update_BFarray(self): + # (Re-)cache the BFarray structure + # Note: This must be called after any updates to self.bf.* + self._BFarray = None + self._BFarray = self.as_BFarray() def as_BFarray(self): + if self._BFarray is not None: + return self._BFarray a = _bf.BFarray() a.data = self.ctypes.data a.space = Space(self.bf.space).as_BFspace() @@ -282,6 +293,7 @@ def view(self, dtype=None, type_=None): dtype_np = np.dtype(dtype_bf.as_numpy_dtype()) v = super(ndarray, self).view(dtype_np) v.bf.dtype = dtype_bf + v._update_BFarray() return v #def astype(self, dtype): # dtype_bf = DataType(dtype) @@ -304,6 +316,7 @@ def tofile(self, fid, sep="", format="%s"): def byteswap(self, inplace=False): if inplace: self.bf.native = not self.bf.native + self._update_BFarray() return super(ndarray, self).byteswap(True) else: return ndarray(self).byteswap(True)