Skip to content

Commit

Permalink
ndarray: add Base.ones(::NDArray) and Base.zeros (#363)
Browse files Browse the repository at this point in the history
For creating NDArray with same type and dims
  • Loading branch information
iblislin committed Jan 4, 2018
1 parent 30852dd commit 0ef5966
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@
4.0
```

* `Base.ones(x::NDArray)` for creating an one-ed `NDArray`. (#TBD)

* `Base.zeros(x::NDArray)` for creating a zero-ed `NDArray`. (#TBD)

* Modulo operator. (#TBD)

```julia
Expand Down
38 changes: 15 additions & 23 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ Note that the returned `NDArray` is uninitialized.
Base.similar(x::NDArray{T}) where {T} = empty(T, size(x), context(x))

"""
zeros(DType, dims[, ctx::Context = cpu()])
zeros(DType, dims...)
zeros([DType], dims, [ctx::Context = cpu()])
zeros([DType], dims...)
zeros(x::NDArray)
Create zero-ed `NDArray` with specific shape and type.
"""
Expand All @@ -185,19 +186,17 @@ end

zeros(::Type{T}, dims::Int...) where {T<:DType} = zeros(T, dims)

"""
zeros(dims[, ctx::Context = cpu()])
zeros(dims...)
Create zero-ed `NDArray` with specific shape.
"""
zeros(dims::NTuple{N, Int}, ctx::Context = cpu()) where N =
zeros(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
zeros(MX_float, dims, ctx)
zeros(dims::Int...) = zeros(dims)

zeros(x::NDArray)::typeof(x) = zeros_like(x)
Base.zeros(x::NDArray)::typeof(x) = zeros_like(x)

"""
ones(DType, dims::Tuple[, ctx::Context = cpu()])
ones(DType, dim1, dim2...)
ones([DType], dims, [ctx::Context = cpu()])
ones([DType], dims...)
ones(x::NDArray)
Create an `NDArray` with specific shape & type, and initialize with 1.
"""
Expand All @@ -209,20 +208,13 @@ end

ones(::Type{T}, dims::Int...) where T<:DType = ones(T, dims)

"""
ones(dims::Tuple[, ctx::Context = cpu()])
ones(dim1, dim2, ...)
Create an `NDArray` with specific shape and initialize with 1.
"""
function ones(dims::NTuple{N,Int}, ctx::Context = cpu()) where N
arr = empty(dims, ctx)
arr[:] = 1
arr
end

ones(dims::NTuple{N,Int}, ctx::Context = cpu()) where N =
ones(MX_float, dims, ctx)
ones(dims::Int...) = ones(dims)

ones(x::NDArray)::typeof(x) = ones_like(x)
Base.ones(x::NDArray)::typeof(x) = ones_like(x)

import Base: size, length, ndims, eltype

"""
Expand Down
22 changes: 22 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@ function test_constructor()
end # function test_constructor


function test_ones_zeros_like()
info("NDArray::Base.zeros")
let x = mx.rand(1, 10, (1, 3, 2, 4))
y = zeros(x)
@test sum(copy(y)) == 0

y = mx.zeros(x)
@test sum(copy(y)) == 0
end

info("NDArray::Base.ones")
let x = mx.rand(1, 10, (1, 3, 2, 4))
y = ones(x)
@test sum(copy(y)) == 1 * 3 * 2 * 4

y = mx.ones(x)
@test sum(copy(y)) == 1 * 3 * 2 * 4
end
end # function test_ones_zeros_like


function test_copy()
dims = rand_dims()
tensor = rand(mx.MX_float, dims)
Expand Down Expand Up @@ -1167,6 +1188,7 @@ end # function test_act_funcs
################################################################################
@testset "NDArray Test" begin
test_constructor()
test_ones_zeros_like()
test_assign()
test_copy()
test_slice()
Expand Down

0 comments on commit 0ef5966

Please sign in to comment.