## Imports

In [1]:
import math
import sys
from enum import IntEnum
from itertools import takewhile
from typing import List, Literal, Optional, Tuple, Union

from more_itertools import sliced
from pydantic import BaseModel, conint

from aoc_utilities import puzzle_input

<h2>--- Day 16: Packet Decoder ---</h2><p>As you leave the cave and reach open waters, you receive a transmission from the Elves back on the ship.</p>
<p>The transmission was sent using the Buoyancy Interchange Transmission System (<span title="Just be glad it wasn't sent using the BuoyancY Transmission Encoding System.">BITS</span>), a method of packing numeric expressions into a binary sequence. Your submarine's computer has saved the transmission in <a href="https://en.wikipedia.org/wiki/Hexadecimal" target="_blank">hexadecimal</a> (your puzzle input).</p>
<p>The first step of decoding the message is to convert the hexadecimal representation into binary. Each character of hexadecimal corresponds to four bits of binary data:</p>
<pre><code>0 = 0000
1 = 0001
2 = 0010
3 = 0011
4 = 0100
5 = 0101
6 = 0110
7 = 0111
8 = 1000
9 = 1001
A = 1010
B = 1011
C = 1100
D = 1101
E = 1110
F = 1111
</code></pre>
<p>The BITS transmission contains a single <em>packet</em> at its outermost layer which itself contains many other packets. The hexadecimal representation of this packet might encode a few extra <code>0</code> bits at the end; these are not part of the transmission and should be ignored.</p>
<p>Every packet begins with a standard header: the first three bits encode the packet <em>version</em>, and the next three bits encode the packet <em>type ID</em>. These two values are numbers; all numbers encoded in any packet are represented as binary with the most significant bit first. For example, a version encoded as the binary sequence <code>100</code> represents the number <code>4</code>.</p>
<p>Packets with type ID <code>4</code> represent a <em>literal value</em>. Literal value packets encode a single binary number. To do this, the binary number is padded with leading zeroes until its length is a multiple of four bits, and then it is broken into groups of four bits. Each group is prefixed by a <code>1</code> bit except the last group, which is prefixed by a <code>0</code> bit. These groups of five bits immediately follow the packet header. For example, the hexadecimal string <code>D2FE28</code> becomes:</p>
<pre><code>110100101111111000101000
VVVTTTAAAAABBBBBCCCCC
</code></pre>
<p>Below each bit is a label indicating its purpose:</p>
<ul>
<li>The three bits labeled <code>V</code> (<code>110</code>) are the packet version, <code>6</code>.</li>
<li>The three bits labeled <code>T</code> (<code>100</code>) are the packet type ID, <code>4</code>, which means the packet is a literal value.</li>
<li>The five bits labeled <code>A</code> (<code>10111</code>) start with a <code>1</code> (not the last group, keep reading) and contain the first four bits of the number, <code>0111</code>.</li>
<li>The five bits labeled <code>B</code> (<code>11110</code>) start with a <code>1</code> (not the last group, keep reading) and contain four more bits of the number, <code>1110</code>.</li>
<li>The five bits labeled <code>C</code> (<code>00101</code>) start with a <code>0</code> (last group, end of packet) and contain the last four bits of the number, <code>0101</code>.</li>
<li>The three unlabeled <code>0</code> bits at the end are extra due to the hexadecimal representation and should be ignored.</li>
</ul>
<p>So, this packet represents a literal value with binary representation <code>011111100101</code>, which is <code>2021</code> in decimal.</p>
<p>Every other type of packet (any packet with a type ID other than <code>4</code>) represent an <em>operator</em> that performs some calculation on one or more sub-packets contained within. Right now, the specific operations aren't important; focus on parsing the hierarchy of sub-packets.</p>
<p>An operator packet contains one or more packets. To indicate which subsequent binary data represents its sub-packets, an operator packet can use one of two modes indicated by the bit immediately after the packet header; this is called the <em>length type ID</em>:</p>
<ul>
<li>If the length type ID is <code>0</code>, then the next <em>15</em> bits are a number that represents the <em>total length in bits</em> of the sub-packets contained by this packet.</li>
<li>If the length type ID is <code>1</code>, then the next <em>11</em> bits are a number that represents the <em>number of sub-packets immediately contained</em> by this packet.</li>
</ul>
<p>Finally, after the length type ID bit and the 15-bit or 11-bit field, the sub-packets appear.</p>
<p>For example, here is an operator packet (hexadecimal string <code>38006F45291200</code>) with length type ID <code>0</code> that contains two sub-packets:</p>
<pre><code>00111000000000000110111101000101001010010001001000000000
VVVTTTILLLLLLLLLLLLLLLAAAAAAAAAAABBBBBBBBBBBBBBBB
</code></pre>
<ul>
<li>The three bits labeled <code>V</code> (<code>001</code>) are the packet version, <code>1</code>.</li>
<li>The three bits labeled <code>T</code> (<code>110</code>) are the packet type ID, <code>6</code>, which means the packet is an operator.</li>
<li>The bit labeled <code>I</code> (<code>0</code>) is the length type ID, which indicates that the length is a 15-bit number representing the number of bits in the sub-packets.</li>
<li>The 15 bits labeled <code>L</code> (<code>000000000011011</code>) contain the length of the sub-packets in bits, <code>27</code>.</li>
<li>The 11 bits labeled <code>A</code> contain the first sub-packet, a literal value representing the number <code>10</code>.</li>
<li>The 16 bits labeled <code>B</code> contain the second sub-packet, a literal value representing the number <code>20</code>.</li>
</ul>
<p>After reading 11 and 16 bits of sub-packet data, the total length indicated in <code>L</code> (27) is reached, and so parsing of this packet stops.</p>
<p>As another example, here is an operator packet (hexadecimal string <code>EE00D40C823060</code>) with length type ID <code>1</code> that contains three sub-packets:</p>
<pre><code>11101110000000001101010000001100100000100011000001100000
VVVTTTILLLLLLLLLLLAAAAAAAAAAABBBBBBBBBBBCCCCCCCCCCC
</code></pre>
<ul>
<li>The three bits labeled <code>V</code> (<code>111</code>) are the packet version, <code>7</code>.</li>
<li>The three bits labeled <code>T</code> (<code>011</code>) are the packet type ID, <code>3</code>, which means the packet is an operator.</li>
<li>The bit labeled <code>I</code> (<code>1</code>) is the length type ID, which indicates that the length is a 11-bit number representing the number of sub-packets.</li>
<li>The 11 bits labeled <code>L</code> (<code>00000000011</code>) contain the number of sub-packets, <code>3</code>.</li>
<li>The 11 bits labeled <code>A</code> contain the first sub-packet, a literal value representing the number <code>1</code>.</li>
<li>The 11 bits labeled <code>B</code> contain the second sub-packet, a literal value representing the number <code>2</code>.</li>
<li>The 11 bits labeled <code>C</code> contain the third sub-packet, a literal value representing the number <code>3</code>.</li>
</ul>
<p>After reading 3 complete sub-packets, the number of sub-packets indicated in <code>L</code> (3) is reached, and so parsing of this packet stops.</p>
<p>For now, parse the hierarchy of the packets throughout the transmission and <em>add up all of the version numbers</em>.</p>
<p>Here are a few more examples of hexadecimal-encoded transmissions:</p>
<ul>
<li><code>8A004A801A8002F478</code> represents an operator packet (version 4) which contains an operator packet (version 1) which contains an operator packet (version 5) which contains a literal value (version 6); this packet has a version sum of <code><em>16</em></code>.</li>
<li><code>620080001611562C8802118E34</code> represents an operator packet (version 3) which contains two sub-packets; each sub-packet is an operator packet that contains two literal values. This packet has a version sum of <code><em>12</em></code>.</li>
<li><code>C0015000016115A2E0802F182340</code> has the same structure as the previous example, but the outermost packet uses a different length type ID. This packet has a version sum of <code><em>23</em></code>.</li>
<li><code>A0016C880162017C3686B18A3D4780</code> is an operator packet that contains an operator packet that contains an operator packet that contains five literal values; it has a version sum of <code><em>31</em></code>.</li>
</ul>
<p>Decode the structure of your hexadecimal-encoded BITS transmission; <em>what do you get if you add up the version numbers in all packets?</em></p>


## Solution

In [2]:
ENCODED_NIBBLE_BITS = 5
HEADER_BITS = 6
HEADER_FIELD_BITS = 3
PACKET_MIN_BITS = HEADER_BITS + ENCODED_NIBBLE_BITS


class PacketType(IntEnum):
    LITERAL = 4
    UNKNOWN_OPERATOR_1 = 0
    UNKNOWN_OPERATOR_2 = 1
    UNKNOWN_OPERATOR_3 = 2
    UNKNOWN_OPERATOR_4 = 3
    UNKNOWN_OPERATOR_5 = 5
    UNKNOWN_OPERATOR_6 = 6
    UNKNOWN_OPERATOR_7 = 7


class SubpacketLengthType(IntEnum):
    TOTAL_LENGTH = 0
    TOTAL_CHILDREN = 1


class Packet(BaseModel):
    version: conint(ge=0, le=7)
    type_id: PacketType
    subpacket_length_type: Optional[SubpacketLengthType] = None
    content: Union[int, List['Packet']]
    binary_length: conint(ge=11)

    @property
    def is_literal(self) -> bool:
        return self.type_id == PacketType.LITERAL

    def walk_depth_first(self):
        yield self
        if not self.is_literal:
            for subpacket in self.content:
                yield from subpacket.walk_depth_first()


Packet.update_forward_refs()

We can now recursively decode BITS transmissions, generating an (abstract syntax) tree of `Packet`s in the process.

In [3]:
def decode_bits_transmission(transmission: str, hexadecimal=True) -> Packet:
    if hexadecimal:
        transmission_binary = _hexadecimal_to_padded_binary(transmission)
    else:
        transmission_binary = transmission

    if not set(transmission_binary).issubset({'0', '1'}):
        raise ValueError(
            'Transmissions must be converted to binary by this function '
            'or its caller.'
        )

    packet_data = {'subpacket_length_type': None}

    packet_data['version'] = int(transmission_binary[:HEADER_FIELD_BITS], base=2)

    type_id = PacketType(int(transmission_binary[HEADER_FIELD_BITS:HEADER_BITS], base=2))
    packet_data['type_id'] = type_id

    packet_data['content'], subpacket_length_type, content_length = (
        _parse_content(type_id, transmission_binary[6:])
    )

    packet_data['subpacket_length_type'] = subpacket_length_type

    subpacket_length_info_bits = (
        12 * (type_id != PacketType.LITERAL)
        + 4 * (subpacket_length_type == SubpacketLengthType.TOTAL_LENGTH)
    )
    packet_data['binary_length'] = (
        HEADER_BITS
        + subpacket_length_info_bits
        + content_length
    )

    return Packet(**packet_data)


def _hexadecimal_to_padded_binary(hexadecimal: str) -> str:
    binary = f'{int(hexadecimal, base=16):b}'
    padded_length = 8 * math.ceil(len(binary) / 8)
    return f'{binary:0>{padded_length}}'


def _parse_content(
    type_id: PacketType,
    content: str
) -> Tuple[Union[int, List['Packet']], Optional[SubpacketLengthType], int]:
    if type_id == PacketType.LITERAL:
        return _parse_literal(content)
    else:
        return _parse_subpackets(content)


def _parse_literal(literal: str) -> Tuple[int, Literal[None], int]:
    encoded_nibble_length = 5
    possible_nibbles = list(sliced(literal, encoded_nibble_length))
    last_nibble_index = 0
    for index, (leading_nibble_flag, *nibble) in enumerate(possible_nibbles):
        if not int(leading_nibble_flag):
            last_nibble_index = index
            break

    nibbles = possible_nibbles[:last_nibble_index + 1]
    value_binary = ''.join(nibble[1:] for nibble in nibbles)
    return int(value_binary, base=2), None, len(nibbles) * encoded_nibble_length


def _parse_subpackets(content: str) -> Tuple[List['Packet'], SubpacketLengthType, int]:
    raw_subpackets, subpacket_length_type, expected_length, expected_count = (
        _raw_subpackets_and_length_constraints(content)
    )

    subpackets = []
    subpacket_start_index = 0
    while (
            subpacket_start_index < expected_length
            and len(subpackets) < expected_count
    ):
        subpacket = decode_bits_transmission(
            raw_subpackets[subpacket_start_index:],
            hexadecimal=False
        )
        subpackets.append(subpacket)
        subpacket_start_index += subpacket.binary_length

    has_total_length_constraint = (
        subpacket_length_type == SubpacketLengthType.TOTAL_LENGTH
    )

    if has_total_length_constraint and subpacket_start_index != expected_length:
        raise ValueError('Actual and alleged transmission lengths differ.')
    if not has_total_length_constraint and len(subpackets) != expected_count:
        raise ValueError('Actual and alleged subpacket counts differ.')

    return subpackets, subpacket_length_type, subpacket_start_index


def _raw_subpackets_and_length_constraints(
    subpackets_content
):
    expectations = {
        SubpacketLengthType.TOTAL_LENGTH: len(subpackets_content),
        SubpacketLengthType.TOTAL_CHILDREN: len(subpackets_content) / PACKET_MIN_BITS
    }

    subpacket_length_type = int(subpackets_content[0], base=2)

    if subpacket_length_type == SubpacketLengthType.TOTAL_LENGTH:
        length_bits = 15
    else:
        length_bits = 11

    expectations[subpacket_length_type] = int(subpackets_content[1:length_bits + 1], base=2)
    raw_subpackets = subpackets_content[length_bits + 1:]

    return raw_subpackets, subpacket_length_type, *expectations.values()

### Testing

#### Literal Packet Example

In [4]:
actual_packet = decode_bits_transmission('D2FE28')  # 110100101111111000101000
expected_packet = Packet(
    version=6,
    type_id=4,
    binary_length=len('110100101111111000101'),
    content=2021
)

assert actual_packet == expected_packet

#### Total Subpacket Length Example

In [5]:
actual_packet = decode_bits_transmission('38006F45291200')

expected_packet = Packet(
    version=1,
    type_id=6,
    subpacket_length_type=SubpacketLengthType.TOTAL_LENGTH,
    binary_length=len('0011100000000000011011110100010100101001000100100'),
    content=[
        Packet(
            version=6,
            type_id=PacketType.LITERAL,
            binary_length=len('11010001010'),
            content=10
        ),
        Packet(
            version=2,
            type_id=PacketType.LITERAL,
            binary_length=len('0101001000100100'),
            content=20
        )
    ]
)

assert actual_packet == expected_packet

#### Subpacket Count Example

In [6]:
actual_packet = decode_bits_transmission('EE00D40C823060')

expected_packet = Packet(
    version=7,
    type_id=3,
    subpacket_length_type=1,
    binary_length=len('111011100000000011010100000011001000001000110000011'),
    content = [
        Packet(
            version=2,
            type_id=4,
            binary_length=len('01010000001'),
            content=1
        ),
        Packet(
            version=4,
            type_id=4,
            binary_length=len('10010000010'),
            content=2
        ),
        Packet(
            version=1,
            type_id=4,
            binary_length=len('00110000011'),
            content=3
        )
    ]
)

assert actual_packet == expected_packet

### Answer

In [7]:
packet = decode_bits_transmission(puzzle_input.as_text(day=16))

print(
    'Sum of all packet and subpacket version numbers:',
    sum(subpacket.version for subpacket in packet.walk_depth_first())
)

Sum of all packet and subpacket version numbers: 945
