-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
random: Base-like APIs for rand, rand!, randn, randn! (#383)
* random: Base-like APIs for rand, rand!, randn, randn! and deprecate the original APIs ```julia julia> mx.rand(2, 3) 2×3 mx.NDArray{Float32,2} @ CPU0: 0.631961 0.324175 0.0762663 0.285366 0.395292 0.074995 julia> mx.rand(2, 3, low = low, high = high) 2×3 mx.NDArray{Float32,2} @ CPU0: 7.83884 7.85793 7.64791 7.68646 8.56082 8.42189 ``` ```julia julia> mx.randn(2, 3) 2×3 mx.NDArray{Float32,2} @ CPU0: 0.962853 0.424535 -0.320123 0.478113 1.72886 1.72287 julia> mx.randn(2, 3, μ = 100) 2×3 mx.NDArray{Float32,2} @ CPU0: 99.5635 100.483 99.888 99.9889 100.533 100.072 ``` * fix depwarn
- Loading branch information
Showing
5 changed files
with
93 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,66 @@ | ||
""" | ||
rand!(low, high, arr::NDArray) | ||
rand!(x::NDArray; low = 0, high = 1) | ||
Draw random samples from a uniform distribution. | ||
Samples are uniformly distributed over the half-open interval [low, high) | ||
(includes low, but excludes high). | ||
# Examples | ||
```julia | ||
julia> mx.rand(0, 1, mx.zeros(2, 2)) |> copy | ||
2×2 Array{Float32,2}: | ||
0.405374 0.321043 | ||
0.281153 0.713927 | ||
julia> mx.rand!(empty(2, 3)) | ||
2×3 mx.NDArray{Float32,2} @ CPU0: | ||
0.385748 0.839275 0.444536 | ||
0.0879585 0.215928 0.104636 | ||
julia> mx.rand!(empty(2, 3), low = 1, high = 10) | ||
2×3 mx.NDArray{Float32,2} @ CPU0: | ||
6.6385 4.18888 2.07505 | ||
8.97283 2.5636 1.95586 | ||
``` | ||
""" | ||
function rand!(low::Real, high::Real, out::NDArray) | ||
_random_uniform(NDArray, low=low, high=high, shape=size(out), out=out) | ||
end | ||
rand!(x::NDArray; low = 0, high = 1) = | ||
_random_uniform(NDArray, low = low, high = high, shape = size(x), out = x) | ||
|
||
""" | ||
rand(low, high, shape, context=cpu()) | ||
rand(dims...; low = 0, high = 1, context = cpu()) | ||
Draw random samples from a uniform distribution. | ||
Samples are uniformly distributed over the half-open interval [low, high) | ||
(includes low, but excludes high). | ||
# Examples | ||
```julia | ||
julia> mx.rand(0, 1, (2, 2)) |> copy | ||
2×2 Array{Float32,2}: | ||
0.405374 0.321043 | ||
0.281153 0.713927 | ||
julia> mx.rand(2, 2) | ||
2×2 mx.NDArray{Float32,2} @ CPU0: | ||
0.487866 0.825691 | ||
0.0234245 0.794797 | ||
julia> mx.rand(2, 2; low = 1, high = 10) | ||
2×2 mx.NDArray{Float32,2} @ CPU0: | ||
5.5944 5.74281 | ||
9.81258 3.58068 | ||
``` | ||
""" | ||
function rand(low::Real, high::Real, shape::NTuple{N, Int}, ctx::Context=cpu()) where N | ||
out = empty(shape, ctx) | ||
rand!(low, high, out) | ||
end | ||
rand(dims::Int...; low = 0, high = 1, context = cpu()) = | ||
rand!(empty(dims, context), low = low, high = high) | ||
|
||
""" | ||
randn!(mean, std, arr::NDArray) | ||
randn!(x::NDArray; μ = 0, σ = 1) | ||
Draw random samples from a normal (Gaussian) distribution. | ||
""" | ||
function randn!(mean::Real, stdvar::Real, out::NDArray) | ||
_random_normal(NDArray, loc=mean, scale=stdvar, shape=size(out), out=out) | ||
end | ||
randn!(x::NDArray; μ = 0, σ = 1) = | ||
_random_normal(NDArray, loc = μ, scale = σ, shape = size(x), out = x) | ||
|
||
""" | ||
randn(mean, std, shape, context=cpu()) | ||
randn(dims...; μ = 0, σ = 1, context = cpu()) | ||
Draw random samples from a normal (Gaussian) distribution. | ||
""" | ||
function randn(mean::Real, stdvar::Real, shape::NTuple{N,Int}, ctx::Context=cpu()) where N | ||
out = empty(shape, ctx) | ||
randn!(mean, stdvar, out) | ||
end | ||
randn(dims::Int...; μ = 0, σ = 1, context = cpu()) = | ||
randn!(empty(dims, context), μ = μ, σ = σ) | ||
|
||
""" | ||
srand(seed::Int) | ||
Set the random seed of libmxnet | ||
""" | ||
function srand(seed_state::Int) | ||
@mxcall(:MXRandomSeed, (Cint,), seed_state) | ||
end | ||
srand(seed_state::Int) = @mxcall(:MXRandomSeed, (Cint,), seed_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters