# Day 16
## Part 1
I'm so ill that I'm using classes and inheritance. Really not thinking straight.

This is really the sort of problem that demands more robustness than it's worth implementing for a run-once solution.

In [1]:
from dataclasses import dataclass, field
from typing import List


@dataclass 
class Packet:
    version: int
    type_id: int
    
    def version_sum(self):
        return self.version
    
    
@dataclass
class LiteralPacket(Packet):
    value: int
        
        
@dataclass
class OperatorPacket(Packet):
    subpackets: List[Packet] = field(default_factory=list)
        
    def version_sum(self):
        return self.version + sum(sp.version_sum() for sp in self.subpackets)

Define a cod-bitstream class rather than keeping track of a pointer to the current character. The gotcha here is when you need to read in packets up to a certain bit length, so define functions to deal with consumed bit counts.   

In [2]:
@dataclass
class Bitstream:
    b: str
    i: int = 0
        
    def read(self, n):
        x = int(self.b[self.i:self.i + n], 2)
        self.i += n
        return x
    
    def count_start(self):
        return self.i
        
    def count(self, count_start):
        return self.i - count_start


class PacketParser:
    def __init__(self, hex_string):
        self.b = Bitstream(''.join([f'{int(c, 16):04b}' 
                                    for c in hex_string]))
        self.packet = self.read_packet()
    
    def read_packet(self):
        version = self.b.read(3)
        type_id = self.b.read(3)

        if type_id == 4:
            val = 0
            reading_value = self.b.read(1)
            val = self.b.read(4)
            while reading_value:
                reading_value = self.b.read(1)   
                val = val * 16 + self.b.read(4)
            return LiteralPacket(version, type_id, val)
        else:
            packet = OperatorPacket(version, type_id)
            length_type = self.b.read(1)
            if length_type == 0:
                length = self.b.read(15)
                count_start = self.b.count_start()
                while self.b.count(count_start) < length:
                    subpacket = self.read_packet()
                    packet.subpackets.append(subpacket)
            elif length_type == 1:
                length = self.b.read(11)
                for _ in range(length):
                    subpacket = self.read_packet()
                    packet.subpackets.append(subpacket)
            return packet
        
        
def parse_data(s):
    return PacketParser(s).packet
    
parse_data('D2FE28')

LiteralPacket(version=6, type_id=4, value=2021)

In [3]:
PacketParser('38006F45291200').packet

OperatorPacket(version=1, type_id=6, subpackets=[LiteralPacket(version=6, type_id=4, value=10), LiteralPacket(version=2, type_id=4, value=20)])

In [4]:
PacketParser('EE00D40C823060').packet

OperatorPacket(version=7, type_id=3, subpackets=[LiteralPacket(version=2, type_id=4, value=1), LiteralPacket(version=4, type_id=4, value=2), LiteralPacket(version=1, type_id=4, value=3)])

In [5]:
def part_1(data):
    return data.version_sum()

p = parse_data('8A004A801A8002F478')
p

OperatorPacket(version=4, type_id=2, subpackets=[OperatorPacket(version=1, type_id=2, subpackets=[OperatorPacket(version=5, type_id=2, subpackets=[LiteralPacket(version=6, type_id=4, value=15)])])])

In [6]:
assert part_1(p) == 16

In [7]:
q = parse_data('620080001611562C8802118E34')
q

OperatorPacket(version=3, type_id=0, subpackets=[OperatorPacket(version=0, type_id=0, subpackets=[LiteralPacket(version=0, type_id=4, value=10), LiteralPacket(version=5, type_id=4, value=11)]), OperatorPacket(version=1, type_id=0, subpackets=[LiteralPacket(version=0, type_id=4, value=12), LiteralPacket(version=3, type_id=4, value=13)])])

In [8]:
r = parse_data('C0015000016115A2E0802F182340')
r

OperatorPacket(version=6, type_id=0, subpackets=[OperatorPacket(version=0, type_id=0, subpackets=[LiteralPacket(version=0, type_id=4, value=10), LiteralPacket(version=6, type_id=4, value=11)]), OperatorPacket(version=4, type_id=0, subpackets=[LiteralPacket(version=7, type_id=4, value=12), LiteralPacket(version=0, type_id=4, value=13)])])

In [9]:
assert part_1(q) == 12

In [10]:
assert part_1(r) == 23

In [11]:
data = parse_data(open('input', 'r').read().strip())

In [12]:
part_1(data)

891

## Part 2

If I'd known that this was coming I would probably have just implemented a tree rather than the class above. Never mind, add an additional calculation function to the packet classes.

In [13]:
import math


@dataclass 
class Packet:
    version: int
    type_id: int
    
    def version_sum(self):
        return self.version
    
    def calculation(self):
        pass
    
    
@dataclass
class LiteralPacket(Packet):
    value: int
        
    def calculation(self):
        return(self.value)
        
        
@dataclass
class OperatorPacket(Packet):
    subpackets: List[Packet] = field(default_factory=list)
        
    def version_sum(self):
        return self.version + sum(sp.version_sum() for sp in self.subpackets)
    
    def calculation(self):
        fd = {
            0: sum,
            1: math.prod,
            2: min,
            3: max,
            5: lambda xs: 1 if xs[0] > xs[1] else 0,
            6: lambda xs: 1 if xs[0] < xs[1] else 0,
            7: lambda xs: 1 if xs[0] == xs[1] else 0
        }
        f = fd[self.type_id]
        return f([sp.calculation() for sp in self.subpackets])
    
def part_2(data):
    return data.calculation()

In [14]:
PacketParser('C200B40A82').packet.calculation()

3

In [15]:
PacketParser('9C0141080250320F1802104A08').packet.calculation()

1

In [16]:
data = parse_data(open('input', 'r').read().strip())
part_2(data)

673042777597