Skip to content

Commit

Permalink
Fixing KeyIterator used for JuMPArray to work also when index sets ar…
Browse files Browse the repository at this point in the history
…e non indexable. (#836)

* keys() work as expected now for JuMPArray when index set is not indexable

* minor comment modification

* fixed indentation to be 4 spaces

* moved the initialization of KeyIterator state from constructor to Base.start

* improved KeyIterator of JuMPArrays in two ways:
- now the functions start next done do not modify arguments
- less allocation thanks to caching and @generated functions

* added JuMPKey, a tuple wrapper, to be used by keys(::JuMPArray)

* added specialization of the iterator to keep good performance
when having indexable sets

* - removed JuMPKey (to be added in separate PR)
- removed traits for indexability of indexsets
- added a test with sets indexsets of JuMPArray
  • Loading branch information
IssamT authored and joehuchette committed Sep 27, 2016
1 parent 9865cd3 commit b5225ba
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/JuMPArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
end

Base.getindex(d::JuMPArray, ::Colon) = d.innerArray[:]

@generated function Base.getindex{T,N,NT<:NTuple}(d::JuMPArray{T,N,NT}, idx...)
if N != length(idx)
error("Indexed into a JuMPArray with $(length(idx)) indices (expected $N indices)")
Expand Down
81 changes: 75 additions & 6 deletions src/JuMPContainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ Base.ndims{T,N}(x::JuMPDict{T,N}) = N
Base.abs(x::JuMPDict) = map(abs, x)
# avoid dangerous behavior with "end" (#730)
Base.endof(x::JuMPArray) = error("endof() (and \"end\" syntax) not implemented for JuMPArray objects.")
Base.size(x::JuMPArray) = error("size (and \"end\" syntax) not implemented for JuMPArray objects. Use JuMP.size if you want to access the dimensions.")
Base.size(x::JuMPArray,k) = error("size (and \"end\" syntax) not implemented for JuMPArray objects. Use JuMP.size if you want to access the dimensions.")
Base.size(x::JuMPArray) = error(string("size (and \"end\" syntax) not implemented for JuMPArray objects.",
"Use JuMP.size if you want to access the dimensions."))
Base.size(x::JuMPArray,k) = error(string("size (and \"end\" syntax) not implemented for JuMPArray objects.",
" Use JuMP.size if you want to access the dimensions."))
size(x::JuMPArray) = size(x.innerArray)
size(x::JuMPArray,k) = size(x.innerArray,k)
# for uses of size() within JuMP
Expand Down Expand Up @@ -183,13 +185,80 @@ Base.length(it::ValueIterator) = length(it.x)

type KeyIterator{JA<:JuMPArray}
x::JA
dim::Int
next_k_cache::Array{Any,1}
function KeyIterator(d)
n = ndims(d.innerArray)
new(d, n, Array(Any, n+1))
end
end

KeyIterator{JA}(d::JA) = KeyIterator{JA}(d)

function indexability(x::JuMPArray)
for i in 1:length(x.indexsets)
if !method_exists(getindex, (typeof(x.indexsets[i]),))
return false
end
end

return true
end

function Base.start(it::KeyIterator)
if indexability(it.x)
return start(it.x.innerArray)
else
return notindexable_start(it.x)
end
end

@generated function notindexable_start{T,N,NT}(x::JuMPArray{T,N,NT})
quote
$(Expr(:tuple, 0, [:(start(x.indexsets[$i])) for i in 1:N]...))
end
end
Base.start(it::KeyIterator) = start(it.x.innerArray)
@generated __next{T,N,NT}(x::JuMPArray{T,N,NT}, k) =

@generated function _next{T,N,NT}(x::JuMPArray{T,N,NT}, k::Tuple)
quote
$(Expr(:tuple, [:(next(x.indexsets[$i], k[$i+1])[1]) for i in 1:N]...))
end
end

function Base.next(it::KeyIterator, k::Tuple)
cartesian_key = _next(it.x, k)
pos = -1
for i in 1:it.dim
if !done(it.x.indexsets[i], next(it.x.indexsets[i], k[i+1])[2] )
pos = i
break
end
end
if pos == - 1
it.next_k_cache[1] = 1
return cartesian_key, tuple(it.next_k_cache...)
end
it.next_k_cache[1] = 0
for i in 1:it.dim
if i < pos
it.next_k_cache[i+1] = start(it.x.indexsets[i])
elseif i == pos
it.next_k_cache[i+1] = next(it.x.indexsets[i], k[i+1])[2]
else
it.next_k_cache[i+1] = k[i+1]
end
end
cartesian_key, tuple(it.next_k_cache...)
end

Base.done(it::KeyIterator, k::Tuple) = (k[1] == 1)

@generated __next{T,N,NT}(x::JuMPArray{T,N,NT}, k::Integer) =
quote
subidx = ind2sub(size(x),k)
$(Expr(:tuple, [:(x.indexsets[$i][subidx[$i]]) for i in 1:N]...)), next(x.innerArray,k)[2]
end
Base.next(it::KeyIterator, k) = __next(it.x,k)
Base.done(it::KeyIterator, k) = done(it.x.innerArray, k)
Base.next(it::KeyIterator, k) = __next(it.x,k::Integer)
Base.done(it::KeyIterator, k) = done(it.x.innerArray, k::Integer)

Base.length(it::KeyIterator) = length(it.x.innerArray)
21 changes: 21 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,24 @@ facts("[model] Nonliteral exponents in @constraint") do
@fact m.quadconstr[3].terms --> x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 - 1
@fact m.quadconstr[4].terms --> QuadExpr(x + x + x - 1)
end

facts("[model] sets used as indexsets in JuMPArray") do
set = IntSet()
for i in 4:5
push!(set, i)
end
set2 = IntSet()
for i in 21:23
push!(set2, i)
end
m = Model()
@variable(m, x[set, set2], Bin)
@objective(m , Max, sum{sum{x[e,p], e in set}, p in set2})
solve(m)
sol = getvalue(x)
checked_objval = 0
for i in keys(sol)
checked_objval += sol[i...]
end
@fact checked_objval --> 6
end

0 comments on commit b5225ba

Please sign in to comment.