Skip to content

Commit

Permalink
Add doc strings and type hints to tree.py. Also optimize prime factor fn
Browse files Browse the repository at this point in the history
  • Loading branch information
peterrrock2 committed Jan 9, 2024
1 parent d5fa077 commit aa02ee4
Showing 1 changed file with 195 additions and 59 deletions.
254 changes: 195 additions & 59 deletions gerrychain/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import signature
from .random import random
from collections import deque, namedtuple
from typing import Any, Callable, Dict, List, Optional, Set, Union, Sequence
from typing import Any, Callable, Dict, List, Optional, Set, Union, Hashable, Sequence, Tuple


def predecessors(h: nx.Graph, root: Any) -> Dict:
Expand Down Expand Up @@ -45,11 +45,17 @@ def random_spanning_tree(graph: nx.Graph, weight_dict: Dict) -> nx.Graph:
return spanning_tree


def uniform_spanning_tree(graph: nx.Graph, choice: Callable = random.choice) -> nx.Graph:
""" Builds a spanning tree chosen uniformly from the space of all
spanning trees of the graph.
:param graph: Networkx Graph
:param choice: :func:`random.choice`
def uniform_spanning_tree(
graph: nx.Graph,
choice: Callable = random.choice
) -> nx.Graph:
"""
Builds a spanning tree chosen uniformly from the space of all
spanning trees of the graph. Uses Wilson's algorithm.
:param graph: Networkx Graph
:type graph: nx.Graph
:param choice: :func:`random.choice`. Defaults to :func:`random.choice`.
:type choice: Callable, optional

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Consistent use of Optional

"""
root = choice(list(graph.node_indices))
tree_nodes = set([root])
Expand All @@ -75,6 +81,19 @@ def uniform_spanning_tree(graph: nx.Graph, choice: Callable = random.choice) ->


class PopulatedGraph:
"""
A class representing a graph with population information.
:param graph: The underlying graph structure.
:type graph: nx.Graph
:param populations: A dictionary mapping nodes to their populations.
:type populations: Dict
:param ideal_pop: The ideal population for each district.
:type ideal_pop: float

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Should ideal_pop be a float and an int?

:param epsilon: The tolerance for population deviation from the ideal population within each

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Match the other descriptions of this, I think they mention "as a percentage of the ideal..."

district.
:type epsilon: float
"""
def __init__(
self,
graph: nx.Graph,
Expand Down Expand Up @@ -107,12 +126,27 @@ def has_ideal_population(self, node) -> bool:
)



# Tuple that is used in the find_balanced_edge_cuts function
# Comment added to make this easier to find
Cut = namedtuple("Cut", "edge subset")


def find_balanced_edge_cuts_contraction(
h: PopulatedGraph, choice: Callable = random.choice) -> List[Cut]:
# this used to be greater than 2 but failed on small grids:(
h: PopulatedGraph,
choice: Callable = random.choice
) -> List[Cut]:
"""
Find balanced edge cuts using contraction.
:param h: The populated graph.
:type h: PopulatedGraph
:param choice: The function used to make random choices.
:type choice: Callable, optional

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Optional consistency

:return: A list of balanced edge cuts.
:rtype: List[Cut]
"""

root = choice([x for x in h if h.degree(x) > 1])
# BFS predecessors for iteratively contracting leaves
pred = predecessors(h.graph, root)
Expand All @@ -135,6 +169,21 @@ def find_balanced_edge_cuts_memoization(
h: PopulatedGraph,
choice: Callable = random.choice
) -> List[Any]:

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Does function also return the list as Cut tuples?

"""
Find balanced edge cuts using memoization.
This function takes a PopulatedGraph object and a choice function as input and returns a list of balanced edge cuts.
A balanced edge cut is defined as a cut that divides the graph into two subsets, such that the population of each subset
is close to the ideal population defined by the PopulatedGraph object.
:param h: The PopulatedGraph object representing the graph.
:type h: PopulatedGraph
:param choice: The choice function used to select the root node.
:type choice: Callable, optional

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

Optional

:return: A list of balanced edge cuts.
:rtype: List[Any]

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

See list of Cut comment above

"""

root = choice([x for x in h if h.degree(x) > 1])
pred = predecessors(h.graph, root)
succ = successors(h.graph, root)
Expand Down Expand Up @@ -193,7 +242,6 @@ def bipartition_tree(
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
choice: Callable = random.choice,
max_attempts: Optional[int] = 10000
max_attempts: Optional[int] = None
) -> Set:
"""
This function finds a balanced 2 partition of a graph by drawing a
Expand Down Expand Up @@ -280,8 +328,40 @@ def _bipartition_tree_random_all(
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
choice: Callable = random.choice,
max_attempts: Optional[int] = None
):
"""Randomly bipartitions a graph and returns all cuts."""
) -> List[Tuple[Hashable, Hashable]]:
"""
Randomly bipartitions a tree into two subgraphs until a valid bipartition is found.
:param graph: The input graph.
:type graph: nx.Graph
:param pop_col: The name of the column in the graph nodes that contains the population data.
:type pop_col: str
:param pop_target: The target population for each subgraph.
:type pop_target: Union[int, float]
:param epsilon: The allowed deviation from the target population.

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

...as a percentage of pop_target

:type epsilon: float
:param node_repeats: The number of times to repeat the bipartitioning process. Defaults to 1.
:type node_repeats: int, optional
:param repeat_until_valid: Whether to repeat the bipartitioning process until a valid bipartition is found. Defaults to True.
:type repeat_until_valid: bool, optional
:param spanning_tree: The spanning tree to use for bipartitioning. If None, a random spanning tree will be generated. Defaults to None.
:type spanning_tree: Optional[nx.Graph], optional
:param spanning_tree_fn: The function to generate a spanning tree. Defaults to random_spanning_tree.
:type spanning_tree_fn: Callable, optional
:param balance_edge_fn: The function to find balanced edge cuts. Defaults to find_balanced_edge_cuts_memoization.
:type balance_edge_fn: Callable, optional
:param choice: The function to choose a random element from a list. Defaults to random.choice.
:type choice: Callable, optional
:param max_attempts: The maximum number of attempts to find a valid bipartition. If None, there is no limit. Defaults to None.
:type max_attempts: Optional[int], optional
:returns: A list of possible cuts that bipartition the tree into two subgraphs.
:rtype: List[Tuple[Hashable, Hashable]]
:raises RuntimeError: If a valid bipartition cannot be found after the specified number of attempts.
"""


populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}

possible_cuts = []
Expand Down Expand Up @@ -321,8 +401,9 @@ def bipartition_tree_random(
balance_edge_fn: Callable = find_balanced_edge_cuts_memoization,
choice: Callable = random.choice,
max_attempts: Optional[int] = None
):
"""This is like :func:`bipartition_tree` except it chooses a random balanced
) -> Union[Set[Any], None]:
"""
This is like :func:`bipartition_tree` except it chooses a random balanced
cut, rather than the first cut it finds.
This function finds a balanced 2 partition of a graph by drawing a
Expand All @@ -333,27 +414,38 @@ def bipartition_tree_random(
Builds up a connected subgraph with a connected complement whose population
is ``epsilon * pop_target`` away from ``pop_target``.
Returns a subset of nodes of ``graph`` (whose induced subgraph is connected).
The other part of the partition is the complement of this subset.
:param graph: The graph to partition
:param pop_col: The node attribute holding the population of each node
:param pop_target: The target population for the returned subset of nodes
:param graph: The graph to partition (must be an instance of nx.Graph)
:type graph: nx.Graph
:param pop_col: The node attribute holding the population of each node (must be a string)
:type pop_col: str
:param pop_target: The target population for the returned subset of nodes (must be an int or float)
:type pop_target: Union[int, float]
:param epsilon: The allowable deviation from ``pop_target`` (as a percentage of
``pop_target``) for the subgraph's population
``pop_target``) for the subgraph's population (must be a float)
:type epsilon: float
:param node_repeats: A parameter for the algorithm: how many different choices
of root to use before drawing a new spanning tree.
of root to use before drawing a new spanning tree (default is 1, must be an int)
:type node_repeats: int

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param repeat_until_valid: Determines whether to keep drawing spanning trees
until a tree with a balanced cut is found. If `True`, a set of nodes will
always be returned; if `False`, `None` will be returned if a valid spanning
tree is not found on the first try.
tree is not found on the first try (default is True, must be a bool)
:type repeat_until_valid: bool

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param spanning_tree: The spanning tree for the algorithm to use (used when the
algorithm chooses a new root and for testing)
algorithm chooses a new root and for testing) (must be an instance of nx.Graph or None)
:type spanning_tree: Optional[nx.Graph]
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
tree is not provided
:param balance_edge_fn: The algorithm used to find balanced cut edges
:param choice: :func:`random.choice`. Can be substituted for testing.
:param max_atempts: The max number of attempts that should be made to bipartition.
tree is not provided (must be a callable)
:type spanning_tree_fn: Callable

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param balance_edge_fn: The algorithm used to find balanced cut edges (must be a callable)
:type balance_edge_fn: Callable

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param choice: :func:`random.choice`. Can be substituted for testing. (must be a callable)
:type choice: Callable

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param max_attempts: The max number of attempts that should be made to bipartition. (must be an int or None)
:type max_attempts: Optional[int]
:return: A subset of nodes of ``graph`` (whose induced subgraph is connected) or None if a valid spanning tree is not found.
:rtype: Union[Set[Any], None]
"""
possible_cuts = _bipartition_tree_random_all(
graph=graph,
Expand Down Expand Up @@ -381,18 +473,28 @@ def recursive_tree_part(
node_repeats: int = 1,
method: Callable = partial(bipartition_tree, max_attempts=10000)
) -> Dict:
"""Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into
"""
Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into
``len(parts)`` parts of population ``pop_target`` (within ``epsilon``). Can be used to
generate initial seed plans or to implement ReCom-like "merge walk" proposals.
:param graph: The graph
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
:type graph: nx.Graph
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``)
:type parts: Sequence
:param pop_target: Target population for each part of the partition
:type pop_target: Union[float, int]
:param pop_col: Node attribute key holding population data
:type pop_col: str
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
of the partition can be
:type epsilon: float
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
:param method: The partition method to use.
Defaluts to 1.
:type node_repeats: int, optional
:param method: The partition method to use. Defaults to
`partial(bipartition_tree, max_attempts=10000)`.
:type method: Callable, optional
:return: New assignments for the nodes of ``graph``.
:rtype: dict
"""
Expand Down Expand Up @@ -452,12 +554,24 @@ def get_seed_chunks(
balanced within new_epsilon <= ``epsilon`` of a balanced target population.
:param graph: The graph
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
:param pop_target: target population of the districts (not of the chunks)
:type graph: nx.Graph
:param num_chunks: The number of chunks to partition the graph into
:type num_chunks: int
:param num_dists: The number of districts
:type num_dists: int
:param pop_target: The target population of the districts (not of the chunks)
:type pop_target: Union[int, float]
:param pop_col: Node attribute key holding population data
:type pop_col: str
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
of the partition can be
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
:type epsilon: float
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree_random`
to use.

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

What is the default?

:type node_repeats: int, optional
:param method: The method to use for bipartitioning the graph.
Defaults to :func:`~gerrychain.tree_methods.bipartition_tree_random`
:type method: Callable, optional
:return: New assignments for the nodes of ``graph``.
:rtype: dict
"""
Expand Down Expand Up @@ -534,29 +648,40 @@ def get_seed_chunks(


def get_max_prime_factor_less_than(
n, ceil
):
n: int, ceil: int
) -> Optional[int]:
"""
Helper function for recursive_seed_part. Returns the largest prime factor of ``n`` less than
Helper function for recursive_seed_part_inner. Returns the largest prime factor of ``n`` less than
``ceil``, or None if all are greater than ceil.
:param n: The number to find the largest prime factor for.
:type n: int
:param ceil: The upper limit for the largest prime factor.
:type ceil: int
:return: The largest prime factor of ``n`` less than ``ceil``, or None if all are greater than ceil.
:rtype: int or None
"""
factors = []
i = 2
if n <= 1 or ceil <= 1:
return None

largest_factor = None
while n % 2 == 0:
largest_factor = 2
n //= 2

i = 3
while i * i <= n:
if n % i:
i += 1
else:
while n % i == 0:
if i <= ceil:
largest_factor = i
n //= i
factors.append(i)
if n > 1:
factors.append(n)

if len(factors) == 0:
return 1
m = [i for i in factors if i <= ceil]
if m == []:
return None
return int(max(m))
i += 2

if n > 1 and n <= ceil:
largest_factor = n

return largest_factor



def recursive_seed_part_inner(
Expand Down Expand Up @@ -681,29 +806,40 @@ def recursive_seed_part(
method: Callable = partial(bipartition_tree, max_attempts=10000),
node_repeats: int = 1,
n: Optional[int] = None,
ceil: None = None
ceil: Optional[int] = None
) -> Dict:
"""
Returns a partition with ``num_dists`` districts balanced within ``epsilon`` of
``pop_target`` by recursively splitting graph using recursive_seed_part_inner.
:param graph: The graph
:type graph: nx.Graph
:param parts: Iterable of part labels (like ``[0,1,2]`` or ``range(4)``
:type parts: Sequence
:param pop_target: Target population for each part of the partition
:type pop_target: Union[float, int]
:param pop_col: Node attribute key holding population data
:type pop_col: str
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
of the partition can be
:type epsilon: float
:param method: Function used to find balanced partitions at the 2-district level
Defaults to :func:`~gerrychain.tree_methods.bipartition_tree`
:type method: Callable

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

optional

:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
Defaults to 1.
:type node_repeats: int, optional
:param n: Either a positive integer (greater than 1) or None. If n is a positive integer,
this function will recursively create a seed plan by either biting off districts from graph
or dividing graph into n chunks and recursing into each of these. If n is None, this
function prime factors ``num_dists``=n_1*n_2*...*n_k (n_1 > n_2 > ... n_k) and recursively
partitions graph into n_1 chunks.
this function will recursively create a seed plan by either biting off districts from graph
or dividing graph into n chunks and recursing into each of these. If n is None, this
function prime factors ``num_dists``=n_1*n_2*...*n_k (n_1 > n_2 > ... n_k) and recursively
partitions graph into n_1 chunks.
:type n: Optional[int]

This comment has been minimized.

Copy link
@cdonnay

cdonnay Jan 12, 2024

Contributor

default value?

:param ceil: Either a positive integer (at least 2) or None. Relevant only if n is None. If
``ceil`` is a positive integer then finds the largest factor of ``num_dists`` less than or
equal to ``ceil``, and recursively splits graph into that number of chunks, or bites off a
district if that number is 1.
``ceil`` is a positive integer then finds the largest factor of ``num_dists`` less than or
equal to ``ceil``, and recursively splits graph into that number of chunks, or bites off a
district if that number is 1. Defaults to None.
:type ceil: Optional[int]
:return: New assignments for the nodes of ``graph``.
:rtype: dict
"""
Expand Down

1 comment on commit aa02ee4

@cdonnay
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Read through!

Please sign in to comment.