diff --git a/src/macros.jl b/src/macros.jl index 68e08aea3f6..18ff69d76dd 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -83,6 +83,14 @@ Helper function for macros to construct container objects. Takes an `Expr` that """ _build_ref_sets(c) = _build_ref_sets(c, _get_name(c)) +function _expr_contains_splat(ex::Expr) + if ex.head == :(...) + return true + end + return any(_expr_contains_splat.(ex.args)) +end +_expr_contains_splat(::Any) = false + """ JuMP._get_looped_code(varname, code, condition, idxvars, idxsets, sym, requestedcontainer::Symbol; lowertri=false) @@ -593,6 +601,10 @@ function _constraint_macro(args, macro_name::Symbol, parsefun::Function) # we will wrap in loops to assign to the ConstraintRefs refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_contains_splat.(idxsets)) + _error("cannot use splatting operator `...`.") + end + vectorized, parsecode, buildcall = parsefun(_error, x.args...) _add_kw_args(buildcall, kw_args) if vectorized @@ -1008,7 +1020,7 @@ expr = @expression(m, [i=1:3], i*sum(x[j] for j=1:3)) ``` """ macro expression(args...) - + macro_error(str...) = _macro_error(:expression, args, str...) args, kw_args, requestedcontainer = _extract_kw_args(args) if length(args) == 3 m = esc(args[1]) @@ -1019,14 +1031,19 @@ macro expression(args...) c = gensym() x = args[2] else - error("@expression: needs at least two arguments.") + macro_error("needs at least two arguments.") end - length(kw_args) == 0 || error("@expression: unrecognized keyword argument") + length(kw_args) == 0 || macro_error("unrecognized keyword argument") anonvar = isexpr(c, :vect) || isexpr(c, :vcat) || length(args) == 2 variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + + if any(_expr_contains_splat.(idxsets)) + macro_error("cannot use splatting operator `...`.") + end + newaff, parsecode = _parse_expr_toplevel(x, :q) code = quote q = Val{false}() @@ -1397,6 +1414,10 @@ macro variable(args...) # We now build the code to generate the variables (and possibly the # SparseAxisArray to contain them) refcall, idxvars, idxsets, condition = _build_ref_sets(var, variable) + if any(_expr_contains_splat.(idxsets)) + _error("cannot use splatting operator `...`.") + end + clear_dependencies(i) = (Containers.is_dependent(idxvars,idxsets[i],i) ? () : idxsets[i]) # Code to be used to create each variable of the container. @@ -1510,6 +1531,10 @@ macro NLconstraint(m, x, extra...) # Strategy: build up the code for non-macro add_constraint, and if needed # we will wrap in loops to assign to the ConstraintRefs refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_contains_splat.(idxsets)) + error("@NLconstraint: cannot use splatting operator `...`.") + end + # Build the constraint if isexpr(x, :call) # one-sided constraint # Simple comparison - move everything to the LHS @@ -1606,6 +1631,10 @@ macro NLexpression(args...) variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_contains_splat.(idxsets)) + error("@NLexpression: cannot use splatting operator `...`.") + end + code = quote $(refcall) = NonlinearExpression($(esc(m)), $(_process_NL_expr(m, x))) end @@ -1672,6 +1701,10 @@ macro NLparameter(m, ex, extra...) variable = gensym() refcall, idxvars, idxsets, condition = _build_ref_sets(c, variable) + if any(_expr_contains_splat.(idxsets)) + error("@NLparameter: cannot use splatting operator `...`.") + end + code = quote if !isa($(esc(x)), Number) error(string("in @NLparameter (", $(string(ex)), "): expected ", diff --git a/test/macros.jl b/test/macros.jl index 46a2a8db503..01cbe56c554 100644 --- a/test/macros.jl +++ b/test/macros.jl @@ -481,6 +481,33 @@ end c = @NLconstraint(model, x == sum(1.0 for i in 1:0)) @test sprint(show, c) == "x - 0 = 0" || sprint(show, c) == "x - 0 == 0" end + + @testset "Splatting error" begin + model = Model() + A = [1 0; 0 1] + @variable(model, x) + + @test_macro_throws ErrorException( + "In `@variable(model, y[axes(A)...])`: cannot use splatting operator `...`." + ) @variable(model, y[axes(A)...]) + + @test_macro_throws ErrorException( + "In `@constraint(model, [i = [axes(A)...]], x >= i)`: cannot use splatting operator `...`." + ) @constraint(model, [i=[axes(A)...]], x >= i) + + @test_macro_throws ErrorException( + "@NLconstraint: cannot use splatting operator `...`." + ) @NLconstraint(model, [i=[axes(A)...]], x >= i) + + @test_macro_throws ErrorException( + "In `@expression(model, [i = [axes(A)...]], i * x)`: cannot use splatting operator `...`." + ) @expression(model, [i=[axes(A)...]], i * x) + + @test_macro_throws ErrorException( + "@NLexpression: cannot use splatting operator `...`." + ) @NLexpression(model, [i=[axes(A)...]], i * x) + end + end @testset "Macros for JuMPExtension.MyModel" begin