In [53]:
import advent
from functools import reduce

data = advent.get_lines(16)[0]

#data = '8A004A801A8002F478'

data = [c for c in bin(int(data, 16))[2:]] # Convert to binary list

In [61]:
def to_str(packet):
    version = bitstr_to_int(packet[:3])
    type = bitstr_to_int(packet[3:6])
    length = packet[6]
    return f"v {version}, t {type}, l {length}"

def bitstr_to_int(packet):
    return int(''.join(packet), 2)

def apply(type, result):
    # result is a list of ints
    # type is an integer. Return integer result
    if type == 0:
        return sum(result)
    elif type == 1:
        return reduce(lambda x, y: x*y, result, 1)
    elif type == 2:
        return min(result)
    elif type == 3:
        return max(result)
    elif type == 4:
        raise ValueError(f"Type should not be 4: {result}")
    elif type == 5:
        return result[0] > result[1]
    elif type == 6:
        return result[0] < result[1]
    elif type == 7:
        return result[0] == result[1]
    else:
        raise ValueError(f"{type}, {result} is invalid!")

def parse_packet(packet):
    # Parse packet. Return (r, p, v)
    type = bitstr_to_int(packet[3:6])
    if type == 4:
        return parse_literal(packet)
    elif packet[6] == '0':
        return parse_bitlength(packet)
    elif packet[6] == '1':
        return parse_packlength(packet)

def parse_literal(packet):
    # Given a packet that MUST start with literal
    # Parses it and returns (l, p, v) where l is the result (int)
    # and p is the leftover packet and v is the version
    c, r1, r2 = 6, 7, 11
    result, cont = packet[r1:r2], packet[c]
    while cont == '1':
        c, r1, r2 = c+5, r1+5, r2+5
        result += packet[r1:r2]
        cont = packet[c]
    return bitstr_to_int(result), packet[r2:], bitstr_to_int(packet[:3])

def parse_bitlength(packet):
    # Given a packet that MUST have length type 0
    # Parses it and returns (l, p, v) where l is the result (list)
    # and p is the leftover packet, and v is the sum of versions
    v_total = bitstr_to_int(packet[:3])
    r_total = []
    length = bitstr_to_int(packet[7:22])
    old_packet, packet = packet, packet[22:]
    while True:
        
        r, packet, v = parse_packet(packet)
        v_total += v
        r_total.append(r)
        if len(old_packet) - 22 - len(packet) == length:
            break
    r_total = apply(bitstr_to_int(old_packet[3:6]), r_total)
    return r_total, packet, v_total


def parse_packlength(packet):
    length = bitstr_to_int(packet[7:18])
    v_total = bitstr_to_int(packet[:3])
    r_total = []
    length = bitstr_to_int(packet[7:18])
    old_packet, packet = packet, packet[18:]
    for _ in range(length):
        r, packet, v = parse_packet(packet)
        v_total += v
        r_total.append(r)
    r_total = apply(bitstr_to_int(old_packet[3:6]), r_total)
    return r_total, packet, v_total

result, packet, version = parse_packet(data)
result, version


(539051801941, 879)