In [7]:
import numpy as np
import sys
from functools import cache

In [2]:
def read_input(fname):
    data = []
    with open(fname, 'r') as inf:
        for line in inf.readlines():
            if line.strip() == '':
                continue
            data.append([c for c in line.strip()])
    return np.array(data)

In [3]:
def map_path(y, x, direct, map, known, emap, pid):
    # Energize

    if x < 0 or x >= data.shape[1]:
        # print(pid, 'Outside', x, y)
        return
    if y < 0 or y >= data.shape[0]:
        # print(pid, 'Outside', x, y)
        return

    # print('On', pid, 'At', x, y, map[y,x], 'facing', direct)
    emap[y, x] += 1

    if (x, y, direct) in known:
        # print(pid, 'Loop!')
        return

    known.append((x, y, direct))

    if map[y, x] == '.':
        return map_path(y+direct[0], x+direct[1], direct, map, known, emap, pid)

    if map[y, x] == '/':
        return map_path(y-direct[1], x-direct[0], (-direct[1], -direct[0]), map, known, emap, pid)

    if map[y, x] == '\\':
        return map_path(y+direct[1], x+direct[0], (direct[1], direct[0]), map, known, emap, pid)

    if map[y, x] == '-':
        if direct[0] == 0:
            return map_path(y+direct[0], x+direct[1], direct, map, known, emap, pid)

        # print('New path')
        map_path(y, x+direct[0], (0, direct[0]), map, known, emap, pid+1)
        return map_path(y, x-direct[0], (0, -direct[0]), map, known, emap, pid)

    if map[y, x] == '|':
        if direct[1] == 0:
            return map_path(y+direct[0], x+direct[1], direct, map, known, emap, pid)

        # print('New path!')
        map_path(y+direct[1], x, (direct[1], 0), map, known, emap, pid+1)
        return map_path(y-direct[1], x, (-direct[1], 0), map, known, emap, pid)


In [4]:
def map_path_fast(y, x, direct, map, known, emap, pid):
    # Energize

    while True:
        if x < 0 or x >= data.shape[1]:
            break
        if y < 0 or y >= data.shape[0]:
            break

        emap[y, x] += 1

        if (x, y, direct) in known:
            break

        known.append((x, y, direct))

        if map[y, x] == '.':
            y += direct[0]
            x += direct[1]
            continue

        elif map[y, x] == '/':
            y -= direct[1]
            x -= direct[0]
            direct = (-direct[1], -direct[0])
            continue

        elif map[y, x] == '\\':
            y += direct[1]
            x += direct[0]
            direct = (direct[1], direct[0])
            continue

        elif map[y, x] == '-':
            if direct[0] == 0:
                y += direct[0]
                x += direct[1]
            else:
                map_path(y, x+direct[0], (0, direct[0]), map, known, emap, pid+1)
                x -= direct[0]
                direct = (0, -direct[0])

        elif map[y, x] == '|':
            if direct[1] == 0:
                y += direct[0]
                x += direct[1]
            else:
                map_path(y+direct[1], x, (direct[1], 0), map, known, emap, pid+1)
                y -= direct[1]
                direct = (-direct[1], 0)


In [5]:
def find_best(data):
    max_energy = 0

    for xs in range(0, data.shape[1]):
        for ys, dy in ((0, 1), (data.shape[0]-1, -1)):
            direct = (dy, 0)
            known = []
            emap = np.zeros(data.shape, dtype=int)
            map_path_fast(ys, xs, direct, data, known, emap, 0)
            max_energy = max(max_energy, len(np.where(emap != 0)[0]))

    for ys in range(0, data.shape[0]):
        for xs, dx in ((0, 1), (data.shape[1]-1, -1)):
            direct = (0, dx)
            known = []
            emap = np.zeros(data.shape, dtype=int)
            map_path_fast(ys, xs, direct, data, known, emap, 0)
            max_energy = max(max_energy, len(np.where(emap != 0)[0]))

    return max_energy


In [8]:
sys.setrecursionlimit(5000)

print('*****\nPuzzle1\n*****\n')

print('Test case\n')

data = read_input('input16a.txt')

known = []
emap = np.zeros(data.shape, dtype=int)
map_path_fast(0, 0, (0, 1), data, known, emap, 0)
n_e = len(np.where(emap != 0)[0])
    
print(f'# energized is {n_e}')

assert n_e == 46

print('\nPuzzle case\n')

data = read_input('input16.txt')

known = []
emap = np.zeros(data.shape, dtype=int)
map_path_fast(0, 0, (0, 1), data, known, emap, 0)
n_e = len(np.where(emap != 0)[0])
    
print(f'# energized is {n_e}')

assert n_e == 7788

print('\n*****\nPuzzle2\n*****\n')

print('Test case\n')

data = read_input('input16a.txt')

m_e = find_best(data)
    
print(f'Max energized is {m_e}')

assert m_e == 51

print('\nPuzzle case\n')

data = read_input('input16.txt')

m_e = find_best(data)
    
print(f'Max energized is {m_e}')

assert m_e == 7987

*****
Puzzle1
*****

Test case

# energized is 46

Puzzle case

# energized is 7788

*****
Puzzle2
*****

Test case

Max energized is 51

Puzzle case

Max energized is 7987
