In [163]:
import re
from collections import namedtuple
from operator import mul

from math import prod

from aoc import submit

DAY = 16

In [170]:
Packet = namedtuple('Packet', ['version', 'type_id', 'content'])


def parse_input(raw):
    return [hex_to_bits(hex_str) for hex_str in raw.splitlines()]


def hex_to_bits(hex_str):
    return bin(int(hex_str, 16))[2:].zfill(len(hex_str) * 4)


def split(arr, *at):
    start = 0
    for end in at:
        yield arr[start:end]
        start = end
    yield arr[start:]


def literal(bits, acc=''):
    head, val, rest = split(bits, 1, 5)
    return literal(rest, acc + val) if int(head, 2) else (int(acc + val, 2), rest)


def operator(bits):
    head, bits = split(bits, 1)
    if int(head):
        count, bits = split(bits, 11)
        return unpack_all(bits, int(count, 2))
    else:
        length, bits = split(bits, 15)
        bits, rest = split(bits, int(length, 2))
        return unpack_all(bits)[0], rest


def unpack_all(bits, n=None):
    result = []
    while bits and (n is None or len(result) < n):
        packet, bits = unpack(bits)
        result.append(packet)
    return result, bits


def unpack(bits):
    version, type_id, bits = split(bits, 3, 6)
    version, type_id = int(version, 2), int(type_id, 2)

    match int(type_id):
        case 4:
            content, bits = literal(bits)
        case _:
            content, bits = operator(bits)

    return Packet(version, type_id, content), bits


def add_versions(packet: Packet):
    result = packet.version
    try:
        iterator = iter(packet.content)
    except TypeError:
        pass
    else:
        for child in iterator:
            result += add_versions(child)
    return result


@submit(day=DAY)
def part_one(raw):
    result = 0
    for bits in parse_input(raw):
        packet, _ = unpack(bits)
        result += add_versions(packet)
    return result


part_one:
✅ example: 82             (0.15 ms)
✅ input:   1007           (1.85 ms)


In [165]:
def evaluate(packet: Packet):
    try:
        iterator = iter(packet.content)
    except TypeError:
        return packet.content
    else:
        values = [evaluate(child) for child in iterator]
        match packet.type_id:
            case 0:
                return sum(values)
            case 1:
                return prod(values)
            case 2:
                return min(values)
            case 3:
                return max(values)
            case 5:
                return values[0] > values[1]
            case 6:
                return values[0] < values[1]
            case 7:
                return values[0] == values[1]


@submit(day=DAY, skip_example=True)
def part_two(raw):
    packet, _ = unpack(parse_input(raw)[0])
    return evaluate(packet)

part_two:
✅ input:   834151779165   (1.79 ms)
