In [None]:
import itertools

In [None]:
def get_input(filename):
    image = dict()
    with open(filename) as file:
        line = file.readline()
        algorithm = line.strip()
        file.readline()
        for row, line in enumerate(file):
            for col, char in enumerate(line.strip()):
                image[(row, col)] = char
    return algorithm, image

In [None]:
def print_image(image):
    min_row = min(row for row, col in image)
    max_row = max(row for row, col in image)
    min_col = min(col for row, col in image)
    max_col = max(col for row, col in image)
    for row in range(min_row, max_row + 1):
        for col in range(min_col, max_col + 1):
            print(image[(row, col)], end="")
        print()

In [None]:
def get_neighbours(pos):
    deltas = (delta for delta in itertools.product([-1, 0, 1], repeat=2))
    for delta in deltas:
        yield tuple(p + d for p, d in zip(pos, delta))

__NOTE__:

My algorithm starts with `#` and ends with `.`. This means that on every second iteration, the infinite area changes to `#` (because a neighborhood of `.`s translates to the index number 0, which is `#`. On the next iteration, the infinite area changes back to `.`s, because a neighborhood of `#` translates to algortihm index 512, which is `.`.

We cannot keep track of an infinite number of pixels, but we can handle this by saying that the default value for pixels that are not part of the image is either `.` or `#`, depending on which iteration we're on.

In [None]:
def enhance(image, algorithm, iteration):
    min_row = min(row for row, col in image)
    max_row = max(row for row, col in image)
    min_col = min(col for row, col in image)
    max_col = max(col for row, col in image)

    if algorithm[0] == "#":
        default = "#" if (iteration % 2 == 1) else "."
    else:
        default = "."
    
    new_image = dict()
    for row in range(min_row - 1, max_row + 2):
        for col in range(min_col - 1, max_col + 2):
            pos = (row, col)
            bits = ["1" if (image.get(neighbour, default) == "#") else "0" for neighbour in get_neighbours(pos)]
            alg_index = int("".join(bits), 2)      
            new_image[(row, col)] = algorithm[alg_index]
    
    return new_image

# Part 1

In [None]:
algorithm, image = get_input("day20.input")

for iteration in range(2):
    image = enhance(image, algorithm, iteration)
    
sum(1 for pixel in image.values() if pixel == "#")

# Part 2

In [None]:
algorithm, image = get_input("day20.input")

for iteration in range(50):
    image = enhance(image, algorithm, iteration)

sum(1 for pixel in image.values() if pixel == "#")