In [1]:
import itertools
import math

In [2]:
testval, testvalans = 'D2FE28', '2021'
testop1, testop1ans = '38006F45291200', (10, 20)
testop2, testop2ans = 'EE00D40C823060', (1, 2, 3)
testvsums = {'8A004A801A8002F478': 16,
             '620080001611562C8802118E34': 12,
             'C0015000016115A2E0802F182340': 23,
             'A0016C880162017C3686B18A3D4780': 31}

In [3]:
with open('day16.txt') as fp:
    puzzledata = fp.read().strip()

## part 1 ##

In [4]:
def bitstream(hexstr):
    for c in hexstr:
        bitstr = format(int(c, 16), '04b')
        for bit in bitstr:
            yield bit

In [5]:
def take(n, iterable):
    return list(itertools.islice(iterable, n))

In [6]:
def int_from_bits(bits):
    return int(''.join(bits), 2)

In [7]:
def read_value(it):
    valbits = []
    last = False
    while True:
        prefix = next(it)
        if prefix == '0':
            last = True
        valbits.extend(take(4, it))
        if last:
            break
    return int_from_bits(valbits)

In [8]:
def get_packets(it, num=-1):
    packets = []
    while num != 0:
        num -= 1
        version_bits = take(3, it)
        if not version_bits:
            # iterator exhausted
            return packets
        version = int_from_bits(version_bits)
        typeid_bits = take(3, it)
        if not typeid_bits:
            # iterator exhausted
            return packets
        typeid = int_from_bits(typeid_bits)
        if typeid == 4:
            # value packet
            val = read_value(it)
            packet = {'version': version,
                      'typeid': typeid,
                      'value': val}
            packets.append(packet)
        else:
            # op packet
            try:
                lengthtype = next(it)
            except StopIteration:
                # iterator exhausted
                return packets
            if lengthtype == '0':
                # total length in bits
                bitlength_bits = take(15, it)
                if not bitlength_bits:
                    # iterator exhausted
                    return packets
                numbits = int_from_bits(bitlength_bits)
                if numbits == 0:
                    # We've just been reading hex padding 0s; we're finished
                    return packets
                newit = iter(take(numbits, it))
                subpackets = get_packets(newit)
                packet = {'version': version,
                          'typeid': typeid,
                          'subpackets': subpackets}
                packets.append(packet)
            elif lengthtype == '1':
                numsub = int_from_bits(take(11, it))
                subpackets = get_packets(it, numsub)
                packet = {'version': version,
                          'typeid': typeid,
                          'subpackets': subpackets}
                packets.append(packet)
            else:
                raise ValueError(f'Bad length type ID: {lengthtype}')
    return packets

In [9]:
def get_version_ids(packets, acc=None):
    if acc is None:
        acc = []
    for packet in packets:
        acc.append(packet['version'])
        if 'subpackets' in packet:
            acc = get_version_ids(packet['subpackets'], acc)
    return acc

In [10]:
for key in testvsums:
    packets = get_packets(iter(bitstream(key)))
    print(testvsums[key], sum(get_version_ids(packets)))

16 16
12 12
23 23
31 31


In [11]:
puzzlepackets = get_packets(iter(bitstream(puzzledata)))

In [12]:
sum(get_version_ids(puzzlepackets))

929

## part 2 ##

In [13]:
def gt(v):
    a, b = v
    if a > b:
        return 1
    else:
        return 0
    
def lt(v):
    a,b = v
    if a < b:
        return 1
    else:
        return 0
    
def eq(v):
    a,b = v
    if a == b:
        return 1
    else:
        return 0
    
    
ops = {0: sum,
       1: math.prod,
       2: min,
       3: max,
       5: gt,
       6: lt,
       7: eq,
      }

In [14]:
def calc(packet):
    if 'value' in packet:
        return packet['value']
    # operator packet
    return ops[packet['typeid']]([calc(sp) for sp in packet['subpackets']])

In [15]:
pzpkt = puzzlepackets[0]

In [16]:
calc(pzpkt)

911945136934