In [None]:
from aocd import data, models, submit
from io import StringIO
from pathlib import Path
import re

import pandas as pd

# Load data and examples

In [None]:
puzzle_year = 2024
puzzle_day = int(re.match(r"day(\d+)", Path.cwd().name).group(1))

In [None]:
todays_puzzle = models.Puzzle(year=puzzle_year, day=puzzle_day)
todays_examples = todays_puzzle.examples
data = todays_puzzle.input_data

In [None]:
todays_examples

# Part A

In [None]:
def combo_operand(operand: int, register: list) -> int | None:
    if 0 <= operand <= 3:
        return operand
    elif 4 <= operand <= 7:
        return register[operand - 4]
    else:
        return None

In [None]:
def run_program(registers: list, program: list, should_clone=False):
    outputs = []
    n_intructions = len(program)

    def run_instruction(
        instruction_pointer: int, operand: int, registers: list
    ) -> int | None:
        instruction = program[instruction_pointer]
        if instruction == 0:
            # adv
            registers[0] = registers[0] // (2 ** combo_operand(operand, registers))
        elif instruction == 1:
            # bxl
            registers[1] = registers[1] ^ operand
        elif instruction == 2:
            # bst
            registers[1] = combo_operand(operand, registers) % 8
        elif instruction == 3:
            # jnz
            if registers[0] != 0:
                return operand
        elif instruction == 4:
            # bxc
            registers[1] ^= registers[2]
        elif instruction == 5:
            # out
            outputs.append(combo_operand(operand, registers) % 8)
        elif instruction == 6:
            # bdv
            registers[1] = registers[0] // (2 ** combo_operand(operand, registers))
        elif instruction == 7:
            # cdv
            registers[2] = registers[0] // (2 ** combo_operand(operand, registers))
        return instruction_pointer + 2

    instruction_pointer = 0
    while 0 <= instruction_pointer < n_intructions - 1:
        instruction_pointer = run_instruction(
            instruction_pointer, program[instruction_pointer + 1], registers
        )
        if should_clone:
            if len(outputs) > len(program):
                return ""
            if len(outputs) > 0 and outputs[-1] != program[len(outputs) - 1]:
                return ""
    return outputs

In [None]:
def part_a(data: str) -> str:
    registerA = int(re.findall(r"Register A: (\d+)", data)[0])
    registerB = int(re.findall(r"Register B: (\d+)", data)[0])
    registerC = int(re.findall(r"Register C: (\d+)", data)[0])
    registers = [registerA, registerB, registerC]
    program = re.findall(r"Program: (.*)", data)[0]
    program = [int(x) for x in program.split(",")]
    output = run_program(registers, program)
    return ",".join([str(o) for o in output])

In [None]:
for example_index, example in enumerate(todays_examples):
    if example.answer_a != "":
        print(
            f"Example {example_index} part a: {part_a(example.input_data)} (expected {example.answer_a})"
        )
        assert part_a(str(example.input_data)) == example.answer_a
submit(part_a(data), part="a", year=puzzle_year, day=puzzle_day)

# Part B

In [None]:
from aocd.examples import Example

In [None]:
todays_examples.append(
    Example(
        input_data="""Register A: 2024
Register B: 0
Register C: 0

Program: 0,3,5,4,3,0""",
        answer_a="",
        answer_b="117440",
    )
)

In [None]:
def find_lowest_register(program):
    """
    The program I have as the puzzle input does the following:
    1) B = A % 8
    2) B' = B xor 3
    3) C = A // 2**(B')  == A >> B'
    4) A' = A // 2**3 == A >> 3
    4) B'' = B' xor C
    5) B''' = B'' xor 5
    OUT = B''' % 8

    -> This means we are at most interested in 7 bits every iteration of the program.
    -> Also,  A must be between 2**45 and 7 * 2**45 (45 = 3 x (16-1) where 16 is the length of the program).
       If not, the program would print 17th output character that we do not wont,
       or would stop before pringting all 16 characters.

    The code below relies on the fact that we only need to know the start of the register A
    in order to calculate the last outputs produced by the program.
    We try to guess iteratively A by running the program so that we reproduce
    the program in the reverse order.
    """
    result_so_far = 0
    program_len = len(program)
    program_index = program_len - 1
    # We want to scan the 3-bits blocks from the lowest to the heighest value
    # It could happen, however that only at later iteration
    # it turns out that the number found previously does not allow
    # for self-reproduction. In order not to run over the same values multiple times
    # we introduce scan_start_values.
    scan_start_values = [0] * program_len
    # We are able to iteratively guess 3 bits at the time,
    # however 3 initial bits are unspecified in in the first iterations
    # (can only be fixed in requiring that at all iterations the results are correct)
    # extra_number_at_start are those 3 initial bits.
    extra_number_at_start = 0
    while program_index >= 0:
        digit = program[program_index]
        match = False
        for b in range(scan_start_values[program_index], 8):
            scan_start_values[program_index] = b + 1
            registers = [(result_so_far << 3) + b, 0, 0]
            output = run_program(registers, program)
            if output[0] == digit:
                result_so_far = (result_so_far << 3) + b
                match = True
                break
        if not match:
            result_so_far = result_so_far >> 3
            scan_start_values[program_index] = 0
            if program_index == program_len - 1:
                extra_number_at_start += 1
                result_so_far = extra_number_at_start
            else:
                program_index += 1
        else:
            program_index -= 1
        if extra_number_at_start > 8:
            break
    return result_so_far

In [None]:
def part_b(data: str) -> str:
    program = re.findall(r"Program: (.*)", data)[0]
    program = [int(x) for x in program.split(",")]
    lowest_register_value = find_lowest_register(program)

    registers = [lowest_register_value, 0, 0]

    assert run_program(registers, program) == program
    return str(lowest_register_value)

In [None]:
todays_examples

In [None]:
for example_index, example in enumerate(todays_examples):
    if example.answer_b is not None:
        print(
            f"Example {example_index} part b: {part_b(example.input_data)} (expected {example.answer_b})"
        )
        assert part_b(str(example.input_data)) == example.answer_b
submit(part_b(data), part="b", year=puzzle_year, day=puzzle_day)