In [None]:
import math
from pathlib import Path

In [None]:
input_1 = Path("input_1.txt").read_text().strip()

In [None]:
def hex_to_bin(hex_values):
    return "".join(f"{int(hex_value, 16):04b}" for hex_value in hex_values)

def read_packet_version(bits):
    return int(bits[:3], 2), bits[3:]

def read_packet_type(bits):
    return int(bits[:3], 2), bits[3:]

def read_literal_value(bits):
    last_group, value, rest = read_literal_value_group(bits)
    while not last_group:
        last_group, next_value_part, rest = read_literal_value_group(rest)
        value += next_value_part
    return int(value, 2), rest

def read_literal_value_group(bits):
    last_group = not bool(int(bits[0], 2))
    value_bits = bits[1:5]
    return last_group, value_bits, bits[5:]

def read_length_type(bits):
    length_type_id, rest = int(bits[:1], 2), bits[1:]
    if length_type_id == 0:
        value_length, rest = int(rest[:15], 2), rest[15:]
    elif length_type_id == 1:
        #print(f"{bits=} {length_type_id=} {rest=}")
        value_length, rest = int(rest[:11], 2), rest[11:]
    return length_type_id, value_length, rest

def read_packet(bits):
    packet_version, rest = read_packet_version(bits)
    packet_type, rest = read_packet_type(rest)
    if packet_type == 4:
        value, rest = read_literal_value(rest)
    else:
        value, rest = read_operator(rest)
    return packet_version, packet_type, value, rest

def read_operator(bits):
    length_type, value_length, rest = read_length_type(bits)
    if length_type == 0:
        packets, rest = read_packets_by_bit_length(rest, value_length)
    elif length_type == 1:
        packets, rest = read_packets_by_group_count(rest, value_length)
    return packets, rest

def read_packets_by_bit_length(bits, length):
    current_length = 0
    packets = []
    while current_length < length:
        packet_version, packet_type, value, rest = read_packet(bits)
        packets.append((packet_version, packet_type, value))
        current_length += len(bits) - len(rest)
        bits = rest
    return tuple(packets), rest

def read_packets_by_group_count(bits, count):
    packets = []
    while len(packets) < count:
        packet_version, packet_type, value, rest = read_packet(bits)
        packets.append((packet_version, packet_type, value))
        bits = rest
    return tuple(packets), rest

def packet_versions(packet, versions=None):
    packet_version, packet_type, packet_value = packet
    if packet_type == 4:
        return [packet_version]
    else:
        sub_packet_versions = [packet_versions(sub_packet) for sub_packet in packet_value] 
        return [packet_version] + [version for versions in sub_packet_versions for version in versions]

def packet_value(packet):
    _, packet_type, value = packet
 
    if packet_type in (0, 1, 2, 3):
        operator = (sum, math.prod, min, max)[packet_type]
        return operator(packet_value(sub_package) for sub_package in value)
    elif packet_type == 4:
        return value
    elif packet_type == 5:
        a, b = value
        return packet_value(a) > packet_value(b)
    elif packet_type == 6:
        a, b = value
        return packet_value(a) < packet_value(b)
    elif packet_type == 7:
        a, b = value
        return packet_value(a) == packet_value(b)
    else:
        raise NotImplementedError(f"Operator '{packet_type}' not implemented.")

In [None]:
# Part 1 - Test

# literal value package
assert hex_to_bin("D2FE28") == "110100101111111000101000"
assert read_packet_version("110100101111111000101000") == (6, "100101111111000101000")
assert read_packet_type("100101111111000101000") == (4, "101111111000101000")
assert read_literal_value_group("101111111000101000") == (False, "0111", "1111000101000")
assert read_literal_value("101111111000101000") == (2021, "000")
assert read_packet(hex_to_bin("D2FE28")) == (6, 4, 2021, "000")

# operator package, bit length sub-packets
assert hex_to_bin("38006F45291200") == "00111000000000000110111101000101001010010001001000000000"
assert read_packet_version("00111000000000000110111101000101001010010001001000000000") == (1, "11000000000000110111101000101001010010001001000000000")
assert read_packet_type("11000000000000110111101000101001010010001001000000000") == (6, "00000000000110111101000101001010010001001000000000")
assert read_length_type("00000000000110111101000101001010010001001000000000") == (0, 27, "1101000101001010010001001000000000")
assert read_packets_by_bit_length("1101000101001010010001001000000000", 27) == (((6, 4, 10), (2, 4, 20)), "0000000")
assert read_packet(hex_to_bin("38006F45291200")) == (1, 6, ((6, 4, 10), (2, 4, 20)), "0000000")

# operator package, group count sub-packets
assert hex_to_bin("EE00D40C823060") == "11101110000000001101010000001100100000100011000001100000"
assert read_packet_version("11101110000000001101010000001100100000100011000001100000") == (7, "01110000000001101010000001100100000100011000001100000")
assert read_packet_type("01110000000001101010000001100100000100011000001100000") == (3, "10000000001101010000001100100000100011000001100000")
assert read_length_type("10000000001101010000001100100000100011000001100000") == (1, 3, "01010000001100100000100011000001100000")
assert read_packets_by_group_count("01010000001100100000100011000001100000", 3) == (((2, 4, 1), (4, 4, 2), (1, 4, 3)), '00000')
assert read_packet(hex_to_bin("EE00D40C823060")) == (7, 3, ((2, 4, 1), (4, 4, 2), (1, 4, 3)), '00000')

assert sum(packet_versions(read_packet(hex_to_bin("8A004A801A8002F478"))[:-1])) == 16
assert sum(packet_versions(read_packet(hex_to_bin("620080001611562C8802118E34"))[:-1])) == 12
assert sum(packet_versions(read_packet(hex_to_bin("C0015000016115A2E0802F182340"))[:-1])) == 23
assert sum(packet_versions(read_packet(hex_to_bin("A0016C880162017C3686B18A3D4780"))[:-1])) == 31

In [None]:
# Part 1
transmission = hex_to_bin(input_1)
sum(packet_versions(read_packet(transmission)[:-1]))

In [None]:
# Part 2 - Test
assert packet_value(read_packet(hex_to_bin("C200B40A82"))[:-1]) == 3
assert packet_value(read_packet(hex_to_bin("04005AC33890"))[:-1]) == 54
assert packet_value(read_packet(hex_to_bin("880086C3E88112"))[:-1]) == 7
assert packet_value(read_packet(hex_to_bin("CE00C43D881120"))[:-1]) == 9
assert packet_value(read_packet(hex_to_bin("D8005AC2A8F0"))[:-1]) == 1
assert packet_value(read_packet(hex_to_bin("F600BC2D8F"))[:-1]) == 0
assert packet_value(read_packet(hex_to_bin("9C005AC2F8F0"))[:-1]) == 0
assert packet_value(read_packet(hex_to_bin("9C0141080250320F1802104A08"))[:-1]) == 1

In [None]:
# Part 2
transmission = hex_to_bin(input_1)
packet_value(read_packet(transmission)[:-1])