## Part 1

In [6]:
from enum import Enum

In [155]:
TEST_INFILE = "inputs/day_16_test_1.txt"
INFILE = "inputs/day_16_1.txt"

#with open(TEST_INFILE) as infile:
with open(INFILE) as infile:
    lines = infile.read().splitlines()

grid = [list(line) for line in lines]

In [156]:
grid

[['\\',
  '.',
  '.',
  '|',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '|',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '\\',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '\\',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '-',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '/',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '/',
  '.',
  '.',
  '.',
  '\\',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '-',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '/',
  '.',
  '.',
  '.'],
 ['.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '/',
  '.',
  '.',
  '.',
  '-',
  '/',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '\\',
  '.',
  '|',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',
  '.',


In [157]:
class Point:
    def __init__(self, row, col):
        self.row = row
        self.col = col

    def __repr__(self):
        return f"({self.row}, {self.col})"

    def __add__(self, other):
        return Point(self.row + other.row, self.col + other.col)

    def __radd__(self, other):
        return self + other
    
    def __eq__(self, other):
        return self.row == other.row and self.col == other.col
    
    def __hash__(self):
        return hash((self.row, self.col))
    

class Direction(Enum):
    UP = Point(-1, 0)
    DOWN = Point(1, 0)
    LEFT = Point(0, -1)
    RIGHT = Point(0, 1)


assert Point(0, 0) == Point(0, 0)
assert Point(1, 0) + Point(1, 10) == Point(2, 10)
p = Point(1, 0)
p += Point(1, 10)
assert p == Point(2, 10)


class Vector:
    def __init__(self, point, direction):
        self.point = point
        self.direction = direction

    def __repr__(self):
        return f"({self.point.row}, {self.point.col}) => {self.direction}"
    
    def __eq__(self, other):
        return self.point == other.point and self.direction == other.direction
    
    def __hash__(self):
        return hash((self.point.row, self.point.col, self.direction))
    

assert Vector(Point(0, 0), Direction.UP) == Vector(Point(0, 0), Direction.UP)


def get_neighbors_through_cell(vector, grid=grid):
    neighbors = []
    neighbor_point = vector.point + vector.direction.value
    
    if neighbor_point.row < 0 or neighbor_point.row >= len(grid):
        return neighbors
    if neighbor_point.col < 0 or neighbor_point.col >= len(grid[0]):
        return neighbors

    neighbor_symbol = grid[neighbor_point.row][neighbor_point.col]

    match neighbor_symbol:
        case ".":
            # just keep going
            neighbors.append(Vector(neighbor_point, vector.direction))

        case "|":
            # just keep going
            if vector.direction in [Direction.UP, Direction.DOWN]:
                neighbors.append(Vector(neighbor_point, vector.direction))
            # split
            elif vector.direction in [Direction.LEFT, Direction.RIGHT]:
                neighbors.append(Vector(neighbor_point, Direction.UP))
                neighbors.append(Vector(neighbor_point, Direction.DOWN))

        case "-":
            # just keep going
            if vector.direction in [Direction.LEFT, Direction.RIGHT]:
                neighbors.append(Vector(neighbor_point, vector.direction))
            # split
            elif vector.direction in [Direction.UP, Direction.DOWN]:
                neighbors.append(Vector(neighbor_point, Direction.LEFT))
                neighbors.append(Vector(neighbor_point, Direction.RIGHT))

        case "/":
            match vector.direction:
                case Direction.UP:
                    neighbors.append(Vector(neighbor_point, Direction.RIGHT))
                case Direction.DOWN:
                    neighbors.append(Vector(neighbor_point, Direction.LEFT))
                case Direction.LEFT:
                    neighbors.append(Vector(neighbor_point, Direction.DOWN))
                case Direction.RIGHT:
                    neighbors.append(Vector(neighbor_point, Direction.UP))

        case "\\":
            match vector.direction:
                case Direction.UP:
                    neighbors.append(Vector(neighbor_point, Direction.LEFT))
                case Direction.DOWN:
                    neighbors.append(Vector(neighbor_point, Direction.RIGHT))
                case Direction.LEFT:
                    neighbors.append(Vector(neighbor_point, Direction.UP))
                case Direction.RIGHT:
                    neighbors.append(Vector(neighbor_point, Direction.DOWN))

    return neighbors

In [158]:
test_grid = [[".", "."]]
assert get_neighbors_through_cell(Vector(Point(0, 0), Direction.RIGHT), test_grid) == [Vector(Point(0, 1), Direction.RIGHT)]
test_grid = [[".", "-"]]
assert get_neighbors_through_cell(Vector(Point(0, 0), Direction.RIGHT), test_grid) == [Vector(Point(0, 1), Direction.RIGHT)]

test_grid = [[".", "|"]]
assert get_neighbors_through_cell(Vector(Point(0, 0), Direction.RIGHT), test_grid) == [
    Vector(Point(0, 1), Direction.UP),
    Vector(Point(0, 1), Direction.DOWN)
]

test_grid = [[".", "/"]]
assert get_neighbors_through_cell(Vector(Point(0, 0), Direction.RIGHT), test_grid) == [Vector(Point(0, 1), Direction.UP)]

test_grid = [[".", "\\"]]
assert get_neighbors_through_cell(Vector(Point(0, 0), Direction.RIGHT), test_grid) == [Vector(Point(0, 1), Direction.DOWN)]

In [159]:
def get_energized(start, debug=False):
    queue = set([start])
    visited = set()

    while len(queue) > 0:
        cell = queue.pop()
        if cell in visited:
            continue
        
        if debug: print(f"Visiting {cell}")
        visited.add(cell)
        
        neighbors = get_neighbors_through_cell(cell)
        if len(neighbors) > 0:
            queue.update(neighbors)

    return visited, len(set(v.point for v in visited).difference([start.point]))

In [160]:
def print_visited(visited, grid=grid):
    points = set([v.point for v in visited])
    for row_n in range(len(grid)):
        for col_n in range(len(grid[0])):
            if Point(row_n, col_n) in points:
                print("X", end="")
            else:
                print(".", end="")
        print("")


In [161]:
start = Vector(Point(0, -1), Direction.RIGHT)
visited, n_energized = get_energized(start)

In [162]:
print_visited(visited)

X...X....................X..XX...X.........X.X.....X....X......XX.X........X.......XXXXX.X.X........X.....XXXX
X...X....................X..XX...X.........X.X.....X....X......XX.X........X.......X.X.X.X.X........X.....X...
X...X....................X..XX...X.........X.X.....X....X......XX.X...XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
X...X....................X..XX...X.........X.X.....X....X......XX.X...X....XX......X.X.X.XXXXXXXXXXXXXXXX.X...
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...X......X.X.X.X.X........X...X.X...
X...X.................XXXX..XX...X.........X.X...XXXXXXXXXXXXXXXXXXXXXXXXXXXX......X.X.X.X.X...XXXXXXX..X.X...
XXXXX.................X.....XX...X.........X.X...X.X....X......XXXXXXXX.X..XX......X.X.X.X.X...X....XX..X.X...
X..XX.................X.....XX...X.........X.X...X.X....XXXXXXXXXXXXXXXXXXXXXXXX...X.X.X.X.X...X....XX..X.X...
X..XX.................XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...X....XX..X.X...
X

In [163]:
n_energized

6816

## Part 2

In [166]:
top_down    = [Vector(Point(-1, col), Direction.DOWN) for col in range(len(grid[0]))]
bottom_up   = [Vector(Point(len(grid), col), Direction.UP) for col in range(len(grid[0]))]
left_right  = [Vector(Point(row, -1), Direction.RIGHT) for row in range(len(grid))]
right_left  = [Vector(Point(row, len(grid)), Direction.LEFT) for row in range(len(grid))]

max_n_energized = 0
for start in top_down + bottom_up + left_right + right_left:
    visited, n_energized = get_energized(start)
    if n_energized > max_n_energized:
        max_n_energized = n_energized

max_n_energized

8163