# Day 16: Packet Decoder

In [1]:
from pathlib import Path
from dataclasses import dataclass
from more_itertools import take, collapse
from math import prod
from typing import Iterable

from aoc2021.util import read_as_str, bin2dec, hex2bin

## Puzzle input data

In [2]:
# Test data.
tdata = '8A004A801A8002F478'

# Input data.
data = read_as_str(Path('./day16-input.txt'))
data

'005473C9244483004B001F79A9CE75FF9065446725685F1223600542661B7A9F4D001428C01D8C30C61210021F0663043A20042616C75868800BAC9CB59F4BC3A40232680220008542D89B114401886F1EA2DCF16CFE3BE6281060104B00C9994B83C13200AD3C0169B85FA7D3BE0A91356004824A32E6C94803A1D005E6701B2B49D76A1257EC7310C2015E7C0151006E0843F8D000086C4284910A47518CF7DD04380553C2F2D4BFEE67350DE2C9331FEFAFAD24CB282004F328C73F4E8B49C34AF094802B2B004E76762F9D9D8BA500653EEA4016CD802126B72D8F004C5F9975200C924B5065C00686467E58919F960C017F00466BB3B6B4B135D9DB5A5A93C2210050B32A9400A9497D524BEA660084EEA8EF600849E21EFB7C9F07E5C34C014C009067794BCC527794BCC424F12A67DCBC905C01B97BF8DE5ED9F7C865A4051F50024F9B9EAFA93ECE1A49A2C2E20128E4CA30037100042612C6F8B600084C1C8850BC400B8DAA01547197D6370BC8422C4A72051291E2A0803B0E2094D4BB5FDBEF6A0094F3CCC9A0002FD38E1350E7500C01A1006E3CC24884200C46389312C401F8551C63D4CC9D08035293FD6FCAFF1468B0056780A45D0C01498FBED0039925B82CCDCA7F4E20021A692CC012B00440010B8691761E0002190E21244C98EE0B0C0139297660B401A80002150E20A

## Puzzle answers
### Part 1

In [3]:
Input = str


@dataclass(frozen=True)
class Packet:
    version: int
    typeid: int


@dataclass(frozen=True)
class Literal(Packet):
    value: int
    nbits: int


@dataclass(frozen=True)
class Operator(Packet):
    length_typeid: str
    packets: list[Packet]

    @property
    def nbits(self):
        n = 7 + (15 if self.length_typeid == '0' else 11)
        return n + sum(p.nbits for p in self.packets)

    @property
    def value(self):
        match self.typeid:
            case 0: return sum(p.value for p in self.packets)
            case 1: return prod(p.value for p in self.packets)
            case 2: return min(p.value for p in self.packets)
            case 3: return max(p.value for p in self.packets)
            case 5: return 1 if self.packets[0].value > self.packets[1].value else 0
            case 6: return 1 if self.packets[0].value < self.packets[1].value else 0
            case 7: return 1 if self.packets[0].value == self.packets[1].value else 0
            case _: raise Exception(f'invalid typeid {self.typeid}')


def parse_version(bs: Iterable[str]) -> int:
    return bin2dec(''.join(take(3, bs)))


def parse_typeid(bs: Iterable[str]) -> int:
    return bin2dec(''.join(take(3, bs)))


def parse_literal(bs: Iterable[str]) -> tuple[int,int]:
    value = []
    n = 0
    while True:
        prefix = next(bs)
        value.extend(take(4, bs))
        n += 5
        if prefix == '0':
            return bin2dec(''.join(value)), n


def parse_packet(bs: Iterable[str]) -> Packet:
    ver = parse_version(bs)
    tid = parse_typeid(bs)
    if tid == 4:
        val, n = parse_literal(bs)
        return Literal(version=ver, typeid=tid, value=val, nbits=n+6)
    else:
        length_tid = next(bs)
        subpackets = []
        if length_tid == '0':
            nbits_to_read = bin2dec(''.join(take(15, bs)))
            while nbits_to_read > 0:
                subpackets.append(parse_packet(bs))
                nbits_to_read -= subpackets[-1].nbits
        else:
            packets_to_read = bin2dec(''.join(take(11, bs)))
            for _ in range(packets_to_read):
                subpackets.append(parse_packet(bs))
        return Operator(version=ver, typeid=tid, length_typeid=length_tid, packets=subpackets)


def parse_transmission(data: Input) -> Literal | Operator:
    return parse_packet(iter(hex2bin(data, chunksize=4)))


def version_numbers(packet: Packet) -> list[int]:
    if packet.typeid == 4:
        return [packet.version]
    else:
        return list(collapse([packet.version] + [version_numbers(p) for p in packet.packets]))


assert hex2bin('D2FE28', chunksize=4) == '110100101111111000101000'
assert hex2bin('38006F45291200', chunksize=4) == '00111000000000000110111101000101001010010001001000000000'
assert hex2bin('EE00D40C823060', chunksize=4) == '11101110000000001101010000001100100000100011000001100000'
assert sum(version_numbers(parse_transmission(tdata))) == 16
assert sum(version_numbers(parse_transmission('38006F45291200'))) == 1+6+2
assert sum(version_numbers(parse_transmission('EE00D40C823060'))) == 7+2+4+1
assert sum(version_numbers(parse_transmission('620080001611562C8802118E34'))) == 12
assert sum(version_numbers(parse_transmission('C0015000016115A2E0802F182340'))) == 23
assert sum(version_numbers(parse_transmission('A0016C880162017C3686B18A3D4780'))) == 31

In [4]:
n = sum(version_numbers(parse_transmission(data)))
print(f'Adding up the version numbers in all packets of the transmission: {n}')

Adding up the version numbers in all packets of the transmission: 847


### Part 2

In [5]:
assert parse_transmission('C200B40A82').value == 3
assert parse_transmission('04005AC33890').value == 54
assert parse_transmission('880086C3E88112').value == 7
assert parse_transmission('CE00C43D881120').value == 9
assert parse_transmission('D8005AC2A8F0').value == 1
assert parse_transmission('F600BC2D8F').value == 0
assert parse_transmission('9C005AC2F8F0').value == 0
assert parse_transmission('9C0141080250320F1802104A08').value == 1

In [6]:
n = parse_transmission(data).value
print(f'Evaluating the hexadecimal-encoded BITS transmission expression: {n}')

Evaluating the hexadecimal-encoded BITS transmission expression: 333794664059
