/
recom.jl
304 lines (264 loc) · 11.1 KB
/
recom.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""
sample_subgraph(graph::BaseGraph,
partition::Partition,
rng::AbstractRNG)
Randomly sample two adjacent districts D₁ and D₂ and return a tuple
(D₁, D₂, edges, nodes) where D₁ and D₂ are Ints, `edges` and `nodes` are Sets
containing the Int edges and Int nodes of the induced subgraph.
"""
function sample_subgraph(graph::BaseGraph,
partition::Partition,
rng::AbstractRNG)
D₁, D₂ = sample_adjacent_districts_randomly(partition, rng)
# take all their nodes
nodes = union(partition.dist_nodes[D₁],
partition.dist_nodes[D₂])
# get a subgraph of these two districts
edges = induced_subgraph_edges(graph, collect(nodes))
return D₁, D₂, edges, BitSet(nodes)
end
"""
build_mst(graph::BaseGraph,
nodes::BitSet,
edges::BitSet)::Dict{Int, Array{Int, 1}}
Builds a graph as an adjacency list from the `mst_nodes` and `mst_edges`.
"""
function build_mst(graph::BaseGraph,
nodes::BitSet,
edges::BitSet)::Dict{Int, Array{Int, 1}}
mst = Dict{Int, Array{Int, 1}}()
for node in nodes
mst[node] = Array{Int, 1}()
end
for edge in edges
add_edge_to_mst!(graph, mst, edge)
end
return mst
end
"""
remove_edge_from_mst!(graph::BaseGraph,
mst::Dict{Int, Array{Int,1}},
edge::Int)
Removes an edge from the graph built by `build_mst()`.
"""
function remove_edge_from_mst!(graph::BaseGraph,
mst::Dict{Int, Array{Int,1}},
edge::Int)
filter!(e -> e != graph.edge_dst[edge], mst[graph.edge_src[edge]])
filter!(e -> e != graph.edge_src[edge], mst[graph.edge_dst[edge]])
end
"""
add_edge_to_mst!(graph::BaseGraph,
mst::Dict{Int, Array{Int,1}},
edge::Int)
Adds an edge to the graph built by `build_mst()`.
"""
function add_edge_to_mst!(graph::BaseGraph,
mst::Dict{Int, Array{Int,1}},
edge::Int)
push!(mst[graph.edge_src[edge]], graph.edge_dst[edge])
push!(mst[graph.edge_dst[edge]], graph.edge_src[edge])
end
"""
traverse_mst(mst::Dict{Int, Array{Int, 1}},
start_node::Int,
avoid_node::Int,
stack::Stack{Int},
traversed_nodes::BitSet)::BitSet
Returns the component of the MST `mst` that contains the vertex
`start_node`.
*Arguments:*
- mst: mst to traverse
- start_node: the node to start traversing from
- avoid_node: the node to avoid adn which seperates the mst into
two components
- stack: an empty Stack
- traversed_nodes: an empty BitSet that is to be populated.
`stack` and `traversed_nodes` are are pre-allocated and passed in to
reduce the number of memory allocations and consequently, time taken.
In the course of calling this function multiple times, it is intended that
we pass in the same (empty) objects repeatedly.
"""
function traverse_mst(mst::Dict{Int, Array{Int, 1}},
start_node::Int,
avoid_node::Int,
stack::Stack{Int},
traversed_nodes::BitSet)::BitSet
@assert isempty(stack)
empty!(traversed_nodes)
push!(stack, start_node)
while !isempty(stack)
new_node = pop!(stack)
push!(traversed_nodes, new_node)
for neighbor in mst[new_node]
if !(neighbor in traversed_nodes) && neighbor != avoid_node
push!(stack, neighbor)
end
end
end
return traversed_nodes
end
"""
get_balanced_proposal(graph::BaseGraph,
mst_edges::BitSet,
mst_nodes::BitSet,
partition::Partition,
pop_constraint::PopulationConstraint,
D₁::Int,
D₂::Int)
Tries to find a balanced cut on the subgraph induced by `mst_edges` and
`mst_nodes` such that the population is balanced according to
`pop_constraint`.
This subgraph was formed by the combination of districts `D₁` and `D₂`.
"""
function get_balanced_proposal(graph::BaseGraph,
mst_edges::BitSet,
mst_nodes::BitSet,
partition::Partition,
pop_constraint::PopulationConstraint,
D₁::Int,
D₂::Int)
mst = build_mst(graph, mst_nodes, mst_edges)
subgraph_pop = partition.dist_populations[D₁] + partition.dist_populations[D₂]
# pre-allocated reusable data structures to reduce number of memory allocations
stack = Stack{Int}()
component_container = BitSet([])
for edge in mst_edges
component₁ = traverse_mst(mst,
graph.edge_src[edge],
graph.edge_dst[edge],
stack,
component_container)
population₁ = get_subgraph_population(graph, component₁)
population₂ = subgraph_pop - population₁
if satisfy_constraint(pop_constraint, population₁, population₂)
component₂ = setdiff(mst_nodes, component₁)
proposal = RecomProposal(D₁, D₂, population₁, population₂, component₁, component₂)
return proposal
end
end
return DummyProposal("Could not find balanced cut.")
end
"""
get_valid_proposal(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
rng::AbstractRNG,
num_tries::Int=3)
*Returns* a population balanced proposal.
*Arguments:*
- graph: BaseGraph
- partition: Partition
- pop_constraint: PopulationConstraint to adhere to
- num_tries: num times to try getting a balanced cut from a subgraph
before giving up
"""
function get_valid_proposal(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
rng::AbstractRNG,
num_tries::Int=3)
while true
D₁, D₂, sg_edges, sg_nodes = sample_subgraph(graph, partition, rng)
for _ in 1:num_tries
weights = rand(rng, length(sg_edges))
mst_edges = weighted_kruskal_mst(graph, sg_edges, collect(sg_nodes), weights)
# see if we can get a population-balanced cut in this mst
proposal = get_balanced_proposal(graph, mst_edges, sg_nodes,
partition, pop_constraint,
D₁, D₂)
if proposal isa RecomProposal return proposal end
end
end
end
"""
update_partition!(partition::Partition,
graph::BaseGraph,
proposal::RecomProposal,
copy_parent::Bool=false)
Updates the `Partition` with the `RecomProposal`.
"""
function update_partition!(partition::Partition,
graph::BaseGraph,
proposal::RecomProposal,
copy_parent::Bool=false)
if copy_parent
partition.parent = nothing
old_partition = deepcopy(partition)
partition.parent = old_partition
end
partition.dist_populations[proposal.D₁] = proposal.D₁_pop
partition.dist_populations[proposal.D₂] = proposal.D₂_pop
for node in proposal.D₁_nodes
partition.assignments[node] = proposal.D₁
end
for node in proposal.D₂_nodes
partition.assignments[node] = proposal.D₂
end
partition.dist_nodes[proposal.D₁] = proposal.D₁_nodes
partition.dist_nodes[proposal.D₂] = proposal.D₂_nodes
update_partition_adjacency(partition, graph)
end
"""
recom_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where {F<:Function, S<:AbstractScore}
Runs a Markov Chain for `num_steps` steps using ReCom. Returns a `ChainScoreData`
object which can be queried to retrieve the values of every score at each
step of the chain.
*Arguments:*
- graph: `BaseGraph`
- partition: `Partition` with the plan information
- pop_constraint: `PopulationConstraint`
- num_steps: Number of steps to run the chain for
- scores: Array of `AbstractScore`s to capture at each step
- num_tries: num times to try getting a balanced cut from a subgraph
before giving up
- acceptance_fn: A function generating a probability in [0, 1]
representing the likelihood of accepting the
proposal. Should accept a `Partition` as input.
- rng: Random number generator. The user can pass in their
own; otherwise, we use the default RNG from Random.
- no\\_self\\_loops: If this is true, then a failure to accept a new state
is not considered a self-loop; rather, the chain
simply generates new proposals until the acceptance
function is satisfied. BEWARE - this can create
infinite loops if the acceptance function is never
satisfied!
"""
function recom_chain(graph::BaseGraph,
partition::Partition,
pop_constraint::PopulationConstraint,
num_steps::Int,
scores::Array{S, 1};
num_tries::Int=3,
acceptance_fn::F=always_accept,
rng::AbstractRNG=Random.default_rng(),
no_self_loops::Bool=false)::ChainScoreData where
{F<:Function, S<:AbstractScore}
steps_taken = 0
first_scores = score_initial_partition(graph, partition, scores)
chain_scores = ChainScoreData(deepcopy(scores), [first_scores])
while steps_taken < num_steps
proposal = get_valid_proposal(graph, partition, pop_constraint, rng, num_tries)
custom_acceptance = acceptance_fn !== always_accept
update_partition!(partition, graph, proposal, custom_acceptance)
if custom_acceptance && !satisfies_acceptance_fn(partition, acceptance_fn)
# go back to the previous partition
partition = partition.parent
# if user specifies this behavior, we do not increment the steps
# taken if the acceptance function fails.
if no_self_loops continue end
end
score_vals = score_partition_from_proposal(graph, partition, proposal, scores)
push!(chain_scores.step_values, score_vals)
steps_taken += 1
end
return chain_scores
end