In [None]:
def get_grid(filename):
    grid = dict()
    with open(filename) as file:
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip()):
                grid[row, col] = char
    return grid

In [None]:
def print_grid(grid):
    for row in range(5):
        print(''.join([grid.get((row, col), '.') for col in range(5)]))

In [None]:
def adjacent_bugs(pos, grid):
    count = 0
    x, y = pos
    for dx, dy in zip((0, 0, 1, -1), (1, -1, 0, 0)):
        if grid.get((x + dx, y + dy)) == '#':
            count += 1
    return count

In [None]:
def biodiversity(grid):
    result = 0
    multiplier = 1
    for char in grid.values():
        if char == '#':
            result += multiplier
        multiplier *= 2
    return result

In [None]:
def evolve(old):
    new = dict()
    for row in range(5):
        for col in range(5):
            count = adjacent_bugs((row, col), old)
            if (old[row, col] == '#') and not (count == 1):
                new[row, col] = '.'
            elif (old[row, col] == '.') and (count in (1, 2)):
                new[row, col] = '#'
            else:
                new[row, col] = old[row, col]
    return new

# Part 1

In [None]:
def part1(filename):
    state = get_grid(filename)
    seen = set()
    while True:
        state = evolve(state)
        bio = biodiversity(state)
        if bio in seen:
            break
        seen.add(bio)
    return bio

In [None]:
# Test
assert part1("day24-test1.input") == 2129920

In [None]:
part1("day24.input")

# Part 2

In [None]:
def adjacent_bugs_part2(pos, grids, level):
    count = 0
    row, col = pos
    for drow, dcol in zip((0, 0, 1, -1), (1, -1, 0, 0)):
        n = (row + drow, col + dcol)
        if n == (2, 2):
            # Recurse into center grid
            if pos == (2, 1):
                # Sum of inner left edge 
                count += sum(grids[level + 1].get((r, 0)) == '#' for r in range(5))
            if pos == (2, 3):
                # Sum of inner right edge
                count += sum(grids[level + 1].get((r, 4)) == '#' for r in range(5))
            if pos == (1, 2):
                # Sum of inner top edge
                count += sum(grids[level + 1].get((0, c)) == '#' for c in range(5))
            if pos == (3, 2):
                # Sum of inner bottom edge
                count += sum(grids[level + 1].get((4, c)) == '#' for c in range(5))

        # Recurse into grid outside
        elif (n[0] == -1) and (grids[level - 1].get((1, 2)) == '#'):
            # Outer cell above
            count += 1
        elif (n[0] == 5) and (grids[level - 1].get((3, 2)) == '#'):
            # Outer cell below
            count += 1
        elif (n[1] == -1) and (grids[level - 1].get((2, 1)) == '#'):
            # Outer cell left
            count += 1
        elif (n[1] == 5) and (grids[level - 1].get((2, 3)) == '#'):
            # Outer cell right
            count += 1
        
        # Normal neighbours
        elif grids[level].get(n) == '#':
            count += 1 
                
    return count

In [None]:
def evolve_part2(old):
    max_level = generation // 2 + 1
    new = collections.defaultdict(dict)
    for level in range(-max_level, max_level + 1):
        for row in range(5):
            for col in range(5):
                if (row, col) == (2, 2):
                    # Don't assign a value to the center tile
                    continue
                count = adjacent_bugs_part2((row, col), old, level)
                if (old[level].get((row, col)) == '#') and not (count == 1):
                    new[level][row, col] = '.'
                elif (old[level].get((row, col)) == '.') and (count in (1, 2)):
                    new[level][row, col] = '#'
                else:
                    new[level][row, col] = old[level].get((row, col), '.')
    return new

In [None]:
import collections
grids = collections.defaultdict(dict)
grids[0] = get_grid("day24-test1.input")
print_grid(grids[0])

In [None]:
for generation in range(10):
    grids = evolve_part2(grids)

In [None]:
sorted(grids.keys())

In [None]:
print_grid(grids[-5])