Skip to content

Commit

Permalink
allow keyword arguments in bind.
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 21, 2015
1 parent df6613d commit 4ac5e7b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
21 changes: 21 additions & 0 deletions docs/user-guide/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,25 @@ end
# fc1_output => (10,64)
```

## Binding and Executing

In order to execute the computation graph specified a composed symbol, we will *bind* the free variables to concrete values, specified as `mx.NDArray`s. This will create an `mx.Executor` on a given `mx.Context`. A context describes the computation devices (CPUs, GPUs, etc.) and an executor will carry out the computation (forward/backward) specified in the corresponding symbolic composition.
```julia
A = mx.variable(:A)
B = mx.variable(:B)
C = A .* B
a = mx.ones(3) * 4
b = mx.ones(3) * 2
c_exec = mx.bind(C, context=mx.cpu(), args=Dict(:A => a, :B => b))

mx.forward(c_exec)
copy(c_exec.outputs[1]) # copy turns NDArray into Julia Array
# =>
# 3-element Array{Float32,1}:
# 8.0
# 8.0
# 8.0
```
**TODO** Provide pointers to further details.

# Low Level Interface
32 changes: 21 additions & 11 deletions src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,26 @@ function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Base.Symbol,NDA
end
arr
end
# help the type inference
if allow_missing
args_vec = Union{NDArray,Void}[args_vec...]
else
args_vec = NDArray[args_vec...]
end
args_hdr = MX_handle[(isa(x,Void) ? MX_handle(0) : x) for x in args_vec]
return (args_hdr, args_vec)
end

@enum GRAD_REQ GRAD_NOP=0 GRAD_WRITE=1 GRAD_ADD=3
function bind(self :: Symbol, ctx :: Context, args :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}};
args_grad :: Union{Void,Vector{NDArray},Dict{Base.Symbol,NDArray}} = nothing,
aux_states :: Union{Void,Vector{NDArray},Dict{Base.Symbol,NDArray}} = nothing,
args_grad :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
aux_states :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
grad_req :: Union{GRAD_REQ,Vector{GRAD_REQ},Dict{Base.Symbol,GRAD_REQ}} = GRAD_WRITE)

arg_names = list_arguments(self)

args_hdr, args = _get_ndarray_inputs("args", args, arg_names, false)
if isa(args_grad, Void)
args_grad = [nothing for i=1:length(args)]
args_grad_hdr = MX_handle[Ptr{Void}(0) for i=1:length(args)]
else
args_grad_hdr, args_grad = _get_ndarray_inputs("args_grad", args_grad, arg_names, true)
end

if isa(aux_states, Void); aux_states = NDArray[]; end
args_hdr, args = _get_ndarray_inputs("args", args, arg_names, false)
args_grad_hdr, args_grad = _get_ndarray_inputs("args_grad", args_grad, arg_names, true)
aux_args_hdr, aux_states = _get_ndarray_inputs("aux_states", aux_states, list_auxiliary_states(self), false)

if isa(grad_req, GRAD_REQ)
Expand All @@ -90,6 +89,17 @@ function bind(self :: Symbol, ctx :: Context, args :: Union{Vector{NDArray},Dict
executor = Executor(MX_ExecutorHandle(ref_hdr[]), self,
args, args_grad, aux_states)
end
function bind(self :: Symbol; kwargs...)
kwargs = Dict(kwargs)
@assert(haskey(kwargs, :args), "Must specify args")
args = pop!(kwargs, :args)
if haskey(kwargs, :context)
context = pop!(kwargs, :context)
else
context = cpu()
end
bind(self, context, args; kwargs...)
end

function simple_bind(self :: Symbol, ctx :: Context; grad_req :: GRAD_REQ=GRAD_WRITE, kwargs...)
arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...)
Expand Down

0 comments on commit 4ac5e7b

Please sign in to comment.