Skip to content

Commit

Permalink
parametric NDArray (#331)
Browse files Browse the repository at this point in the history
* ndarray: add outer constrcutor for AbstractArray

* ndarray: refine copy

* ndarray: refine copy!

* ndarray: refine convert

* ndarray: refine add_to!

* ndarray: refine sub_from!

* ndarray: refine mul_to!

* ndarray: refine div_from!

* ndarray: refine rdiv_from!

* ndarray: refine _wait_to_read/_wait_to_write

* ndarray: refine is_shared

* ndarray: refine save

* ndarray: refine dot

* ndarray: VecOfNDArray

* executor: refine backward

* ndarray: refine empty

* executor: refine bind
  • Loading branch information
iblislin committed Dec 1, 2017
1 parent 2a5a284 commit cb042fd
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 386 deletions.
31 changes: 17 additions & 14 deletions src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
See also [`every_n_epoch`](@ref) and [`speedometer`](@ref).
"""
function every_n_batch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
function every_n_batch(callback::Function, n::Int; call_on_0::Bool = false)
BatchCallback(n, call_on_0, callback)
end
function (cb :: BatchCallback)(state :: OptimizationState)
Expand All @@ -62,7 +62,7 @@ function (cb :: BatchCallback)(state :: OptimizationState)
end

"""
speedometer(; frequency=50)
speedometer(;frequency=50)
Create an `AbstractBatchCallback` that measure the training speed
(number of samples processed per second) every k mini-batches.
Expand All @@ -71,9 +71,9 @@ Create an `AbstractBatchCallback` that measure the training speed
* `frequency::Int`: keyword argument, default 50. The frequency (number of
min-batches) to measure and report the speed.
"""
function speedometer(;frequency::Int=50)
function speedometer(;frequency::Int = 50)
cl_tic = 0
every_n_batch(frequency, call_on_0=true) do state :: OptimizationState
every_n_batch(frequency, call_on_0 = true) do state::OptimizationState
if state.curr_batch == 0
# reset timer
cl_tic = time()
Expand Down Expand Up @@ -104,10 +104,11 @@ A convenient function to construct a callback that runs every `n` full data-pass
See also [`every_n_batch`](@ref).
"""
function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
every_n_epoch(callback::Function, n::Int; call_on_0::Bool = false) =
EpochCallback(n, call_on_0, callback)
end
function (cb :: EpochCallback)(model :: Any, state :: OptimizationState, metric :: Vector{Tuple{Base.Symbol, T}}) where T<:Real

function (cb::EpochCallback)(model::Any, state::OptimizationState,
metric::Vector{Tuple{Symbol, T}}) where T<:Real
if state.curr_epoch == 0
if cb.call_on_0
cb.callback(model, state, metric)
Expand All @@ -124,15 +125,17 @@ Create an `AbstractEpochCallback` that save checkpoints of the model to disk.
The checkpoints can be loaded back later on.
# Arguments
* `prefix::AbstractString`: the prefix of the filenames to save the model. The model
architecture will be saved to prefix-symbol.json, while the weights will be saved
to prefix-0012.params, for example, for the 12-th epoch.
* `frequency::Int`: keyword argument, default 1. The frequency (measured in epochs) to
save checkpoints.
* `prefix::AbstractString`: the prefix of the filenames to save the model.
The model architecture will be saved to prefix-symbol.json,
while the weights will be saved to prefix-0012.params,
for example, for the 12-th epoch.
* `frequency::Int`: keyword argument, default is 1.
The frequency (measured in epochs) to save checkpoints.
* `save_epoch_0::Bool`: keyword argument, default false. Whether we should save a
checkpoint for epoch 0 (model initialized but not seen any data yet).
checkpoint for epoch 0 (model initialized but not seen any data yet).
"""
function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false)
function do_checkpoint(prefix::AbstractString;
frequency::Int = 1, save_epoch_0::Bool = false)
mkpath(dirname(prefix))
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric
save_checkpoint(model, prefix, state)
Expand Down
113 changes: 52 additions & 61 deletions src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,49 @@ be carried out with an executor.
mutable struct Executor
handle :: MX_ExecutorHandle
symbol :: SymbolicNode
arg_arrays :: Vector{NDArray}
grad_arrays :: Vector{Union{Void,NDArray}}
aux_arrays :: Vector{NDArray}
outputs :: Vector{NDArray}
arg_dict :: Dict{Base.Symbol, NDArray}
aux_dict :: Dict{Base.Symbol, NDArray}
arg_arrays :: VecOfNDArray
grad_arrays :: Vector{Union{Void,<:NDArray}}
aux_arrays :: VecOfNDArray
outputs :: VecOfNDArray
arg_dict :: Dict{Symbol}
aux_dict :: Dict{Symbol}
end
function Executor(hdr :: MX_ExecutorHandle, symbol :: SymbolicNode,
arg_arrays :: Vector{NDArray}, grad_arrays :: Vector{Union{Void,NDArray}},
aux_arrays :: Vector{NDArray})

function Executor(hdl::MX_ExecutorHandle, sym::SymbolicNode,
arg_arrays::VecOfNDArray, grad_arrays::AbstractVector,
aux_arrays::VecOfNDArray)
# get output arrays
ref_size = Ref{MX_uint}(0)
ref_hdrs = Ref{Ptr{MX_handle}}(0)
ref_hdls = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXExecutorOutputs, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_handle}}),
hdr, ref_size, ref_hdrs)
out_hdrs = unsafe_wrap(Array, ref_hdrs[], ref_size[])
hdl, ref_size, ref_hdls)
out_hdrs = unsafe_wrap(Array, ref_hdls[], ref_size[])
out_arrays = [NDArray(MX_NDArrayHandle(x)) for x in out_hdrs]

arg_names = list_arguments(symbol)
arg_names = list_arguments(sym)
@assert(length(arg_names) == length(unique(arg_names)), "Duplicated names in arguments: $arg_names")
arg_dict = Dict{Base.Symbol,NDArray}(zip(arg_names, arg_arrays))
arg_dict = Dict(zip(arg_names, arg_arrays))

aux_names = list_auxiliary_states(symbol)
aux_names = list_auxiliary_states(sym)
@assert(length(aux_names) == length(unique(aux_names)), "Duplicated names in auxiliary states: $aux_names")
aux_dict = Dict{Base.Symbol,NDArray}(zip(aux_names, aux_arrays))
aux_dict = Dict(zip(aux_names, aux_arrays))

Executor(hdr, symbol, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict)
Executor(hdl, sym, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict)
end

function Base.unsafe_convert(::Type{MX_handle}, obj::Executor)
Base.unsafe_convert(::Type{MX_handle}, obj::Executor) =
Base.unsafe_convert(MX_handle, obj.handle)
end
Base.convert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)
Base.cconvert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)

function _get_ndarray_inputs(arg_key::AbstractString, args::Vector{NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool)
function _get_ndarray_inputs(arg_key::AbstractString, args::VecOfNDArray,
arg_names::Vector{Symbol}, allow_missing::Bool)
@assert(length(args) == length(arg_names), "Length of $arg_key does not match number of arguments")
return (MX_handle[args...], args)
end
function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Base.Symbol,NDArray}, arg_names::Vector{Base.Symbol}, allow_missing::Bool)

function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol},
arg_names::Vector{Symbol}, allow_missing::Bool)
args_vec = map(arg_names) do name
arr = get(args, name, nothing)
if !allow_missing
Expand Down Expand Up @@ -75,16 +78,16 @@ Create an `Executor` by binding a `SymbolicNode` to concrete `NDArray`.
* `ctx::Context`: the context on which the computation should run.
* `args`: either a list of `NDArray` or a dictionary of name-array pairs. Concrete
arrays for all the inputs in the network architecture. The inputs typically include
network parameters (weights, bias, filters, etc.), data and labels. See [`list_arguments`](@ref)
and [`infer_shape`](@ref).
* `args_grad`:
* `aux_states`:
* `grad_req`:
network parameters (weights, bias, filters, etc.), data and labels.
See [`list_arguments`](@ref) and [`infer_shape`](@ref).
* `args_grad`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `aux_states`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `grad_req`: single value, a `Vector` of `GRAD_REQ` or a `Dict{Symbol,GRAD_REQ}`
"""
function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}};
args_grad :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
aux_states :: Union{Vector{NDArray},Dict{Base.Symbol,NDArray}} = Dict{Base.Symbol,NDArray}(),
grad_req :: Union{GRAD_REQ,Vector{GRAD_REQ},Dict{Base.Symbol,GRAD_REQ}} = GRAD_WRITE)
function bind(self::SymbolicNode, ctx::Context, args;
args_grad = Dict{Symbol,NDArray}(),
aux_states = Dict{Symbol,NDArray}(),
grad_req = GRAD_WRITE)

arg_names = list_arguments(self)

Expand All @@ -97,7 +100,7 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray
elseif isa(grad_req, Vector{GRAD_REQ})
@assert(length(grad_req) == length(args))
reqs = MX_uint[grad_req...]
elseif isa(grad_req, Dict{Base.Symbol, GRAD_REQ})
elseif isa(grad_req, Dict{Symbol, GRAD_REQ})
reqs = MX_uint[get(grad_req, name, GRAD_NOP) for name in arg_names]
end

Expand All @@ -111,20 +114,16 @@ function bind(self :: SymbolicNode, ctx :: Context, args :: Union{Vector{NDArray
executor = Executor(MX_ExecutorHandle(ref_hdr[]), self,
args, args_grad, aux_states)
end
function bind(self :: SymbolicNode; kwargs...)

function bind(x::SymbolicNode; context::Context = cpu(), kwargs...)
kwargs = Dict(kwargs)
@assert(haskey(kwargs, :args), "Must specify args")
args = pop!(kwargs, :args)
if haskey(kwargs, :context)
context = pop!(kwargs, :context)
else
context = cpu()
end
bind(self, context, args; kwargs...)
bind(x, context, args; kwargs...)
end

function simple_bind(self :: SymbolicNode, ctx :: Context;
grad_req :: Union{GRAD_REQ, Dict{Symbol, GRAD_REQ}}=GRAD_WRITE,
function simple_bind(self::SymbolicNode, ctx::Context;
grad_req::Union{GRAD_REQ,Dict{Symbol,GRAD_REQ}} = GRAD_WRITE,
kwargs...)
arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...)
@assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference")
Expand Down Expand Up @@ -168,21 +167,15 @@ function forward(self::Executor; is_train::Bool = false, kwargs...)
self.outputs
end

function backward(self :: Executor)
backward(self, NDArray[])
end
function backward(self :: Executor, out_grad :: NDArray)
backward(self, [out_grad])
end
function backward(self :: Executor, out_grads :: Vector{NDArray})
out_grads = MX_handle[out_grads...]
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads)
end
backward(x::Executor) = backward(x, NDArray[])
backward(x::Executor, out_grad::NDArray) = backward(x, [out_grad])
backward(x::Executor, out_grads::VecOfNDArray) =
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}),
x, length(out_grads), MX_handle[out_grads...])


function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray},
aux_params::Union{Void,Dict{Base.Symbol,NDArray}}=nothing;
allow_extra_params::Bool=false)
function copy_params_from(self::Executor, arg_params::Dict{Symbol},
aux_params::Dict{Symbol} = Dict{Symbol,Any}();
allow_extra_params::Bool = false)
for (name, array) in arg_params
if haskey(self.arg_dict, name)
copy!(self.arg_dict[name], array)
Expand All @@ -191,13 +184,11 @@ function copy_params_from(self::Executor, arg_params::Dict{Base.Symbol,NDArray},
end
end

if !isa(aux_params, Void)
for (name, array) in aux_params
if haskey(self.aux_dict, name)
copy!(self.aux_dict[name], array)
else
@assert(allow_extra_params, "Extra auxiliary state $name not recognized")
end
for (name, array) in aux_params
if haskey(self.aux_dict, name)
copy!(self.aux_dict[name], array)
else
@assert(allow_extra_params, "Extra auxiliary state $name not recognized")
end
end
end
Expand Down

0 comments on commit cb042fd

Please sign in to comment.