Skip to content

Commit

Permalink
ndarray: implement size(x, dims...) (#350)
Browse files Browse the repository at this point in the history
```julia
julia> x = mx.NDArray([1 2; 3 4; 5 6])
3×2 mx.NDArray{Int64,2} @ CPU0:
 1  2
 3  4
 5  6

julia> size(x, 1, 2, 3, 4)
(3, 2, 1, 1)
```
  • Loading branch information
iblislin authored and pluskid committed Dec 2, 2017
1 parent c43d0dd commit 09ee1f4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
16 changes: 16 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@
(MXNet.mx.SymbolicNode x, MXNet.mx.SymbolicNode y, MXNet.mx.SymbolicNode z)
```

### `NDArray`

* `size(x, dims...)` is supported now. (#TBD)

```julia
julia> x = mx.NDArray([1 2; 3 4; 5 6])
3×2 mx.NDArray{Int64,2} @ CPU0:
1 2
3 4
5 6

julia> size(x, 1, 2, 3, 4)
(3, 2, 1, 1)

```

# v0.3.0 (2017.11.16)

* Update `libmxnet` to
Expand Down
17 changes: 12 additions & 5 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,22 +226,29 @@ import Base: size, length, ndims, eltype

"""
size(x::NDArray)
size(x::NDArray, dim)
size(x::NDArray, dims...)
Get the shape of an `NDArray`. The shape is in Julia's column-major convention.
See also the notes on NDArray shapes [`NDArray`](@ref).
"""
function size(arr :: NDArray)
function size(x::NDArray)
ref_ndim = Ref{MX_uint}(0)
ref_shape = Ref{Ptr{MX_uint}}(0)
@mxcall(:MXNDArrayGetShape, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_uint}}),
arr, ref_ndim, ref_shape)
x, ref_ndim, ref_shape)
tuple(map(Int, flipdim(unsafe_wrap(Array, ref_shape[], ref_ndim[]),1))...)
end
function size(arr :: NDArray, dim :: Int)
size(arr)[dim]

function size(x::NDArray{T,N}, dim::Int) where {T,N}
if dim > N
1
else
size(x)[dim]
end
end

size(x::NDArray, dims::Int...) = map(d -> size(x, d), dims)

"""
length(x::NDArray)
Expand Down
10 changes: 10 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,15 @@ function test_show()
end
end

function test_size()
info("NDArray::size")
let A = [1 2; 3 4; 5 6], x = mx.NDArray(A)
@test size(A) == size(x)
@test size(A, 1, 2, 3, 4, 5) == size(x, 1, 2, 3, 4, 5)
@inferred size(x, 1, 2, 3, 4, 5)
end
end # function test_size()

################################################################################
# Run tests
################################################################################
Expand Down Expand Up @@ -802,6 +811,7 @@ end
test_fill()
test_transpose()
test_show()
test_size()
end

end

0 comments on commit 09ee1f4

Please sign in to comment.