Skip to content

Commit

Permalink
Merge pull request #36 from mlubin/ml/userscope
Browse files Browse the repository at this point in the history
introduce a scope for user-defined functions
  • Loading branch information
mlubin committed Dec 24, 2016
2 parents 72b2a30 + 34ef558 commit c3e5493
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 45 deletions.
22 changes: 12 additions & 10 deletions src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@
# convert from Julia expression into NodeData form


function expr_to_nodedata(ex::Expr)
function expr_to_nodedata(ex::Expr,r::UserOperatorRegistry=UserOperatorRegistry())
nd = NodeData[]
values = Float64[]
expr_to_nodedata(ex,nd,values,-1)
expr_to_nodedata(ex,nd,values,-1,r)
return nd,values
end

function expr_to_nodedata(ex::Expr,nd::Vector{NodeData},values::Vector{Float64},parentid)
function expr_to_nodedata(ex::Expr,nd::Vector{NodeData},values::Vector{Float64},parentid,r::UserOperatorRegistry)

myid = length(nd) + 1
if isexpr(ex,:call)
op = ex.args[1]
if length(ex.args) == 2
push!(nd,NodeData(CALLUNIVAR, univariate_operator_to_id[op], parentid))
id = haskey(univariate_operator_to_id,op) ? univariate_operator_to_id[op] : r.univariate_operator_to_id[op] + USER_UNIVAR_OPERATOR_ID_START - 1
push!(nd,NodeData(CALLUNIVAR, id, parentid))
elseif op in comparison_operators
push!(nd,NodeData(COMPARISON, comparison_operator_to_id[op], parentid))
else
push!(nd,NodeData(CALL, operator_to_id[op], parentid))
id = haskey(operator_to_id,op) ? operator_to_id[op] : r.multivariate_operator_to_id[op] + USER_OPERATOR_ID_START - 1
push!(nd,NodeData(CALL, id, parentid))
end
for k in 2:length(ex.args)
expr_to_nodedata(ex.args[k],nd,values,myid)
expr_to_nodedata(ex.args[k],nd,values,myid,r)
end
elseif isexpr(ex, :ref)
@assert ex.args[1] == :x
Expand All @@ -35,22 +37,22 @@ function expr_to_nodedata(ex::Expr,nd::Vector{NodeData},values::Vector{Float64},
end
push!(nd, NodeData(COMPARISON, opid, parentid))
for k in 1:2:length(ex.args)
expr_to_nodedata(ex.args[k],nd,values,myid)
expr_to_nodedata(ex.args[k],nd,values,myid,r)
end
elseif isexpr(ex,:&&) || isexpr(ex,:||)
@assert length(ex.args) == 2
op = ex.head
opid = logic_operator_to_id[op]
push!(nd, NodeData(LOGIC, opid, parentid))
expr_to_nodedata(ex.args[1],nd,values,myid)
expr_to_nodedata(ex.args[2],nd,values,myid)
expr_to_nodedata(ex.args[1],nd,values,myid,r)
expr_to_nodedata(ex.args[2],nd,values,myid,r)
else
error("Unrecognized expression $ex: $(ex.head)")
end
nothing
end

function expr_to_nodedata(ex::Number,nd::Vector{NodeData},values::Vector{Float64},parentid)
function expr_to_nodedata(ex::Number,nd::Vector{NodeData},values::Vector{Float64},parentid,r::UserOperatorRegistry)
valueidx = length(values)+1
push!(values,ex)
push!(nd, NodeData(VALUE, valueidx, parentid))
Expand Down
14 changes: 8 additions & 6 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# a general DAG. If we have a DAG, then need to associate storage with each edge of the DAG.
# user_input_buffer and user_output_buffer are used as temporary storage
# when handling user-defined functions
function forward_eval{T}(storage::Vector{T},partials_storage::Vector{T},nd::Vector{NodeData},adj,const_values,parameter_values,x_values::Vector{T},subexpression_values,user_input_buffer=[],user_output_buffer=[])
function forward_eval{T}(storage::Vector{T},partials_storage::Vector{T},nd::Vector{NodeData},adj,const_values,parameter_values,x_values::Vector{T},subexpression_values,user_input_buffer=[],user_output_buffer=[];user_operators::UserOperatorRegistry=UserOperatorRegistry())

@assert length(storage) >= length(nd)
@assert length(partials_storage) >= length(nd)
Expand Down Expand Up @@ -115,7 +115,7 @@ function forward_eval{T}(storage::Vector{T},partials_storage::Vector{T},nd::Vect
@inbounds partials_storage[children_arr[idx1+2]] = !(condition == 1)
storage[k] = ifelse(condition == 1, lhs, rhs)
elseif op >= USER_OPERATOR_ID_START
evaluator = user_operator_map[op]
evaluator = user_operators.multivariate_operator_evaluator[op - USER_OPERATOR_ID_START+1]
f_input = view(user_input_buffer, 1:n_children)
grad_output = view(user_output_buffer, 1:n_children)
r = 1
Expand Down Expand Up @@ -143,8 +143,9 @@ function forward_eval{T}(storage::Vector{T},partials_storage::Vector{T},nd::Vect
#@assert child_idx == children_arr[first(nzrange(adj,k))]
child_val = storage[child_idx]
if op >= USER_UNIVAR_OPERATOR_ID_START
f = user_univariate_operator_f[op]
fprime = user_univariate_operator_fprime[op]
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
f = user_operators.univariate_operator_f[userop]
fprime = user_operators.univariate_operator_fprime[userop]
fval = f(child_val)::T
fprimeval = fprime(child_val)::T
else
Expand Down Expand Up @@ -206,7 +207,7 @@ export forward_eval
# need to recompute the real components.
# Computes partials_storage_ϵ as well
# We assume that forward_eval has already been called.
function forward_eval_ϵ{N,T}(storage::Vector{T},storage_ϵ::DenseVector{ForwardDiff.Partials{N,T}},partials_storage::Vector{T},partials_storage_ϵ::DenseVector{ForwardDiff.Partials{N,T}},nd::Vector{NodeData},adj,x_values_ϵ,subexpression_values_ϵ)
function forward_eval_ϵ{N,T}(storage::Vector{T},storage_ϵ::DenseVector{ForwardDiff.Partials{N,T}},partials_storage::Vector{T},partials_storage_ϵ::DenseVector{ForwardDiff.Partials{N,T}},nd::Vector{NodeData},adj,x_values_ϵ,subexpression_values_ϵ;user_operators::UserOperatorRegistry=UserOperatorRegistry())

@assert length(storage_ϵ) >= length(nd)
@assert length(partials_storage_ϵ) >= length(nd)
Expand Down Expand Up @@ -317,7 +318,8 @@ function forward_eval_ϵ{N,T}(storage::Vector{T},storage_ϵ::DenseVector{Forward
@inbounds child_idx = children_arr[adj.colptr[k]]
child_val = storage[child_idx]
if op >= USER_UNIVAR_OPERATOR_ID_START
fprimeprime = user_univariate_operator_fprimeprime[op](child_val)::T
userop = op - USER_UNIVAR_OPERATOR_ID_START + 1
fprimeprime = user_operators.univariate_operator_fprimeprime[userop](child_val)::T
else
fprimeprime = eval_univariate_2nd_deriv(op, child_val,storage[k])
end
Expand Down
6 changes: 4 additions & 2 deletions src/linearity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,16 @@ function classify_linearity(nd::Vector{NodeData},adj,subexpression_linearity,fix
# if operator is nonlinear, then we're nonlinear
op = nod.index
if nod.nodetype == CALLUNIVAR
if univariate_operators[op] == :+ || univariate_operators[op] == :-
if op < USER_UNIVAR_OPERATOR_ID_START && (univariate_operators[op] == :+ || univariate_operators[op] == :-)
linearity[k] = LINEAR
else
linearity[k] = NONLINEAR
end
elseif nod.nodetype == CALL
# operator with more than 1 argument
if operators[op] == :+ || operators[op] == :-
if op >= USER_OPERATOR_ID_START
linearity[k] = NONLINEAR
elseif operators[op] == :+ || operators[op] == :-
linearity[k] = LINEAR
elseif operators[op] == :* && num_constant_children == length(children_idx) - 1
linearity[k] = LINEAR
Expand Down
6 changes: 4 additions & 2 deletions src/sparsity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ function compute_hessian_sparsity(nd::Vector{NodeData},adj,input_linearity::Vect
par = nd[nod.parent]
if par.nodetype == CALLUNIVAR
op = par.index
if univariate_operators[op] != :+ && univariate_operators[op] != :-
if op >= USER_UNIVAR_OPERATOR_ID_START || univariate_operators[op] != :+ && univariate_operators[op] != :-
nonlinear_wrt_output[k] = true
end
elseif par.nodetype == CALL
op = par.index
if operators[op] == :+ || operators[op] == :- || operators[op] == :ifelse
if op >= USER_OPERATOR_ID_START
nonlinear_wrt_output[k] = true
elseif operators[op] == :+ || operators[op] == :- || operators[op] == :ifelse
# pass
elseif operators[op] == :*
# check if all siblings are constant
Expand Down
46 changes: 25 additions & 21 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,40 +60,44 @@ export comparison_operator_to_id, comparison_operators


# user-provided operators
type UserOperatorRegistry
multivariate_operator_to_id::Dict{Symbol,Int}
multivariate_operator_evaluator::Vector{MathProgBase.AbstractNLPEvaluator}
univariate_operator_to_id::Dict{Symbol,Int}
univariate_operator_f::Vector{Any}
univariate_operator_fprime::Vector{Any}
univariate_operator_fprimeprime::Vector{Any}
end

const user_operator_map = Dict{Int,MathProgBase.AbstractNLPEvaluator}()
UserOperatorRegistry() = UserOperatorRegistry(Dict{Symbol,Int}(),Vector{MathProgBase.AbstractNLPEvaluator}(0),Dict{Symbol,Int}(),[],[],[])

# we use the MathProgBase NLPEvaluator interface, where the
# operator takes the place of the objective function.
# users should implement eval_f and eval_grad_f for now.
# we will eventually support hessians too
function register_multivariate_operator(s::Symbol,f::MathProgBase.AbstractNLPEvaluator)
!haskey(operator_to_id, s) || error("Operator $s has already been defined")
id = length(operators)+1
push!(operators,s)
operator_to_id[s] = id
user_operator_map[id] = f
function register_multivariate_operator!(r::UserOperatorRegistry,s::Symbol,f::MathProgBase.AbstractNLPEvaluator)
haskey(r.multivariate_operator_to_id, s) && error("Operator $s has already been defined")
id = length(r.multivariate_operator_evaluator)+1
r.multivariate_operator_to_id[s] = id
push!(r.multivariate_operator_evaluator,f)
return
end

export register_multivariate_operator

const user_univariate_operator_f = Dict{Int,Any}()
const user_univariate_operator_fprime = Dict{Int,Any}()
const user_univariate_operator_fprimeprime = Dict{Int,Any}()
export register_multivariate_operator!

# for univariate operators, just take in functions to evaluate
# zeroth, first, and second order derivatives
function register_univariate_operator(s::Symbol,f,fprime,fprimeprime)
!haskey(univariate_operator_to_id, s) || error("Operator $s has already been defined")
id = length(univariate_operators)+1
push!(univariate_operators,s)
univariate_operator_to_id[s] = id
user_univariate_operator_f[id] = f
user_univariate_operator_fprime[id] = fprime
user_univariate_operator_fprimeprime[id] = fprimeprime
function register_univariate_operator!(r::UserOperatorRegistry,s::Symbol,f,fprime,fprimeprime)
haskey(r.univariate_operator_to_id, s) && error("Operator $s has already been defined")
id = length(r.univariate_operator_f)+1
r.univariate_operator_to_id[s] = id
push!(r.univariate_operator_f,f)
push!(r.univariate_operator_fprime,fprime)
push!(r.univariate_operator_fprimeprime,fprimeprime)
return
end

export register_univariate_operator
export register_univariate_operator!


function has_user_multivariate_operators(nd::Vector{NodeData})
Expand Down
9 changes: 5 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,19 @@ function MathProgBase.eval_grad_f(::CDFEvaluator,grad,x)
grad[1] = -pdf(Normal(0,1),x[1])
grad[2] = pdf(Normal(0,1),x[2])
end
register_multivariate_operator(,CDFEvaluator())
register_univariate_operator(:c,cos,x->-sin(x),x->-cos(x))
r = ReverseDiffSparse.UserOperatorRegistry()
register_multivariate_operator!(r,,CDFEvaluator())
register_univariate_operator!(r,:c,cos,x->-sin(x),x->-cos(x))
Φ(x,y) = MathProgBase.eval_f(CDFEvaluator(),[x,y])
ex = :(Φ(x[2],x[1]-1)*c(x[3]))
nd,const_values = expr_to_nodedata(ex)
nd,const_values = expr_to_nodedata(ex,r)
@test ReverseDiffSparse.has_user_multivariate_operators(nd)
adj = adjmat(nd)
storage = zeros(length(nd))
partials_storage = zeros(length(nd))
reverse_storage = zeros(length(nd))
x = [2.0,3.0,4.0]
fval = forward_eval(storage,partials_storage,nd,adj,const_values,[],x,[],zeros(2),zeros(2))
fval = forward_eval(storage,partials_storage,nd,adj,const_values,[],x,[],zeros(2),zeros(2),user_operators=r)
true_val = Φ(x[2],x[1]-1)*cos(x[3])
@test isapprox(fval,true_val)
grad = zeros(3)
Expand Down

0 comments on commit c3e5493

Please sign in to comment.