In [1]:
import numpy as np

In [2]:
puzzle_input = open('inputs/16').read().strip()

In [3]:
test1 = 'D2FE28'
test2 = '38006F45291200'
test3 = 'EE00D40C823060'

test4 = '8A004A801A8002F478'
test5 = '620080001611562C8802118E34'
test6 = 'C0015000016115A2E0802F182340'
test7 = 'A0016C880162017C3686B18A3D4780'

In [4]:
def hex_to_4_bit_string(s):
    number = int(s, 16)
    size = 4 * len(s)
    
    return f'{number:0>{size}b}'

In [5]:
class BitString:
    def __init__(self, s):
        self.s = s
        self.consumed = 0
        
    def get_remaining(self):
        return self.s[consumed:]
    
    def consume_str(self, i):
        s = self.s[self.consumed:(self.consumed+i)]
        self.consumed += i
        return s
    
    def consume_int(self, i):
        return int(self.consume_str(i), 2)

In [6]:
def p1(s):
    def parse_packet(bs):
        version_sum = 0

        version = bs.consume_int(3)
        version_sum += version

        packet_type =  bs.consume_int(3)

        match packet_type:
            case 4:
                # literal
                bits = ""

                while True:
                    done = bs.consume_int(1) == 0     
                    bits += bs.consume_str(4)

                    if done:
                        break

                literal = int(bits, 2)
            case _:
                # operator
                length_type_id = bs.consume_int(1)

                match length_type_id:
                    # type 0, the length tells us total number of bits of subpackets
                    case 0:
                        length_subpackets = bs.consume_int(15)                    
                        consumed_prior = bs.consumed

                        while bs.consumed - consumed_prior != length_subpackets:
                            version_sum += parse_packet(bs)

                    # type 0, the length tells us total number of subpackets
                    case 1:
                        number_of_subpackets = bs.consume_int(11)

                        for i in range(number_of_subpackets):
                            version_sum += parse_packet(bs)

        return version_sum
    
    return parse_packet(BitString(hex_to_4_bit_string(s)))

In [7]:
p1(test1)

6

In [8]:
p1(test2)

9

In [9]:
p1(test3)

14

In [10]:
assert p1(test4) == 16
assert p1(test5) == 12
assert p1(test6) == 23
assert p1(test7) == 31

In [11]:
p1(puzzle_input)

891

In [12]:
def operation(packet_type):
    match packet_type:
        case 0:
            return sum
        case 1:
            return np.product
        case 2:
            return min
        case 3:
            return max
        case 5:
            return lambda vs: vs[0] > vs[1]
        case 6:
            return lambda vs: vs[0] < vs[1]
        case 7:
            return lambda vs: vs[0] == vs[1]

In [15]:
def p2(s):
    def parse_packet(bs):
        version = bs.consume_int(3)
        packet_type =  bs.consume_int(3)

        match packet_type:
            case 4:
                # literal
                bits = ""

                while True:
                    done = bs.consume_int(1) == 0     
                    bits += bs.consume_str(4)

                    if done:
                        break

                return int(bits, 2)
            case _:
                # operator
                length_type_id = bs.consume_int(1)
                subpacket_values = []

                match length_type_id:
                    # type 0, the length tells us total number of bits of subpackets
                    case 0:
                        length_subpackets = bs.consume_int(15)                    
                        consumed_prior = bs.consumed

                        while bs.consumed - consumed_prior != length_subpackets:
                            subpacket_values.append(parse_packet(bs))

                    # type 0, the length tells us total number of subpackets
                    case 1:
                        number_of_subpackets = bs.consume_int(11)

                        for i in range(number_of_subpackets):
                            subpacket_values.append(parse_packet(bs))

                return int(operation(packet_type)(subpacket_values))
    
    return parse_packet(BitString(hex_to_4_bit_string(s)))

In [16]:
assert p2('C200B40A82') == 3
assert p2('04005AC33890') == 54
assert p2('880086C3E88112') == 7
assert p2('CE00C43D881120') == 9
assert p2('D8005AC2A8F0') == 1
assert p2('F600BC2D8F') == 0
assert p2('9C005AC2F8F0') == 0
assert p2('9C0141080250320F1802104A08') == 1

In [17]:
p2(puzzle_input)

673042777597