In [1]:
from collections import Counter
from pathlib import Path
from typing import List, Tuple, Union
from dataclasses import dataclass
from loguru import logger
from tqdm import tqdm
import copy
import sys

text = Path("input.txt").read_text()

In [2]:
@dataclass
class State:
    history: List[int]
    counter: int = 0
    pointer: int = 0
    accumulator: int = 0
    terminated: bool = False
    curr: Tuple[str, int] = ("ready", -1)
        
    def __init__(self, program_text=None):
        self.history = []
        if program_text:
            program = self.parse(program_text)
            self.execute(program)
        
    @staticmethod
    def parse(program: str) -> List[Tuple[str, int]]:
        lines = []
        for line in program.strip().splitlines():
            instr, n = line.split()
            n = int(n)
            lines.append((instr, n))
        return lines
        
    def execute(self, program: List[Tuple[str, int]]):
        self.program = program
        while True:
            if self.pointer >= len(self.program):
                # stop running if we ran out of instructions
                self.terminated = True
                logger.info(f"terminated: {self}")
                break
            if self.history and max(Counter(self.history).values()) > 1:
                # stop infinite loops
                logger.warning(f"hitting an instruction [line={self.pointer}] for the second time, aborting!")
                break
            self.curr = self.program[self.pointer]
            logger.debug(f"running: {self}")
            self.call(*self.curr)            
    
    def call(self, instr, n):
        self.history.append(self.pointer)
        # look up the function
        func = self.__getattribute__(instr)
        # call it with the argument
        func(n)
        self.counter += 1
    
    def jmp(self, n):
        self.pointer += n
    
    def acc(self, n):
        self.accumulator += n
        self.pointer += 1
        
    def nop(self, *args):
        self.pointer += 1

In [3]:
logger.remove()
logger.add(sys.stderr, level="INFO")
state = State()
state.execute(State.parse(text))
state.accumulator



1610

In [4]:
logger.remove()
logger.add(sys.stderr, level="DEBUG")
state = State()
state.execute(State.parse("""
nop +0
acc +1
jmp +4
acc +3
jmp -3
acc -99
acc +1
nop -4
acc +6
"""))

2020-12-08 23:37:56.882 | DEBUG    | __main__:execute:38 - running: State(history=[], counter=0, pointer=0, accumulator=0, terminated=False, curr=('nop', 0))
2020-12-08 23:37:56.883 | DEBUG    | __main__:execute:38 - running: State(history=[0], counter=1, pointer=1, accumulator=0, terminated=False, curr=('acc', 1))
2020-12-08 23:37:56.883 | DEBUG    | __main__:execute:38 - running: State(history=[0, 1], counter=2, pointer=2, accumulator=1, terminated=False, curr=('jmp', 4))
2020-12-08 23:37:56.884 | DEBUG    | __main__:execute:38 - running: State(history=[0, 1, 2], counter=3, pointer=6, accumulator=1, terminated=False, curr=('acc', 1))
2020-12-08 23:37:56.885 | DEBUG    | __main__:execute:38 - running: State(history=[0, 1, 2, 6], counter=4, pointer=7, accumulator=2, terminated=False, curr=('nop', -4))
2020-12-08 23:37:56.885 | DEBUG    | __main__:execute:38 - running: State(history=[0, 1, 2, 6, 7], counter=5, pointer=8, accumulator=2, terminated=False, curr=('acc', 6))
2020-12-08 23:37

In [5]:
logger.remove()
logger.add(sys.stderr, level="ERROR")

def flip_nop_jmp(program, i):
    instr, n = program[i]
    flipped = "jmp" if instr == "nop" else "nop"
    return program[:i] + [(flipped, n)] + program[i+1:]

def find_terminating(program):
    idxs = [i for i, (instr, *_rest) in enumerate(program) if instr in ("nop", "jmp")]
    logger.info(f"trying {idxs}")
    for i in tqdm(idxs):
        copied = copy.deepcopy(program)
        tweaked = flip_nop_jmp(copied, i)
        state = State()
        state.execute(tweaked)
        if state.terminated:
            return state
    raise ValueError("Did not find a terminating program!")
                
state = find_terminating(State.parse("""
nop +0
acc +1
jmp +4
acc +3
jmp -3
acc -99
acc +1
jmp -4
acc +6
"""))
assert state.accumulator == 8

 75%|███████▌  | 3/4 [00:00<00:00, 3551.49it/s]


In [6]:
state = find_terminating(State.parse(text))

 28%|██▊       | 81/292 [00:00<00:01, 155.58it/s]


In [7]:
state.accumulator

1703