Skip to content

Commit

Permalink
ndarray: getindex/setindex! linear indexing (#294)
Browse files Browse the repository at this point in the history
* ndarray: getindex/setindex! linear indexing

```julia
x = mx.zeros(2, 5)
x[5] = 42
```

* ndarray: implement first
  • Loading branch information
iblislin authored and pluskid committed Nov 9, 2017
1 parent f839be1 commit 4919273
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 14 deletions.
19 changes: 19 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@
2.0 4.0
```

* `NDArray` `getindex`/`setindex!` linear indexing support and `first` for extracting scalar value. (#TBD)

```julia
julia> x = mx.zeros(2, 5)

julia> x[5] = 42 # do synchronization and set the value
```

```julia
julia> y = x[5] # actually, getindex won't do synchronization, but REPL's showing did it for you
1 mx.NDArray{Float32} @ CPU0:
42.0

julia> first(y) # do sync and get the value
42.0f0

julia> y[] # this is available, also
42.0f0
```
* Elementwise power of `NDArray`. (#293)
* `x.^2`
* `2.^x`
Expand Down
77 changes: 63 additions & 14 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ function eltype(arr :: T) where T <: Union{NDArray, MX_NDArrayHandle}
end
end

@inline _first(arr::NDArray) = try_get_shared(arr, sync = :read) |> first

Base.first(arr::NDArray) = _first(arr)

"""
slice(arr :: NDArray, start:stop)
Expand Down Expand Up @@ -341,37 +344,58 @@ function slice(arr :: NDArray, slice::UnitRange{Int})
return NDArray(MX_NDArrayHandle(hdr_ref[]), arr.writable)
end

function _at(handle::Union{MX_NDArrayHandle, MX_handle}, idx::Integer)
h_ref = Ref{MX_handle}(C_NULL)
@mxcall(:MXNDArrayAt, (MX_handle, MX_uint, Ref{MX_handle}),
handle, idx, h_ref)
h_ref[]
end

import Base: setindex!

"""
setindex!(arr :: NDArray, val, idx)
setindex!(arr::NDArray, val, idx)
Assign values to an `NDArray`. Elementwise assignment is not implemented, only the following
scenarios are supported
Assign values to an `NDArray`.
The following scenarios are supported
* single value assignment via linear indexing: `arr[42] = 24`
* `arr[:] = val`: whole array assignment, `val` could be a scalar or an array (Julia `Array`
or `NDArray`) of the same shape.
* `arr[start:stop] = val`: assignment to a *slice*, `val` could be a scalar or an array of
the same shape to the slice. See also [`slice`](@ref).
"""
function setindex!(arr :: NDArray, val :: Real, ::Colon)
@assert(arr.writable)
function setindex!(arr::NDArray, val::Real, idx::Integer)
# linear indexing
@assert arr.writable
_set_value(out=arr[idx], src=val)
end

function setindex!(arr::NDArray, val::Real, ::Colon)
@assert arr.writable
_set_value(out=arr, src=convert(eltype(arr), val))
return arr
end
function setindex!(arr :: NDArray, val :: Array{T}, ::Colon) where T<:Real

function setindex!(arr::NDArray, val::Array{T}, ::Colon) where T<:Real
@assert arr.writable
copy!(arr, val)
end
function setindex!(arr :: NDArray, val :: NDArray, ::Colon)

function setindex!(arr::NDArray, val::NDArray, ::Colon)
@assert arr.writable
copy!(arr, val)
end
function setindex!(arr :: NDArray, val :: Union{T,Array{T},NDArray}, idx::UnitRange{Int}) where T<:Real

function setindex!(arr::NDArray, val::Union{T,Array{T},NDArray},
idx::UnitRange{Int}) where T<:Real
@assert arr.writable
setindex!(slice(arr, idx), val, Colon())
end

import Base: getindex
"""
getindex(arr :: NDArray, idx)
getindex(arr::NDArray, idx)
Shortcut for [`slice`](@ref). A typical use is to write
Expand All @@ -396,18 +420,43 @@ which furthur translates into
create a **copy** of the sub-array for Julia `Array`, while for `NDArray`, this is
a *slice* that shares the memory.
"""
function getindex(arr :: NDArray, ::Colon)
function getindex(arr::NDArray, ::Colon)
return arr
end

"""
Shortcut for [`slice`](@ref). **NOTE** the behavior for Julia's built-in index slicing is to create a
copy of the sub-array, while here we simply call `slice`, which shares the underlying memory.
Shortcut for [`slice`](@ref).
**NOTE** the behavior for Julia's built-in index slicing is to create a
copy of the sub-array, while here we simply call `slice`,
which shares the underlying memory.
"""
function getindex(arr :: NDArray, idx::UnitRange{Int})
function getindex(arr::NDArray, idx::UnitRange{Int})
slice(arr, idx)
end

getindex(arr::NDArray) = _first(arr)

function getindex(arr::NDArray, idx::Integer)
# linear indexing
len = length(arr)
size_ = size(arr)

if idx <= 0 || idx > len
throw(BoundsError(
"attempt to access $(join(size_, 'x')) NDArray at index $(idx)"))
end

idx -= 1
offsets = size_[1:end-1] |> reverse cumprod collect
handle = arr.handle
for offset offsets
handle = _at(handle, idx ÷ offset)
idx %= offset
end

_at(handle, idx) |> MX_NDArrayHandle |> x -> NDArray(x, arr.writable)
end

import Base: copy!, copy, convert, deepcopy

"""
Expand Down
64 changes: 64 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,68 @@ function test_slice()
@test copy(mx.slice(array, 2:3)) == [1 1; 1 1]
end

function test_linear_idx()
info("NDArray::getindex::linear indexing")
let A = reshape(collect(1:30), 3, 10)
x = mx.NDArray(A)

@test copy(x) == A
@test copy(x[1]) == [1]
@test copy(x[2]) == [2]
@test copy(x[3]) == [3]
@test copy(x[12]) == [12]
@test copy(x[13]) == [13]
@test copy(x[14]) == [14]

@test_throws BoundsError x[-1]
@test_throws BoundsError x[0]
@test_throws BoundsError x[31]
@test_throws BoundsError x[42]
end

let A = reshape(collect(1:24), 3, 2, 4)
x = mx.NDArray(A)

@test copy(x) == A
@test copy(x[1]) == [1]
@test copy(x[2]) == [2]
@test copy(x[3]) == [3]
@test copy(x[11]) == [11]
@test copy(x[12]) == [12]
@test copy(x[13]) == [13]
@test copy(x[14]) == [14]
end

info("NDArray::setindex!::linear indexing")
let A = reshape(collect(1:24), 3, 2, 4)
x = mx.NDArray(A)

@test copy(x) == A

x[4] = -4
@test copy(x[4]) == [-4]

x[11] = -11
@test copy(x[11]) == [-11]

x[24] = 42
@test copy(x[24]) == [42]
end
end # function test_linear_idx

function test_first()
info("NDArray::first")
let A = reshape(collect(1:30), 3, 10)
x = mx.NDArray(A)

@test x[] == 1
@test x[5][] == 5

@test first(x) == 1
@test first(x[5]) == 5
end
end # function test_first

function test_plus()
dims = rand_dims()
t1, a1 = rand_tensors(dims)
Expand Down Expand Up @@ -668,6 +730,8 @@ end
test_assign()
test_copy()
test_slice()
test_linear_idx()
test_first()
test_plus()
test_minus()
test_mul()
Expand Down

0 comments on commit 4919273

Please sign in to comment.