In [None]:
!pip install matplotlib

In [None]:
import pickle
from geo_dist_prep.geotree.data import TREE_FILE

tree = pickle.load(open("../" + TREE_FILE, "rb"))

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def compute_dimensions(node):
    """
    Compute the dimensions of a node's bounding box.
    """
    # If the node is a leaf, it has no dimensions
    if node.is_leaf or not any(node.nodes):
        return node.x*360, node.y*180, 0.01, 0.01

    x = min(node.x for node in node.nodes if node)
    y = min(node.y for node in node.nodes if node)
    width = max(node.x for node in node.nodes if node) - x
    height = max(node.y for node in node.nodes if node) - y

    return x*360, y*180, width*360, height*180

max_depth = 20
min_depth = 25
x1, x2 = [200, 210]
y1, y2 = [140, 150]

def plot_node(node, ax, depth=0):
    if depth > max_depth:
        return

    if depth > min_depth or node.value:
        x, y, width, height = compute_dimensions(node)

        if x < x1 and x+width < x1 or x > x2 \
            or y < y1 and y+height < y1 or y > y2:
            return
        
        col = 'r' if node.is_leaf else 'b'

        if node.value:
            col = 'g'

        # Create a rectangle patch
        rect = patches.Rectangle((x, y), width, height, 
                                linewidth=1, edgecolor=col, facecolor=col, 
                                alpha=(depth / 8) / 4)

        # Add the patch to the Axes
        ax.add_patch(rect)

    depth += 1
    # If the node has children, plot them as well
    if not node.is_leaf:
        for child in node.nodes:
            if child:
                plot_node(child, ax, depth)

fig, ax = plt.subplots(1)

plt.xlim(x1, x2)
plt.ylim(y1, y2)
plot_node(tree.root, ax)
plt.show()