In [1]:
infile = "day16.txt"

In [2]:
from itertools import islice
from functools import partial
import operator as op
from math import prod

def bifunctor(f, it):
    a, b = tuple(it)
    return int(f(a, b))

types = {
    0: sum,
    1: prod,
    2: min,
    3: max,
    5: partial(bifunctor, op.gt),
    6: partial(bifunctor, op.lt),
    7: partial(bifunctor, op.eq),
}

def hex2bin(hex_code):
    return bin(int(hex_code, 16))[2:].zfill(4*len(hex_code))

def bin2int(bin_code):
    return int(bin_code, 2)

def bin2hex(bin_code):
    return hex(int(bin_code, 2))[2:]

class Packet:
    def __init__(self, bin_code, n=0):
        self.it = iter(bin_code)
        self.n = n
        self.label = [' '] * n
        self.ch = 0
        self.vsum = 0

    def __repr__(self):
        return ''.join(self.label)
    
    def read(self, n):
        self.n += n
        self.label += [chr(ord('A') + self.ch)] * n
        self.ch = (self.ch + 1) % 26
        return ''.join(islice(self.it, n))

    def read_literal(self):
        five = self.read(5)
        val = [five[1:]]
        while five[0] == '1':
            five = self.read(5)
            val.append(five[1:])
        val = bin2int(''.join(val))
        return val
    
    def next(self):
        version = bin2int(self.read(3))
        self.vsum += version
        type_id = bin2int(self.read(3))
        if type_id == 4:
            res = self.read_literal()
        else:
            len_type = bin2int(self.read(1))
            f = types[type_id]
            if len_type:
                n_packets = bin2int(self.read(11))
                #print(f"Read next {n_packets}")
                val = [self.next() for _ in range(n_packets)]
                res = f(val)
            else:
                len_packets = bin2int(self.read(15))
                #print(f"Read next {len_packets} bits")
                end = self.n + len_packets
                val = []
                while self.n < end:
                    val.append(self.next())
                res = f(val)
        return res


In [3]:
with open(infile) as fd:
    hex_code = fd.read().rstrip()

In [4]:
p = Packet(hex2bin(hex_code))
bin_code = hex2bin(hex_code)

In [5]:
print(p.next())
print(p.vsum)

200476472872
904
