In [None]:
import io
import math
import operator
import os

from dataclasses import dataclass

In [None]:
@dataclass
class PacketBase:
    version: int
    type_id: int

@dataclass
class Literal(PacketBase):
    value: int
        
    def version_sum(self):
        return self.version
        
@dataclass
class Operator(PacketBase):
    sub_packets: list

    def version_sum(self):
        return self.version + sum(packet.version_sum() for packet in self.sub_packets)

In [None]:
def get_bit_buffer(hex_string):
    """Parse the input hexadecimal and create a string-buffer of 0/1s."""
    bitstring = "".join("{:04b}".format(int(char, 16)) for char in hex_string)
    return io.StringIO(bitstring)

In [None]:
def read_int(buffer, size):
    return int(buffer.read(size), 2)

In [None]:
def get_literal_value(buf):
    digits = []
    while True:
        prefix = buf.read(1)
        digit = buf.read(4)
        digits.append(digit)
        if prefix == "0":
            break
    
    return int("".join(digits), 2)

In [None]:
def is_empty(buffer):
    pos = buffer.tell()
    last = buffer.seek(0, os.SEEK_END)
    buffer.seek(pos)
    return pos == last

In [None]:
def get_packet(buf):
    version = read_int(buf, 3)
    type_id = read_int(buf, 3)
    
    if type_id == 4:
        value = get_literal_value(buf)
        return Literal(version, type_id, value)
    
    length_type_id = read_int(buf, 1)
    
    if length_type_id == 0:
        sub_packets_length = read_int(buf, 15)
        sub_buffer = io.StringIO(buf.read(sub_packets_length))
        packet = Operator(version, type_id, sub_packets=[])
        while not is_empty(sub_buffer):
            packet.sub_packets.append(get_packet(sub_buffer))
        return packet
    else:
        num_sub_packets = read_int(buf, 11)
        sub_packets = [get_packet(buf) for _ in range(num_sub_packets)]
        return Operator(version, type_id, sub_packets)

# Part 1

In [None]:
# Run some tests from the given examples
assert get_packet(get_bit_buffer("8A004A801A8002F478")).version_sum() == 16
assert get_packet(get_bit_buffer("620080001611562C8802118E34")).version_sum() == 12
assert get_packet(get_bit_buffer("C0015000016115A2E0802F182340")).version_sum() == 23
assert get_packet(get_bit_buffer("A0016C880162017C3686B18A3D4780")).version_sum() == 31

In [None]:
with open("day16.input") as file:
    buffer = get_bit_buffer(file.read().strip())

# My answer
get_packet(buffer).version_sum()

# Part 2

We define a new version of the Operator package with and extra value-property

In [None]:
@dataclass
class Operator(PacketBase):
    sub_packets: list

    def version_sum(self):
        return self.version + sum(packet.version_sum() for packet in self.sub_packets)
    
    @property
    def value(self):
        if self.type_id == 0:
            return sum(p.value for p in self.sub_packets)
        if self.type_id == 1:
            return math.prod(p.value for p in self.sub_packets)
        if self.type_id == 2:
            return min(p.value for p in self.sub_packets)
        if self.type_id == 3:
            return max(p.value for p in self.sub_packets)
        if self.type_id == 5:
            return int(operator.gt(*(p.value for p in self.sub_packets)))
        if self.type_id == 6:
            return int(operator.lt(*(p.value for p in self.sub_packets)))
        if self.type_id == 7:
            return int(operator.eq(*(p.value for p in self.sub_packets)))

        raise ValueError(f"Unhandled type_id: {self.type_id}")

In [None]:
# Run some tests from the given examples
assert get_packet(get_bit_buffer("C200B40A82")).value == 3
assert get_packet(get_bit_buffer("04005AC33890")).value == 54
assert get_packet(get_bit_buffer("880086C3E88112")).value == 7
assert get_packet(get_bit_buffer("CE00C43D881120")).value == 9
assert get_packet(get_bit_buffer("D8005AC2A8F0")).value == 1
assert get_packet(get_bit_buffer("F600BC2D8F")).value == 0
assert get_packet(get_bit_buffer("9C005AC2F8F0")).value == 0
assert get_packet(get_bit_buffer("9C0141080250320F1802104A08")).value == 1

In [None]:
with open("day16.input") as file:
    buffer = get_bit_buffer(file.read().strip())

# My answer
get_packet(buffer).value