In [None]:
import re

In [None]:
filename = "day24.example.input"
filename = "day24.input"

directions = []
with open(filename) as file:
    for line in file:
        directions.append(re.findall("(e|w|ne|se|sw|nw)", line.strip()))

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)
class Tile:
    x: int = 0
    y: int = 0
    z: int = 0
        
    def neighbour(self, direction: str):
        if direction == "e":
            # (1 -1 0)
            return Tile(self.x + 1, self.y - 1, self.z)
        if direction == "w":
            # (-1 1 0)
            return Tile(self.x - 1, self.y + 1, self.z)
        if direction == "ne":
            # (1 0 -1)
            return Tile(self.x + 1, self.y, self.z - 1)
        if direction == "sw":
            # (-1 0 1)
            return Tile(self.x - 1, self.y, self.z + 1)
        if direction == "nw":
            # (0 1 -1)
            return Tile(self.x, self.y + 1, self.z - 1)
        if direction == "se":
            # (0 -1 1)
            return Tile(self.x, self.y - 1, self.z + 1)
        raise ValueError(f"Unknown direction: {direction}")
    
    @property
    def neighbours(self):
        directions = ("ne", "e", "se", "sw", "w", "nw")
        return [self.neighbour(direction) for direction in directions]
    
    @property
    def radius(self) -> int:
        """This tile's distance from the center."""
        return max(abs(self.x), abs(self.y), abs(self.z))

# Part 1

In [None]:
grid = dict()
for direction in directions:
    tile = Tile()
    for d in direction:
        tile = tile.neighbour(d)

    value = grid.get(tile, 0)
    grid[tile] = (value + 1) % 2

sum(grid.values())

# Part 2

In [None]:
for _ in range(100):
    old_grid = grid.copy()
    seen = set()
    for tile, value in old_grid.items():
        if (value == 0):
            continue  # Check only black tiles
        num_black = sum(old_grid.get(n, 0) for n in tile.neighbours)
        if (num_black == 0) or (num_black > 2):
            grid[tile] = 0

        # Check white neighbours of each black tile
        for neighbour in tile.neighbours:
            if (neighbour in seen) or (old_grid.get(neighbour, 0) == 1):
                continue  # Skip black neighbours
            seen.add(neighbour)
            num_black = sum(old_grid.get(n, 0) for n in neighbour.neighbours)
            if num_black == 2:
                grid[neighbour] = 1
            
sum(grid.values())

## Alternative solution, looping over all tiles within a radius

In [None]:
for _ in range(100):
    old_grid = grid.copy()

    radius = max(grid.keys(), key=lambda t: t.radius).radius + 1
    for x in range(-radius, radius + 1):
        for y in range(-radius, radius + 1):
            z = -x - y
            tile = Tile(x, y, z)
            neighbours = tile.neighbours
            num_black = sum(old_grid.get(n, 0) for n in neighbours)
            if (old_grid.get(tile, 0) == 1) and ((num_black == 0) or (num_black > 2)):
                grid[tile] = 0
            elif (old_grid.get(tile, 0) == 0) and (num_black == 2):
                grid[tile] = 1

sum(grid.values())