In [1]:
import re
import numpy as np

In [2]:
from collections import Counter

In [3]:
data = open("inputs/16.input").read().strip()
part1, part2 = data.split("\n\n\n")
examples = part1.strip().split("\n\n")

In [4]:
def get_numbers(s):
    return list(map(int, re.findall("\d+", s)))

In [5]:
register_functions = {}

def create_register_function(name, op):
    def register_function(a, b, c, registers):
        registers_copy = list(registers)
        assert len(registers) == 4
        result = op(a, b, registers)
        registers_copy[c] = result
        return registers_copy
    register_functions[name] = register_function
    return register_function

addr = create_register_function("addr", lambda a, b, r: r[a] + r[b])
addi = create_register_function("addi", lambda a, b, r: r[a] + b)
mulr = create_register_function("mulr", lambda a, b, r: r[a] * r[b])
muli = create_register_function("muli", lambda a, b, r: r[a] * b)
banr = create_register_function("banr", lambda a, b, r: r[a] & r[b])
bani = create_register_function("bani", lambda a, b, r: r[a] & b)
borr = create_register_function("borr", lambda a, b, r: r[a] | r[b])
bori = create_register_function("bori", lambda a, b, r: r[a] | b)
setr = create_register_function("setr", lambda a, b, r: r[a])
seti = create_register_function("seti", lambda a, b, r: a)
gtir = create_register_function("gtir", lambda a, b, r: 1 if a > r[b] else 0)
gtri = create_register_function("gtri", lambda a, b, r: 1 if r[a] > b else 0)
gtrr = create_register_function("gtrr", lambda a, b, r: 1 if r[a] > r[b] else 0)
eqir = create_register_function("eqir", lambda a, b, r: 1 if a == r[b] else 0)
eqri = create_register_function("eqri", lambda a, b, r: 1 if r[a] == b else 0)
eqrr = create_register_function("eqrr", lambda a, b, r: 1 if r[a] == r[b] else 0)

# Part 1

In [6]:
matching_opcodes_per_sample = []
for example in examples:
    before, op, after = [get_numbers(s) for s in example.split("\n")]
    matching_opcodes = sum([1 if op_fun(*op[1:], before) == after else 0 for name, op_fun in register_functions.items()])
    matching_opcodes_per_sample.append(matching_opcodes)

matching = np.array(matching_opcodes_per_sample)
(matching >= 3).sum()

663

# Part 2

In [34]:
identified = {}
# until we have identified each opcode
while len(identified) < 16:
    len_before = len(identified)
    for example in examples:
        before, op, after = [get_numbers(s) for s in example.split("\n")]
        opcode_matches = []
        for name, op_fun in register_functions.items():
            if name in identified: # skip op codes we have already identified
                continue
            if op_fun(*op[1:], before) == after:
                opcode_matches.append(name)
        if len(opcode_matches) == 1: # uniquely identified this opcode
            identified[opcode_matches[0]] = op[0]
    if len(identified) == len_before:
        print("Can not further identify anything!")
        break

In [35]:
opcode_func_mapping = {opcode: register_functions[opname] for opname, opcode in identified.items()}

In [40]:
ops = [get_numbers(l) for l in part2.strip().split("\n")]
registers = [0, 0, 0, 0]
for op in ops:
    registers = opcode_func_mapping[op[0]](*op[1:], registers)

In [42]:
print(registers[0])

525
