In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from collections import defaultdict

In [None]:
data = load_data(2023, 22)

In [None]:
# data, part_1, part_2
tests = [
    (
        """1,0,1~1,2,1
0,0,2~2,0,2
0,2,3~2,2,3
0,0,4~0,2,4
2,0,5~2,2,5
0,1,6~2,1,6
1,1,8~1,1,9""",
        5,
        7,
    ),
]

# Part 1

In [None]:
def get_bricks(data):
    """Retrieve bricks from the input.

    Parameters
    ----------
    data : str
        The puzzle input.

    Returns
    -------
    bricks : dict[int, tuple[int, int] * 3]
        The falling bricks.
        Keys are block ids.
        Values are (xmin, xmax), (ymin, ymax), (zmin, zmax).
    order : list[int]
        Block ids sorted by increasing height.
    xrange, yrange : range
        The projected ground ranges of the falling area.
    """
    bricks = {}
    xs = set()
    ys = set()
    for idx, line in enumerate(data.splitlines(), start=1):
        x1, y1, z1, x2, y2, z2 = [int(v) for v in line.replace("~", ",").split(",")]
        xs |= {x1, x2}
        ys |= {y1, y2}
        bricks[idx] = (x1, x2), (y1, y2), (z1, z2)

    return (
        bricks,
        sorted(bricks.keys(), key=lambda idx: bricks[idx][2][0]),
        range(min(xs), max(xs) + 1),
        range(min(ys), max(ys) + 1),
    )

In [None]:
def fall(bricks, order, xrange, yrange):
    """Let gravity do its work.

    Parameters
    ----------
    bricks, order, xrange, yrange :
        The output of get_bricks.

    Returns
    -------
    safe_to_remove : set[int]
        The bricks than are not the unique support of another.
    support : dict[int, set[int]]
        The vertical adjacency relation.
        Entry A : {B, C} means that brick A supports bricks B and C.
        There are no uniqueness constraints in support.
    """
    heights = {}
    tops = {}
    support = defaultdict(set)
    safe_to_remove = set(bricks)
    for x in xrange:
        for y in yrange:
            heights[x, y] = 0
            tops[x, y] = 0
    for idx in order:
        (x1, x2), (y1, y2), (z1, z2) = bricks[idx]
        maxz = 0
        top_bricks = set()
        for x in range(x1, x2 + 1):
            for y in range(y1, y2 + 1):
                z = heights[x, y]
                if z > maxz:
                    maxz = z
                    top_bricks = {tops[x, y]}
                elif z == maxz:
                    top_bricks.add(tops[x, y])
        for brick in top_bricks - {0}:
            support[brick].add(idx)
        if len(top_bricks) == 1:
            safe_to_remove -= top_bricks
        for x in range(x1, x2 + 1):
            for y in range(y1, y2 + 1):
                heights[x, y] = maxz + z2 - z1 + 1
                tops[x, y] = idx
    return safe_to_remove, support

In [None]:
def safe_to_remove(data):
    safe_to_remove, _ = fall(*get_bricks(data))
    return len(safe_to_remove)

In [None]:
check(safe_to_remove, tests)
safe_to_remove(data)

# Part 2

In [None]:
def chain_reactions(data):
    """Sum chain reactions lengths."""
    _, support = fall(*get_bricks(data))

    # invert the support dictionary
    resting = defaultdict(set)
    for block in support:
        for other in support[block]:
            resting[other].add(block)

    s = 0
    for brick in support:
        to_remove = {brick}
        candidates = list(support[brick])
        while candidates:
            candidate = candidates.pop()
            # at this point, we might miss a block supported by multiple
            # intermediate blocks, but it will be added back later
            if not resting[candidate] - to_remove:
                to_remove.add(candidate)
                if candidate in support:
                    candidates += support[candidate]
        s += len(to_remove) - 1
    return s

In [None]:
check(chain_reactions, tests, 2)
chain_reactions(data)