In [84]:
import re
import math
from collections import defaultdict, namedtuple

input_file = "input_files/day_03_1.txt"


## Part One

In [108]:
Symbol = namedtuple('Symbol', ['value', 'row', 'col'])

def find_symbols_inbounds(matrix, bounds):
    '''
    Return a list of symbols found int the matrix with the 
    given a bounds (row_start, row_end, col_start, col_end)
    '''
    
    symbols = []
    for row in range(*bounds[0:2]):
        for col in range(*bounds[2:4]):
            char = matrix[row][col]
            if not char.isdigit() and char != '.':
                symbols.append(Symbol(char, row, col))
    return symbols
            
    
def find_symbols(matrix):
    '''
    returns a dictionary with symbol keys. Symbols contain the coordinates 
    in the matrix to they are uniquely identified by value, row, col. 
    The values are a list of part numbers
    '''
    symbol_lookup = defaultdict(list)
    
    for row_number, line in enumerate(matrix):
        for m in re.finditer(r'\d+', line):
            value = int(m.group(0))
            bounds = [
                max(row_number - 1, 0),           # row start
                min(row_number + 2, len(matrix)), # row end
                max(m.start() - 1, 0),            # col start
                min(m.end() + 1, len(line))       # col end
            ]
            symbols = find_symbols_inbounds(matrix, bounds)
            
            for symbol in symbols:
                symbol_lookup[symbol].append(value)
                
    return symbol_lookup


In [107]:
test_input = [
    '467..114..',
    '...*......',
    '..35..633.',
    '......#...',
    '617*......',
    '.....+.58.',
    '..592.....',
    '......755.',
    '...$.*....',
    '.664.598..'
]
markers = find_symbols(test_input)
markers

defaultdict(list,
            {Symbol(value='*', row=1, col=3): [467, 35],
             Symbol(value='#', row=3, col=6): [633],
             Symbol(value='*', row=4, col=3): [617],
             Symbol(value='+', row=5, col=5): [592],
             Symbol(value='*', row=8, col=5): [755, 598],
             Symbol(value='$', row=8, col=3): [664]})

In [109]:
with open(input_file) as file:
    matrix = file.read().splitlines()
            
markers = find_symbols(matrix)
sum(map(sum, markers.values()))

522726

## Part Two

In [103]:
# filter only '*' symbols that are connected to exactly 2 part numbers

sum(math.prod(numbers) 
    for symbol, numbers in markers.items() 
    if len(numbers) == 2 and symbol.value=='*')

81721933