In [48]:
from collections import deque
from enum import Enum
from itertools import islice


class VirusGrid:
    def __init__(self, infected):
        self.infected = infected
        self.directions = deque([(0, -1), (1, 0), (0, 1), (-1, 0)])
        self.pos = 0, 0

    @classmethod
    def from_lines(cls, lines):
        # assumption: grid is square
        infected = set()
        halfsize = None
        for row, line in enumerate(lines):
            line = line.strip()
            if halfsize is None:
                halfsize = len(line) // 2
            y = row - halfsize
            infected.update(
                (x, y) for x, f in enumerate(line, -halfsize)
                if f == '#')
        return cls(infected)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        infected = self.pos in self.infected
        self.directions.rotate(-1 if infected else 1)
        if infected:
            self.infected.remove(self.pos)
        else:
            self.infected.add(self.pos)
        dx, dy = self.directions[0]
        self.pos = self.pos[0] + dx, self.pos[1] + dy
        return not infected


class EvolvedStates(Enum):
    # current and next
    clean = 'weakened', 1
    weakened = 'infected', 0
    infected = 'flagged', -1
    flagged = 'clean', 2
    
    def __init__(self, next_, rotation):
        self._next = next_
        self.rotation = rotation
    
    @property
    def next(self):
        return type(self)[self._next]
    
    def __repr__(self):
        return f'{type(self).__name__}.{self.name}'

    
class EvolvedVirusGrid(VirusGrid):
    def __init__(self, infected):
        self.infected = dict.fromkeys(infected, EvolvedStates.infected)
        self.directions = deque([(0, -1), (1, 0), (0, 1), (-1, 0)])
        self.pos = 0, 0

    def __next__(self):
        state = self.infected.get(self.pos, EvolvedStates.clean)
        self.directions.rotate(state.rotation)
        if state.next is EvolvedStates.clean:
            del self.infected[self.pos]
        else:
            self.infected[self.pos] = state.next
        dx, dy = self.directions[0]
        self.pos = self.pos[0] + dx, self.pos[1] + dy
        return state.next


In [42]:
test_grid = VirusGrid.from_lines('''\
..#
#..
...'''.splitlines())
assert sum(islice(test_grid, 10000)) == 5587

In [43]:
with open('inputs/day22.txt') as day22:
    grid = VirusGrid.from_lines(day22)

print('Part 1:', sum(islice(grid, 10000)))

In [50]:
test_grid = EvolvedVirusGrid.from_lines('''\
..#
#..
...'''.splitlines())
assert sum(state is EvolvedStates.infected for state in islice(test_grid, 10000000)) == 2511944

In [51]:
with open('inputs/day22.txt') as day22:
    grid = EvolvedVirusGrid.from_lines(day22)

print('Part 2:', sum(state is EvolvedStates.infected for state in islice(grid, 10000000)))

Part 2: 2511416
