In [1]:
with open('Day16.txt') as file:
    data = file.read()
binary_data = ''.join(bin(int(item, 16))[2:].rjust(4, '0') for item in data)

In [2]:
from dataclasses import dataclass, field
from __future__ import annotations

from functools import reduce
from operator import mul, gt, lt, eq

@dataclass
class Packet:
    version: int
    type_id: int
    data: int | list[Packet] = field(compare=False)
    cursor: int = field(repr=False)
    
    def sum_versions(self):
        result = self.version
        if self.type_id != 4:
            for child in self.data:
                result += child.sum_versions()
        return result
    
    def evaluate(self):
        if self.type_id == 4:
            return self.data
        
        evals = [p.evaluate() for p in self.data]
        match self.type_id:
            case 0:
                return sum(evals)
            case 1:
                return reduce(mul, evals)
            case 2:
                return min(evals)
            case 3:
                return max(evals)
            case 5:
                return int(gt(*evals))
            case 6:
                return int(lt(*evals))
            case 7:
                return int(eq(*evals))

In [3]:
def parse(data):
    version = int(data[0:3], 2)
    type_id = int(data[3:6], 2)
    result = None
    cursor = 6
    if type_id == 4:
        result = ''
        bit = '1'
        while bit == '1':
            bit = data[cursor]
            cursor += 1
            result += data[cursor:cursor+4]
            cursor += 4
        result = int(result, 2)
    else:
        result = []
        length = data[cursor]
        cursor += 1
        if length == '0':
            length = int(data[cursor:cursor+15], 2)
            cursor += 15
            length += cursor
            while cursor < length:
                packet = parse(data[cursor:])
                result.append(packet)
                cursor += packet.cursor
        else:
            length = int(data[cursor:cursor+11], 2)
            cursor += 11
            for _ in range(length):
                packet = parse(data[cursor:])
                result.append(packet)
                cursor += packet.cursor
    return Packet(version, type_id, result, cursor)

In [4]:
def sum_version(packet):
    result = packet.version
    if isinstance(packet.data, list):
        for packet in packet.data:
            result += sum_version(packet)
    return result

In [5]:
%%time
packet = parse(binary_data)

CPU times: user 719 µs, sys: 0 ns, total: 719 µs
Wall time: 728 µs


In [6]:
packet.sum_versions()

957

In [7]:
packet.evaluate()

744953223228