In [1]:
# Advent of Code, Day 9 - Jim Carson. 
import numpy as np
import pandas as pd
def parse(puzzle_input):
    with open(puzzle_input,"r") as fp:
        f = fp.read()
    r = f.splitlines()
    return np.array([list(i) for i in r]).astype(int)

In [2]:
def get_neighbors(d, r, c):
    points = set()
    for row, column in [ (r - 1, c), (r + 1, c), (r, c - 1), (r, c + 1)]:
        if 0 <= row < d.shape[0] and 0 <= column < d.shape[1]:
            points.add((row, column))
    return points

def is_lowpoint(d, r, c):
    for p in get_neighbors(d,r,c):
        if d[r,c] >= data[p]:
            return False
    return True

In [3]:
def heightmap(data):
    x_max = data.shape[0]
    y_max = data.shape[1]
    z = np.zeros((x_max, y_max), dtype = float)
    for i in range(x_max):
        for j in range(y_max):
            if is_lowpoint(data, i,j):
                z[i,j] = data[i,j]
            else:
                z[i,j] = np.nan
    return(z)

In [4]:
# Credit Rodrigo (RojerGS) for the elegant way of doing this in
# a while loop and sets.   *much* cleaner than what I had last night@
def find_basin(d, p):
    todo, done, basin = {p}, set(), set()
    while todo:
        p = todo.pop()
        if d[p] == 9:
            continue
        basin.add(p)
        n = get_neighbors(d, *p)
        todo.update(n - done)
        done.update(n)
    return basin

In [5]:
# Test case 15
data = parse("input_files/day09.test.txt")
z = heightmap(data)
print(np.nansum(z)+np.sum(z >= 0))

15.0


In [6]:
data = parse("input_files/day09.txt")
z = heightmap(data)
# part 1: 566
print(np.nansum(z)+np.sum(z >= 0))

566.0


In [7]:
basin_locations = {}
# Number our basins
low_number = 0
for row in range(z.shape[0]):
    for col in range(z.shape[1]):
        if not np.isnan(z[row,col]):
            basin_locations[low_number] = (row, col)
            z[row,col] = low_number
            low_number = low_number + 1

In [8]:
basin_sizes = {}
for k,v in basin_locations.items():
    f = find_basin(data, v)
    basin_sizes[k] = len(f)

In [9]:
three_smallest_basins = sorted(basin_sizes.items(), key=lambda x: x[1], reverse=True)[:3]

In [10]:
product = 1
for i,v in three_smallest_basins:
    # print("Basin %d has size %d" % (i,v))
    product *= v
    
# 891684
print(product)

891684
