# --- `Day 9`: Title ---

In [5]:
import aocd
import re
import operator
from collections import Counter, defaultdict, deque
from itertools import combinations
from functools import reduce, lru_cache

def prod(iterable):
    return reduce(operator.mul, iterable, 1)

def count(iterable, predicate = bool):
    return sum([1 for item in iterable if predicate(item)])

def first(iterable, default = None):
    return next(iter(iterable), default)

def lmap(func, *iterables):
    return list(map(func, *iterables))

def ints(s):
    return lmap(int, re.findall(r"-?\d+", s))

def words(s):
    return re.findall(r"[a-zA-Z]+", s)

def list_diff(x):
    return [b - a for a, b in zip(x, x[1:])]

def binary_to_int(lst):
    return int("".join(str(i) for i in lst), 2)

def get_column(lst, index):
    return [x[index] for x in lst]

In [29]:
def parse_line(line): 
    return lmap(int,str(line))
    
def parse_input(input):
    return list(map(parse_line, input.splitlines()))

In [30]:
final_input = parse_input(aocd.get_data(day=9, year=2021))
print(final_input[:5])

[[9, 7, 5, 6, 7, 8, 9, 5, 4, 5, 9, 8, 7, 6, 5, 2, 3, 4, 5, 8, 9, 1, 0, 1, 2, 3, 6, 7, 8, 9, 9, 9, 8, 9, 9, 9, 1, 3, 5, 9, 9, 9, 9, 8, 6, 6, 4, 3, 5, 9, 7, 4, 3, 2, 1, 9, 2, 1, 2, 3, 4, 9, 8, 9, 9, 9, 4, 3, 4, 6, 6, 7, 9, 8, 9, 8, 7, 9, 9, 9, 9, 8, 7, 6, 5, 4, 9, 8, 7, 6, 5, 4, 3, 5, 6, 9, 7, 6, 5, 4], [6, 5, 4, 5, 8, 9, 8, 9, 5, 9, 8, 7, 6, 5, 4, 1, 2, 3, 6, 7, 8, 9, 2, 4, 5, 4, 5, 8, 9, 8, 9, 8, 7, 8, 9, 8, 9, 4, 9, 8, 9, 8, 7, 6, 5, 4, 3, 2, 3, 9, 6, 5, 6, 3, 9, 8, 9, 2, 9, 4, 9, 8, 7, 8, 9, 8, 9, 1, 3, 4, 5, 6, 9, 6, 5, 6, 6, 7, 8, 9, 9, 7, 6, 5, 4, 3, 4, 9, 9, 7, 9, 8, 5, 6, 7, 8, 9, 5, 4, 3], [5, 4, 3, 4, 5, 6, 7, 8, 9, 9, 7, 6, 5, 4, 3, 0, 4, 4, 5, 6, 7, 8, 9, 9, 7, 9, 6, 9, 8, 7, 6, 5, 6, 9, 8, 7, 8, 9, 8, 7, 6, 9, 9, 5, 4, 3, 2, 1, 9, 8, 9, 6, 7, 9, 8, 7, 8, 9, 8, 9, 4, 3, 6, 5, 6, 7, 8, 9, 4, 5, 6, 9, 8, 7, 4, 4, 5, 6, 7, 9, 9, 8, 7, 6, 5, 1, 3, 4, 5, 9, 8, 7, 6, 7, 8, 9, 9, 4, 3, 2], [6, 5, 4, 5, 6, 7, 8, 9, 3, 2, 9, 8, 4, 3, 2, 1, 2, 6, 7, 7, 8, 9, 7, 8, 9, 8, 9, 1, 9, 9, 7,

In [31]:
test_input = parse_input('''\
2199943210
3987894921
9856789892
8767896789
9899965678
''')

print(test_input)

[[2, 1, 9, 9, 9, 4, 3, 2, 1, 0], [3, 9, 8, 7, 8, 9, 4, 9, 2, 1], [9, 8, 5, 6, 7, 8, 9, 8, 9, 2], [8, 7, 6, 7, 8, 9, 6, 7, 8, 9], [9, 8, 9, 9, 9, 6, 5, 6, 7, 8]]


## Solution 1

In [36]:
def getNeighbors(input, row, col, w, h):
    result = []
    if row > 0:
        result.append(input[row - 1][col])
    if col > 0:
        result.append(input[row][col - 1])
    if col < w - 1:
        result.append(input[row][col + 1])
    if row < h - 1 :
        result.append(input[row + 1][col])
    return result

def solve_1(input):
    w = len(input[0])
    h = len(input)
    
    count = 0
    for y,row in enumerate(input):
        for x,col in enumerate(row):
            value = input[y][x]
            neighbors = getNeighbors(input, y, x, w, h)
            if all(x > value for x in neighbors):
                count += value + 1
    return count
            

solve_1(test_input)

15

In [37]:
f"Solution 1: {solve_1(final_input)}"

'Solution 1: 560'

## Solution 2

In [49]:
def getNeighbors2(visited, input, row, col, w, h):
    result = []
    if row > 0:
        if input[row - 1][col] != 9:
            pt = (row - 1,col)
            if not pt in visited:
                result.append(pt)
                visited.add(pt)
    if col > 0:
        if input[row][col - 1] != 9:
            pt = (row,col - 1)
            if not pt in visited:
                result.append(pt)
                visited.add(pt)
    if col < w - 1:
        if input[row][col + 1] != 9:
            pt = (row,col + 1)
            if not pt in visited:
                result.append(pt)
                visited.add(pt)
    if row < h - 1:
        if input[row + 1][col] != 9:
            pt = (row + 1,col)
            if not pt in visited:
                result.append(pt)
                visited.add(pt)
    return result

def solve_2(input):
    w = len(input[0])
    h = len(input)
    
    count = []
    lowpoints = []
    for y,row in enumerate(input):
        for x,col in enumerate(row):
            value = input[y][x]
            neighbors = getNeighbors(input, y, x, w, h)
            if all(x > value for x in neighbors):
                lowpoints.append((y,x))
    
    for (row,col) in lowpoints:
        visited = set()
        q = [(row,col)]
        while q:
            r,c = q.pop(0)
            neighbors = getNeighbors2(visited, input, r, c, w, h)
            for n in neighbors:
                q.append(n)
            
        count.append(len(visited))
    count.sort()
    return count[-3] * count[-2] * count[-1]
        
solve_2(test_input)

1134

In [50]:
f"Solution 2: {solve_2(final_input)}"

'Solution 2: 959136'