diff --git a/src/nlpmacros.jl b/src/nlpmacros.jl index ee547d5fd7c..56769efa7e4 100644 --- a/src/nlpmacros.jl +++ b/src/nlpmacros.jl @@ -16,9 +16,18 @@ function parseNLExpr(m, x, tapevar, parent, values) code = :(let; end) block = code.args[1] @assert isexpr(block, :block) - haskey(univariate_operator_to_id,x.args[1]) || error("Unrecognized function $(x.args[1]) used in nonlinear expression.") - operatorid = univariate_operator_to_id[x.args[1]] - push!(block.args, :(push!($tapevar, NodeData(CALLUNIVAR, $operatorid, $parent)))) + if haskey(univariate_operator_to_id,x.args[1]) + operatorid = univariate_operator_to_id[x.args[1]] + push!(block.args, :(push!($tapevar, NodeData(CALLUNIVAR, $operatorid, $parent)))) + else + opname = quot(x.args[1]) + errorstring = "Unrecognized function $opname used in nonlinear expression." + lookupcode = quote + haskey(univariate_operator_to_id,$opname) || error($errorstring) + operatorid = univariate_operator_to_id[$opname] + end + push!(block.args, :($lookupcode; push!($tapevar, NodeData(CALLUNIVAR, operatorid, $parent)))) + end parentvar = gensym() push!(block.args, :($parentvar = length($tapevar))) push!(block.args, parseNLExpr(m, x.args[2], tapevar, parentvar, values)) @@ -27,10 +36,19 @@ function parseNLExpr(m, x, tapevar, parent, values) code = :(let; end) block = code.args[1] @assert isexpr(block, :block) - haskey(operator_to_id,x.args[1]) || error("Unrecognized function $(x.args[1]) used in nonlinear expression.") - operatorid = operator_to_id[x.args[1]] + if haskey(operator_to_id,x.args[1]) # fast compile-time lookup + operatorid = operator_to_id[x.args[1]] + push!(block.args, :(push!($tapevar, NodeData(CALL, $operatorid, $parent)))) + else # could be user defined + opname = quot(x.args[1]) + errorstring = "Unrecognized function $opname used in nonlinear expression." + lookupcode = quote + haskey(operator_to_id,$opname) || error($errorstring) + operatorid = operator_to_id[$opname] + end + push!(block.args, :($lookupcode; push!($tapevar, NodeData(CALL, operatorid, $parent)))) + end parentvar = gensym() - push!(block.args, :(push!($tapevar, NodeData(CALL, $operatorid, $parent)))) push!(block.args, :($parentvar = length($tapevar))) for i in 1:length(x.args)-1 push!(block.args, parseNLExpr(m, x.args[i+1], tapevar, parentvar, values)) diff --git a/test/nonlinear.jl b/test/nonlinear.jl index 96264b583b0..01a7d8bca91 100644 --- a/test/nonlinear.jl +++ b/test/nonlinear.jl @@ -646,14 +646,16 @@ mysquare(x) = x^2 function myf(x,y) return (x-1)^2+(y-2)^2 end -registerNLFunction(:myf, 2, myf, autodiff=true) -registerNLFunction(:myf_2, 2, myf, (g,x,y) -> (g[1] = 2(x-1); g[2] = 2(y-2))) -registerNLFunction(:mysquare, 1, mysquare, autodiff=true) -registerNLFunction(:mysquare_2, 1, mysquare, x-> 2x, autodiff=true) -registerNLFunction(:mysquare_3, 1, mysquare, x-> 2x, x -> 2.0) + if length(convex_nlp_solvers) > 0 facts("[nonlinear] User-defined functions") do + registerNLFunction(:myf, 2, myf, autodiff=true) + registerNLFunction(:myf_2, 2, myf, (g,x,y) -> (g[1] = 2(x-1); g[2] = 2(y-2))) + registerNLFunction(:mysquare, 1, mysquare, autodiff=true) + registerNLFunction(:mysquare_2, 1, mysquare, x-> 2x, autodiff=true) + registerNLFunction(:mysquare_3, 1, mysquare, x-> 2x, x -> 2.0) + m = Model(solver=convex_nlp_solvers[1]) @defVar(m, x[1:2] >= 0.5)