In [6]:
from collections import namedtuple

OpCall = namedtuple('OpCall', 'op param')

class Console:
    def __init__(self, program: str):
        self.acc = 0
        self.ptr = 0
        self.code = []
        
        for line in program.splitlines():
            op, param = line.strip().split()
            self.code.append(OpCall(op, int(param)))
            
            
    def nop(self, param):
        self.ptr += 1
    
    
    def acc_op(self, param):
        self.acc += param
        self.ptr += 1

        
    def jmp_op(self, param):
        self.ptr += param    
        
        
    def run_until_repeat(self):
        seen_ptrs = set()

        while self.ptr not in seen_ptrs:
            seen_ptrs.add(self.ptr)
            op, param = self.code[self.ptr]
            
            if op == 'nop':
                self.nop(param)
            elif op == 'acc':
                self.acc_op(param)
            elif op == 'jmp':
                self.jmp_op(param)
                
        return self.acc
    
    
    # For Part 2
    # Return an empty list if there's a loop, otherwise a list 
    # containing the accumulator (in case the acc is zero, will 
    # want to test for truthiness)
    def run(self):
        self.reset()
        seen_ptrs = set()

        while self.ptr not in seen_ptrs:
            if self.ptr < 0 or self.ptr >= len(self.code):
                return [self.acc]
            
            seen_ptrs.add(self.ptr)
            op, param = self.code[self.ptr]
            
            if op == 'nop':
                self.nop(param)
            elif op == 'acc':
                self.acc_op(param)
            elif op == 'jmp':
                self.jmp_op(param)
                
        return []
    
    
    def swap_op(self, i):
        op, param = self.code[i]
        if op == 'nop':
            op = 'jmp'
        elif op == 'jmp':
            op = 'nop'
        self.code[i] = OpCall(op, param)
        
        
    def find_glitch(self):
        for i in range(len(code)):
            if self.code[i].op in ('nop', 'jmp'):
                self.swap_op(i)
                result = self.run()
                self.swap_op(i)
                if result:
                    return result[0]
                
            
    def reset(self):
        self.acc = 0
        self.ptr = 0
        return self

    
    # For post mortem
    def find_glitch_optimised(self):
        swapped = set()
        path = set()
        
        while 0 <= self.ptr < len(code):
            outer_op, outer_param = self.code[self.ptr]
            if outer_op in ('nop', 'jmp') and self.ptr not in swapped:
                swapped.add(self.ptr)
                # Cache the state
                latest_ptr = self.ptr
                latest_acc = self.acc
                self.swap_op(latest_ptr)
                seen_ptrs = set()

                while self.ptr not in seen_ptrs | path:
                    if self.ptr < 0 or self.ptr >= len(self.code):
                        return self.acc

                    seen_ptrs.add(self.ptr)
                    op, param = self.code[self.ptr]

                    if op == 'nop':
                        self.nop(param)
                    elif op == 'acc':
                        self.acc_op(param)
                    elif op == 'jmp':
                        self.jmp_op(param)
                        
                self.swap_op(latest_ptr)
                self.ptr = latest_ptr
                self.acc = latest_acc
            elif outer_op == 'jmp':
                path.add(self.ptr)
                self.jmp_op(outer_param)
            elif outer_op == 'acc':
                path.add(self.ptr)
                self.acc_op(outer_param)
            else:
                path.add(self.ptr)
                self.ptr += 1
                
        return self.acc
        
                 
test_code = '''nop +0
acc +1
jmp +4
acc +3
jmp -3
acc -99
acc +1
jmp -4
acc +6'''

assert Console(test_code).run_until_repeat() == 5

In [7]:
code = open('input').read()
Console(code).run_until_repeat()

1818

## Part 2

In [8]:
assert Console(test_code).find_glitch() == 8

In [9]:
Console(code).find_glitch()

631

In [11]:
%%timeit
Console(code).find_glitch()

25 ms ± 260 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


### Post mortem

This is clearly unacceptable performance. What happens if, instead of swapping in code line order, we swap in order of execution and cache the state before the swap?

In [13]:
Console(code).find_glitch_optimised()

631

In [14]:
%%timeit
Console(code).find_glitch_optimised()

6.27 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Much better.