In [125]:
from dataclasses import dataclass
from typing import List, Union
from functools import reduce

LITERAL = 4

@dataclass
class Packet:
    version: int
    type_id: int
    data: Union[int, List['Packet']]

In [170]:
def hex2bin(h):
    return ''.join(format(int(c, 16), 'b').zfill(4) for c in h)

def parse_header(p):
    version = int(p[:3],  2)
    type_id = int(p[3:6], 2)
    return (version, type_id), p[6:]

def parse_literal(p):
    literal = []
    last = False
    while not last:
        last = p[0] == '0'
        literal.append(p[1:5])
        p = p[5:]
    return int(''.join(literal), 2), p

def parse_int(p, n_bits):
    i = int(p[:n_bits], 2)
    return i, p[n_bits:]

def parse_packet(p):
    (version, type_id), p = parse_header(p)
    if type_id == LITERAL:
        data, p = parse_literal(p)
    else:
        data = []
        length_id, p = p[0], p[1:]
        if length_id == '0':
            n_bits, p = parse_int(p, 15)
            while n_bits > 0:
                packet, q = parse_packet(p)
                n_bits -= len(p) - len(q)
                p = q
                data.append(packet)
        else:
            n_packets, p = parse_int(p, 11)
            while len(data) != n_packets:
                packet, p = parse_packet(p)
                data.append(packet)
    return Packet(version, type_id, data), p

def flatten_packet(p):
    yield p
    if isinstance(p.data, List):
        for c in p.data:
            yield from flatten_packets(c)

def product(xs):
    return reduce(lambda x, y: x * y, xs, 1)
            
_ops = {
    0: sum,
    1: product,
    2: min,
    3: max,
    5: lambda xs: 1 if xs[0] > xs[1] else 0,  # GT
    6: lambda xs: 1 if xs[0] < xs[1] else 0,  # LT
    7: lambda xs: 1 if xs[0] == xs[1] else 0, # EQ
}
def eval_packet(p):
    if p.type_id == 4:
        return p.data
    return _ops[p.type_id]([eval_packet(q) for q in p.data])

In [171]:
parse_packet(hex2bin('D2FE28'))

(Packet(version=6, type_id=4, data=2021), '000')

In [172]:
eval_packet(parse_packet(hex2bin('9C0141080250320F1802104A08'))[0])

1

In [173]:
with open('../data/day16.txt') as infile:
    root, remaining = parse_packet(hex2bin(infile.read().strip()))
    print('[p1] Sum of packet versions:', sum(p.version for p in flatten_packet(root)))
    print('[p2] Evaluated packet:', eval_packet(root))

[p1] Sum of packet versions: 847
[p2] Evaluated packet: 333794664059
