Skip to content

Commit

Permalink
Move spanning tree counter to updaters
Browse files Browse the repository at this point in the history
  • Loading branch information
pjrule committed Jul 6, 2021
1 parent 0a7c547 commit 6b88d36
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 24 deletions.
17 changes: 0 additions & 17 deletions gerrychain/partition/partition.py
Expand Up @@ -113,7 +113,6 @@ def __getitem__(self, key):

def __getattr__(self, key):
return self[key]

def keys(self):
return self.updaters.keys()

Expand Down Expand Up @@ -145,22 +144,6 @@ def plot(self, geometries=None, **kwargs):
)
return df.plot(column="assignment", **kwargs)

def get_num_spanning_trees(self, district):
'''
Given a district number, returns the number of spanning trees in the
subgraph of self corresponding to the district.
Uses Kirchoff's theorem to compute the number of spanning trees.
:param self: :class:`gerrychain.Partition`
:param district: A district in self
:return: The number of spanning trees in the subgraph of self
corresponding to district
'''
graph = self.subgraphs[district]
laplacian = networkx.laplacian_matrix(graph)
L = numpy.delete(numpy.delete(laplacian.todense(), 0, 0), 1, 1)
return math.exp(numpy.linalg.slogdet(L)[1])

@classmethod
def from_districtr_file(cls, graph, districtr_file, updaters=None):
"""Create a Partition from a districting plan created with `Districtr`_,
Expand Down
2 changes: 2 additions & 0 deletions gerrychain/updaters/__init__.py
Expand Up @@ -11,6 +11,7 @@
from .election import Election
from .flows import compute_edge_flows, flows_from_changes
from .tally import DataTally, Tally
from .spanning_trees import num_spanning_trees

__all__ = [
"flows_from_changes",
Expand All @@ -28,4 +29,5 @@
"CountySplit",
"compute_edge_flows",
"Election",
"num_spanning_trees"
]
29 changes: 29 additions & 0 deletions gerrychain/updaters/spanning_trees.py
@@ -0,0 +1,29 @@
"""Updaters that compute spanning tree statistics."""
import math
import numpy
import networkx


def _num_spanning_trees_in_district(partition, district):
"""Given a district ID, returns the number of spanning trees in the
subgraph of self corresponding to the district.
Uses Kirchoff's theorem to compute the number of spanning trees.
:param partition: :class:`gerrychain.Partition`
:param district: A district label (part) in the partition.
:return: The number of spanning trees in the subgraph of the
partition corresponding to district
"""
graph = partition.subgraphs[district]
laplacian = networkx.laplacian_matrix(graph)
L = numpy.delete(numpy.delete(laplacian.todense(), 0, 0), 1, 1)
return math.exp(numpy.linalg.slogdet(L)[1])


def num_spanning_trees(partition):
"""Returns the number of spanning trees in each part (district) of a partition."""
return {
part: _num_spanning_trees_in_district(partition, part)
for part in partition.parts
}
7 changes: 0 additions & 7 deletions tests/partition/test_partition.py
Expand Up @@ -42,13 +42,6 @@ def test_propose_random_flip_proposes_a_partition(example_partition):
assert isinstance(proposal, partition.__class__)


def test_get_num_spanning_trees(three_by_three_grid):
graph = three_by_three_grid
assignment = {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1}
partition = Partition(graph, assignment, {"cut_edges": cut_edges})
assert 192 == round(partition.get_num_spanning_trees(1))


@pytest.fixture
def example_geographic_partition():
graph = networkx.complete_graph(3)
Expand Down
13 changes: 13 additions & 0 deletions tests/updaters/test_spanning_trees.py
@@ -0,0 +1,13 @@
from gerrychain import Partition
from gerrychain.updaters import num_spanning_trees


def test_get_num_spanning_trees(three_by_three_grid):
assignment = {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1}
partition = Partition(
three_by_three_grid,
assignment,
{"num_spanning_trees": num_spanning_trees}
)
assert 192 == round(partition["num_spanning_trees"][1])
assert [1] == list(partition["num_spanning_trees"].keys())

0 comments on commit 6b88d36

Please sign in to comment.