diff --git a/src/macros.jl b/src/macros.jl index 68e08aea3f6..814284ad8cf 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -83,6 +83,16 @@ Helper function for macros to construct container objects. Takes an `Expr` that """ _build_ref_sets(c) = _build_ref_sets(c, _get_name(c)) +function _expr_is_splat(ex::Expr) + if ex.head == :(...) + return true + elseif ex.head == :escape + return _expr_is_splat(ex.args[1]) + end + return false +end +_expr_is_splat(::Any) = false + """ JuMP._get_looped_code(varname, code, condition, idxvars, idxsets, sym, requestedcontainer::Symbol; lowertri=false) @@ -592,6 +602,9 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function) # Strategy: build up the code for add_constraint, and if needed # we will wrap in loops to assign to the ConstraintRefs refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_is_splat.(idxsets)) + _error("cannot use splatting operator `...`.") + end vectorized, parsecode, buildcall = parsefun(_error, x.args...) _add_kw_args(buildcall, kw_args) @@ -1008,7 +1021,7 @@ expr = @expression(m, [i=1:3], i*sum(x[j] for j=1:3)) ``` """ macro expression(args...) - + macro_error(str...) = _macro_error(:expression, args, str...) args, kw_args, requestedcontainer = _extract_kw_args(args) if length(args) == 3 m = esc(args[1]) @@ -1019,14 +1032,17 @@ macro expression(args...) c = gensym() x = args[2] else - error("@expression: needs at least two arguments.") + macro_error("needs at least two arguments.") end - length(kw_args) == 0 || error("@expression: unrecognized keyword argument") + length(kw_args) == 0 || macro_error("unrecognized keyword argument") anonvar = isexpr(c, :vect) || isexpr(c, :vcat) || length(args) == 2 variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_is_splat.(idxsets)) + macro_error("cannot use splatting operator `...`.") + end newaff, parsecode = _parse_expr_toplevel(x, :q) code = quote q = Val{false}() @@ -1393,10 +1409,12 @@ macro variable(args...) final_variable = variable else isa(var,Expr) || _error("Expected $var to be a variable name") - # We now build the code to generate the variables (and possibly the # SparseAxisArray to contain them) refcall, idxvars, idxsets, condition = _build_ref_sets(var, variable) + if any(_expr_is_splat.(idxsets)) + _error("cannot use splatting operator `...`.") + end clear_dependencies(i) = (Containers.is_dependent(idxvars,idxsets[i],i) ? () : idxsets[i]) # Code to be used to create each variable of the container. @@ -1510,6 +1528,9 @@ macro NLconstraint(m, x, extra...) # Strategy: build up the code for non-macro add_constraint, and if needed # we will wrap in loops to assign to the ConstraintRefs refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_is_splat.(idxsets)) + error("@NLconstraint: cannot use splatting operator `...`.") + end # Build the constraint if isexpr(x, :call) # one-sided constraint # Simple comparison - move everything to the LHS @@ -1606,6 +1627,9 @@ macro NLexpression(args...) variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_is_splat.(idxsets)) + error("@NLexpression: cannot use splatting operator `...`.") + end code = quote $(refcall) = NonlinearExpression($(esc(m)), $(_process_NL_expr(m, x))) end @@ -1663,7 +1687,6 @@ macro NLparameter(m, ex, extra...) end c = ex.args[2] x = ex.args[3] - anonvar = isexpr(c, :vect) || isexpr(c, :vcat) if anonvar error("In @NLparameter($m, $ex): Anonymous nonlinear parameter syntax is not currently supported") @@ -1672,6 +1695,9 @@ macro NLparameter(m, ex, extra...) variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_is_splat.(idxsets)) + error("@NLparameter: cannot use splatting operator `...`.") + end code = quote if !isa($(esc(x)), Number) error(string("in @NLparameter (", $(string(ex)), "): expected ", diff --git a/test/macros.jl b/test/macros.jl index 46a2a8db503..50847d2a8c6 100644 --- a/test/macros.jl +++ b/test/macros.jl @@ -481,6 +481,37 @@ end c = @NLconstraint(model, x == sum(1.0 for i in 1:0)) @test sprint(show, c) == "x - 0 = 0" || sprint(show, c) == "x - 0 == 0" end + + @testset "Splatting error" begin + model = Model() + A = [1 0; 0 1] + @variable(model, x) + + @test_macro_throws ErrorException( + "In `@variable(model, y[axes(A)...])`: cannot use splatting operator `...`." + ) @variable(model, y[axes(A)...]) + + f(a, b) = [a, b] + @variable(model, z[f((1, 2)...)]) + @test length(z) == 2 + + @test_macro_throws ErrorException( + "In `@constraint(model, [axes(A)...], x >= 1)`: cannot use splatting operator `...`." + ) @constraint(model, [axes(A)...], x >= 1) + + @test_macro_throws ErrorException( + "@NLconstraint: cannot use splatting operator `...`." + ) @NLconstraint(model, [axes(A)...], x >= 1) + + @test_macro_throws ErrorException( + "In `@expression(model, [axes(A)...], x)`: cannot use splatting operator `...`." + ) @expression(model, [axes(A)...], x) + + @test_macro_throws ErrorException( + "@NLexpression: cannot use splatting operator `...`." + ) @NLexpression(model, [axes(A)...], x) + end + end @testset "Macros for JuMPExtension.MyModel" begin