In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from collections import defaultdict

In [None]:
data = load_data(2023, 16)

In [None]:
# data, part_1, part_2
tests = [
    (
        r""".|...\....
|.-.\.....
.....|-...
........|.
..........
.........\
..../.\\..
.-.-/..|..
.|....-|.\
..//.|....""",
        46,
        51,
    ),
]

# Part 1

In [None]:
def gen_grid(lines):
    grid = {}
    for y, line in enumerate(lines):
        for x, c in enumerate(line):
            grid[x + 1j * y] = c
    return grid

In [None]:
def _travel(grid, light, pos=0, direction=1):
    while pos in grid and direction not in light[pos]:
        light[pos].append(direction)
        match grid[pos]:
            case ".":
                pos = pos + direction
            case "\\":
                direction = direction.imag + direction.real * 1j
                pos = pos + direction
            case "/":
                direction = -direction.imag - direction.real * 1j
                pos = pos + direction
            case "-", complex(imag=0):
                pos = pos + direction
            case "-":
                _travel(grid, light, pos - 1, -1)
                _travel(grid, light, pos + 1, 1)
            case "|", complex(real=0):
                pos = pos + direction
            case "|":
                _travel(grid, light, pos - 1j, -1j)
                _travel(grid, light, pos + 1j, 1j)
            case _:
                raise ValueError(f"Unknown grid cell: {grid[pos]}")

def travel(data):
    grid = gen_grid(data.splitlines())
    light = defaultdict(list)
    _travel(grid, light)
    return len(light)

In [None]:
check(travel, tests)
travel(data)

# Part 2

In [None]:
def travel(data):
    lines = data.splitlines()
    grid = gen_grid(lines)
    max_x, max_y = len(lines[0]), len(lines)
    max_ = 0
    for direction, xrange, yrange in [
        (-1j, range(max_x + 1), [max_y]),
        (1j, range(max_x + 1), [0]),
        (-1, [max_x], range(max_y + 1)),
        (1, [0], range(max_y + 1)),
    ]:
        for x in xrange:
            for y in yrange:
                light = defaultdict(list)
                _travel(grid, light, x + 1j * y, direction)
                max_ = max(max_, len(light))
    return max_

In [None]:
check(travel, tests, 2)
travel(data)