In [83]:
from dataclasses import dataclass
import numpy as np

def hex_to_bin(hexStr):
    packet = bin(int(hexStr, 16))[2:]
    mod = len(packet)%4
    if mod > 0:
        extra_0 = (4-mod)*"0"
        packet = extra_0 + packet
    return packet

with open("input.txt", "r") as f:
    packet = hex_to_bin(f.read())

class Packet:
    version :int
    pckt_type: int
    bit_length: int
    
    @staticmethod
    def parse(bits):
        version = int(bits[:3], 2)
        pckt_type = int(bits[3:6], 2)
        return Packet._create(version, pckt_type, bits)
        
    
    def _create(version: int, pckt_type: int, bits):
        if (pckt_type==4):
            return Literal(version, bits)
        else: 
            return Operator(version, pckt_type, bits)
            
    def __init__(self, version, pckt_type):
        self.version = version
        self.pckt_type = pckt_type
        self.bit_length = 6 #version (3 bits) + packet type (3 bits)
        
class Literal(Packet):
    def __init__(self, version, bits):
        Packet.__init__(self, version, 4)
        n_blocks, self._value = self._parse_bits(bits)
        self.bit_length += n_blocks * 5
        
    def _parse_bits(self, bits):
        bit_repr = ""
        body = bits[self.bit_length:]
        for i in range(int(len(body)/4)):
            bit_repr += body[i*5+1:(i+1)*5]
            leading_bit = body[i*5]
            if leading_bit == "0":
                break
        return (i+1, int(bit_repr, 2))
    
    def sum_versions(self):
        return self.version
    
    def value(self):
        return self._value
    
class Operator(Packet):
    def __init__(self, version, type, bits):
        Packet.__init__(self, version, type)
        self._parse_length(bits)
        self._parse_subpackets(bits)
    
    def _parse_length(self, bits):
        self.length_type = bits[self.bit_length]
        self.bit_length += 1
        if self.length_type == "0":
            bits_to_read = 15
        else:
            bits_to_read = 11
        self.length = int(bits[self.bit_length:self.bit_length+bits_to_read], 2)
        self.bit_length += bits_to_read
    
    def _parse_subpackets(self, bits):
        self.subpackets = []
        finished = False
        all_spkt_bits = 0
        while not finished:
            self.subpackets.append(Packet.parse(bits[self.bit_length:]))
            subpacket_bits = self.subpackets[-1].bit_length
            self.bit_length += subpacket_bits
            all_spkt_bits += subpacket_bits
            finished = all_spkt_bits >= self.length if self.length_type == "0" else len(self.subpackets)==self.length
    
    def sum_versions(self):
        v = self.version
        for packet in self.subpackets:
            v += packet.sum_versions()
        return v
    
    def value(self):
        res = 0
        if self.pckt_type==0:
            for p in self.subpackets:
                res += p.value()
        if self.pckt_type==1:
            res = 1
            for p in self.subpackets:
                res *= p.value()
        if self.pckt_type==2:
            res = np.inf
            for p in self.subpackets:
                res = min(res, p.value())
        if self.pckt_type==3:
            res = -np.inf
            for p in self.subpackets:
                res = max(res, p.value())
        if self.pckt_type==5:
            res = 1 if self.subpackets[0].value() > self.subpackets[1].value() else 0
        if self.pckt_type==6:
            res = 1 if self.subpackets[0].value() < self.subpackets[1].value() else 0
        if self.pckt_type==7:
            res = 1 if self.subpackets[0].value() == self.subpackets[1].value() else 0
        return res
                
op = Packet.parse(packet)
answer_1 = op.sum_versions()

print(f"answer_1: {answer_1}")

print(f"answer_2: {op.value()}")


answer_1: 957
answer_2: 744953223228
