In [7]:
from sympy import *
import networkx as nx
import numpy as np
from scipy import sparse
from scipy.sparse import linalg

In [70]:
def get_num_spanning_trees(p, district):
    '''
    Given a partition p and a district number, this function returns the
    number of spanning trees in the subgraph of p corresponding to the 
    district. Uses Kirchoff's theorem to compute the number of spanning trees.
    
    :param p: :class:`gerrychain.Partition`
    :param district: A district in p
    :return: The number of spanning trees in the subgraph of p corresponding to district
    '''
    graph = p.subgraphs[district]
    nodes = p.parts[district]
    
    # Testing using nxn grid graphs:
    #n = 4
    #graph = nx.grid_graph(dim=[n,n])
    
    laplacian = nx.laplacian_matrix(graph)
    L = np.delete(np.delete(laplacian.todense(),0,0), 1,1)
    return exp(np.linalg.slogdet(L)[1])

## Testing using partitions

In [62]:
# Generating an initial partition from Colorado shapefiles from here:
# https://github.com/mggg-states/CO-shapefiles

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import matplotlib.pyplot as plt
from gerrychain import (GeographicPartition, Partition, Graph, MarkovChain,
                        proposals, updaters, constraints, accept, Election)
from gerrychain.proposals import recom
from functools import partial
import pandas

graph = Graph.from_file("CO-shapefiles/co_precincts.shp")

elections = [
    Election("GOV18", {"Democratic": "GOV18D", "Republican": "GOV18R"})
]
my_updaters = {"population": updaters.Tally("TOTPOP", alias="population")}
# Election updaters, for computing election results using the vote totals from our shapefile.
election_updaters = {election.name: election for election in elections}
my_updaters.update(election_updaters)
initial_partition = GeographicPartition(graph, assignment="CD116FP", updaters=my_updaters)

In [64]:
# can set district_number to be anything '01' to '07'
district_number = '05'
print("number of spanning trees: ")
print(str(get_num_spanning_trees(initial_partition, district_number)))

number of spanning trees: 
4.80693066301099e+214


## Testing using nxn grid graphs

Uncomment the lines in the function get_num_spanning_trees() below the line 
>`# Testing using nxn grid graphs:`

Set `n` to any integer, and compare the value returned by the function to the OEIS data: https://oeis.org/A007341

In [69]:
print("number of spanning trees for a grid graph with n=4: ")
print(str(get_num_spanning_trees(initial_partition, district_number)))

number of spanning trees for a grid graph with n=4: 
100352.000000000
