In [4]:
def read_input(filename):
    f = open(f'../inputs/{filename}.txt', 'r')

    registers = []
    line = f.readline()
    registers.append(int(line[12:]))
    line = f.readline()
    registers.append(int(line[12:]))
    line = f.readline()
    registers.append(int(line[12:]))
    f.readline()

    line = f.readline()
    program = list(map(lambda x: int(x), line[9:].split(',')))

    f.close()
    return (registers, program)

In [32]:
class Computer:
    def __init__(self, registers):
        self.inst_pt = 0
        self.registers = registers
        self._opcode = {
            0: self._adv,
            1: self._bxl,
            2: self._bst,
            3: self._jnz,
            4: self._bxc,
            5: self._out,
            6: self._bdv,
            7: self._cdv
        }
        
    def _adv(self, operand):
        x = operand if operand <= 3 else self.registers[operand-4]
        self.registers[0] = int(self.registers[0]/pow(2, x))
        self.inst_pt += 2

    def _bxl(self, operand):
        self.registers[1] = self.registers[1] ^ operand
        self.inst_pt += 2

    def _bst(self, operand):
        x = operand if operand <= 3 else self.registers[operand-4]
        self.registers[1] = x % 8
        self.inst_pt += 2

    def _jnz(self, operand):
        if self.registers[0] == 0:
            self.inst_pt += 2
        else:
            self.inst_pt = operand

    def _bxc(self, operand):
        self.registers[1] = self.registers[1] ^ self.registers[2]
        self.inst_pt += 2

    def _out(self, operand):
        x = operand if operand <= 3 else self.registers[operand-4]
        self.inst_pt += 2
        return x % 8

    def _bdv(self, operand):
        x = operand if operand <= 3 else self.registers[operand-4]
        self.registers[1] = int(self.registers[0]/pow(2, x))
        self.inst_pt += 2

    def _cdv(self, operand):
        x = operand if operand <= 3 else self.registers[operand-4]
        self.registers[2] = int(self.registers[0]/pow(2, x))
        self.inst_pt += 2

    def _ex_opcode(self, op, operand):
        return self._opcode[op](operand)

    def execute(self, program):
        output = []
        while self.inst_pt < len(program):
            out = self._ex_opcode(program[self.inst_pt], program[self.inst_pt + 1])
            if out != None:
                output.append(out)
        return output

In [3]:
def solve1(input_filename):
    (registers, program) = read_input(input_filename)
    computer = Computer(registers)
    output = computer.execute(program)
    return ','.join(map(lambda x: str(x), output))

## Problem 2

In [33]:
class Computer2(Computer):
    def first_output(self, program):
        while True:
            out = self._ex_opcode(program[self.inst_pt], program[self.inst_pt + 1])
            if out != None:
                return out

In [34]:
def solve2(input_filename):
    (registers, program) = read_input(input_filename)

    # The following is super input specific :D
    outputs = []
    out_hashmap = {}
    for i in range(1024):
        computer = Computer2([i, 0, 0])
        out = computer.first_output(program)
        outputs.append(out)
        if out not in out_hashmap:
            out_hashmap[out] = []
        out_hashmap[out].append(i)
    
    m = 1024
    i = 0
    possible_classes = out_hashmap[program[i]]
    while i < len(program) - 1:
        i += 1
        new_classes = []
        for r in possible_classes:
            new_classes.extend([r, r + m, r + 2 * m, r + 3 * m, r + 4 * m, r + 5 * m, r + 6 * m, r + 7 * m])
        possible_classes = list(filter(lambda r: outputs[int(r/pow(8, i)) % 1024] == program[i], new_classes))
        m *= 8
    
    return min(possible_classes)