Skip to content

Commit

Permalink
clean up Context
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 20, 2015
1 parent 4eff47d commit d13ddc6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
19 changes: 10 additions & 9 deletions src/context.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
@enum CONTEXT_TYPE CPU=1 GPU=2 CPU_PINNED=3

type Context
immutable Context
device_type :: CONTEXT_TYPE
device_id :: Cint

old_ctx :: Nullable{Context}
device_id :: Int
end
Context(dev_type :: Union{CONTEXT_TYPE, Integer}, dev_id :: Integer = 0) =
Context(convert(CONTEXT_TYPE, dev_type), convert(Cint, dev_id), Nullable{Context}())
Context(dev_type :: Union{CONTEXT_TYPE, Int}, dev_id :: Int = 0) =
Context(convert(CONTEXT_TYPE, dev_type), dev_id)

function Base.show(io :: IO, ctx :: Context)
print(io, "$(ctx.device_type)$(ctx.device_id)")
end


# global default context
DEFAULT_CONTEXT = Context(CPU)
function cpu(dev_id::Int=0)
return Context(CPU, dev_id)
end
function gpu(dev_id::Int=0)
return Context(GPU, dev_id)
end
17 changes: 13 additions & 4 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function Base.show(io :: IO, arr :: NDArray)
end

function NDArray{T<:Real}(data :: Array{T})
copy(data, mx.DEFAULT_CONTEXT)
copy(data, cpu())
end

function Base.unsafe_convert(::Type{MX_handle}, obj::NDArray)
Expand All @@ -63,7 +63,10 @@ function context(arr :: NDArray)
return Context(ref_typeid[], ref_devid[])
end

function empty{N}(shape :: NTuple{N, Int}, ctx :: Context = DEFAULT_CONTEXT)
function empty{N}(shape :: NTuple{N, Int})
empty(shape, cpu())
end
function empty{N}(shape :: NTuple{N, Int}, ctx :: Context)
NDArray(_ndarray_alloc(shape, ctx, false))
end
function empty(shape :: Int...)
Expand Down Expand Up @@ -99,7 +102,10 @@ function eltype(arr :: NDArray)
end

"Create zero-ed NDArray of specific shape"
function zeros{N}(shape :: NTuple{N, Int}, ctx :: Context = DEFAULT_CONTEXT)
function zeros{N}(shape :: NTuple{N, Int})
zeros(shape, cpu())
end
function zeros{N}(shape :: NTuple{N, Int}, ctx :: Context)
arr = empty(shape, ctx)
arr[:] = 0
return arr
Expand All @@ -109,7 +115,10 @@ function zeros(shape :: Int...)
end

"Create NDArray and initialize with 1"
function ones{N}(shape :: NTuple{N, Int}, ctx :: Context = DEFAULT_CONTEXT)
function ones{N}(shape :: NTuple{N, Int})
ones(shape, cpu())
end
function ones{N}(shape :: NTuple{N, Int}, ctx :: Context)
arr = empty(shape, ctx)
arr[:] = 1
return arr
Expand Down
10 changes: 8 additions & 2 deletions src/random.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
function rand!(low::Real, high::Real, out::NDArray)
_random_uniform(low, high, out)
end
function rand(low::Real, high::Real, shape::Tuple, ctx::Context=DEFAULT_CONTEXT)
function rand{N}(low::Real, high::Real, shape::NTuple{N, Int})
rand(low, high, shape, cpu())
end
function rand{N}(low::Real, high::Real, shape::NTuple{N, Int}, ctx::Context)
out = empty(shape, ctx)
rand!(low, high, out)
end

function randn!(mean::Real, stdvar::Real, out::NDArray)
_random_gaussian(mean, stdvar, out)
end
function randn(mean::Real, stdvar::Real, shape::Tuple, ctx::Context=DEFAULT_CONTEXT)
function randn{N}(mean::Real, stdvar::Real, shape::NTuple{N,Int})
randn(mean, stdvar, shape, cpu())
end
function randn{N}(mean::Real, stdvar::Real, shape::NTuple{N,Int}, ctx::Context)
out = empty(shape, ctx)
randn!(mean, stdvar, out)
end
Expand Down
6 changes: 3 additions & 3 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using ..Main: rand_dims, reldiff
################################################################################
function rand_tensors{N}(dims::NTuple{N, Int})
tensor = rand(mx.MX_float, dims)
array = copy(tensor, mx.DEFAULT_CONTEXT)
array = copy(tensor, mx.cpu())
return (tensor, array)
end

Expand All @@ -20,12 +20,12 @@ function test_copy()
info("NDArray::copy::dims = $dims")

# copy to NDArray and back
array = copy(tensor, mx.DEFAULT_CONTEXT)
array = copy(tensor, mx.cpu())
tensor2 = copy(array)
@test reldiff(tensor, tensor2) < 1e-6

# copy between NDArray
array2 = copy(array, mx.DEFAULT_CONTEXT)
array2 = copy(array, mx.cpu())
tensor2 = copy(array2)
@test reldiff(tensor, tensor2) < 1e-6
end
Expand Down

0 comments on commit d13ddc6

Please sign in to comment.