In [None]:
def get_map(filename):
    heightmap = dict()
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, height in enumerate(line.strip()):
                heightmap[(row, col)] = int(height)
    return heightmap

In [None]:
def find_neighbours(pos, heightmap):
    """Find valid neighbours on the map."""
    x, y = pos
    candidates = [(x + dx, y + dy) for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0))]
    return [point for point in candidates if point in heightmap]

In [None]:
heightmap = get_map("day09.input")

# Part 1

In [None]:
def find_lowpoints(heightmap):
    """Find lowpoints on the map."""
    lowpoints = []
    for pos in heightmap:
        neighbour_heights = [heightmap[neighbour] for neighbour in find_neighbours(pos, heightmap)]
        if heightmap[pos] < min(neighbour_heights):
            lowpoints.append(pos)
    return lowpoints

In [None]:
lowpoints = find_lowpoints(heightmap)
sum(heightmap[pos] + 1 for pos in lowpoints)

# Part 2

Start in the positions of the low points and explore outwards in a breadth-first manner (BFS). If the neighbouring height is higher (but not 9), it is included in the basin.

In [None]:
from collections import deque

def find_basin(lowpoint, heightmap):
    """Find basin surrounding a lowpoint, using BFS."""
    queue = deque([lowpoint])
    basin = set([lowpoint])
    while queue:
        pos = queue.popleft()
        for neighbour in find_neighbours(pos, heightmap):
            if (neighbour not in basin) and (heightmap[neighbour] > heightmap[pos]) and (heightmap[neighbour] < 9):
                basin.add(neighbour)
                queue.append(neighbour)
    return basin

In [None]:
import math

lowpoints = find_lowpoints(heightmap)
basin_sizes = [len(find_basin(pos, heightmap)) for pos in lowpoints]
math.prod(sorted(basin_sizes)[-3:])