In [48]:
from typing import Sequence, Optional, List, Tuple
from dataclasses import dataclass
from functools import reduce

def hex_char_to_bin(hex_char:str)->str:#
    """Convert a hex Character to a 4-bit binary string"""
    return bin(int(hex_char, 16))[2:].zfill(4)

In [49]:
def hex_to_bin(hex_str:str)->str:
    return "".join(hex_char_to_bin(hex_char) for hex_char in hex_str)

assert hex_to_bin("D2FE28") == "110100101111111000101000"

In [57]:
@dataclass
class Packet:
    version: int
    type_id: int
    subpackets: Sequence['Packet'] = ()
    value: Optional[int] = None
    
    def sum_of_versions(self) -> int:
        return self.version + sum(packet.sum_of_versions() 
                                  for packet in self.subpackets)
@dataclass
class Bitstream:
    bits: str
    index:int=0
    
    def read(self, num_bits:int)->str:
        if num_bits > len(self.bits) - self.index:
            raise ValueError(f"Not enough bits left to read {num_bits}")
        result = self.bits[self.index:self.index + num_bits]
        self.index += num_bits
        return result

    
def _parse(bitstream: Bitstream)->Packet:
    """
    Parse a single packet from a bitstram,
    consuming the bits that make it up
    """
    
    # first three bits are the version in binary
    version = int(bitstream.read(3), 2)
    # next three are the type id in binary
    type_id = int(bitstream.read(3), 2)
    
    # the packet starts after that
    if type_id == 4:
        # literal
        digits = []
        while bitstream.read(1) == '1':
            digits.append(bitstream.read(4))
        # and now we have the last byte
        digits.append(bitstream.read(4))
        
        value = int("".join(digits), 2)
        
        packet = Packet(version, type_id, value=value)
        return packet
    else:
        # operator
        length_type_id = bitstream.read(1)
        num_subpacks = length = None
        if length_type_id == "0":
            # length specified as "total lenght in bytes"
            length = int(bitstream.read(15), 2)
            end = bitstream.index + length
            
            subpackets = []
            while True:
                subpacket = _parse(bitstream)
                subpackets.append(subpacket)
                if bitstream.index >= end:
                    break
            packet = Packet(version, type_id, subpackets=subpackets)
            return packet
        elif length_type_id == "1":
            # length specified as number of subpackets
            num_subpackets = int(bitstream.read(11), 2)
            subpackets = []
            while len(subpackets) < num_subpackets:
                subpacket = _parse(bitstream)
                subpackets.append(subpacket)
            packet = Packet(version, type_id, subpackets=subpackets)
            return packet
        else:
            raise ValueError(f"Unknown length type id:{length_type_id}")

            
def parse(raw:str, hex:bool=True)->Packet:
    if hex:
        bits=Bitstream(hex_to_bin(raw))
    else:
        bits=Bitstream(raw)
    return _parse(bits)

def add_up_all_version_numbers(hex_string:str)->int:
    """Add up all version numbers in a hex string"""
    packet = parse(hex_string)
    return packet.sum_of_versions()

assert add_up_all_version_numbers('8A004A801A8002F478') == 16
assert add_up_all_version_numbers('620080001611562C8802118E34') == 12
assert add_up_all_version_numbers('C0015000016115A2E0802F182340') == 23
assert add_up_all_version_numbers('A0016C880162017C3686B18A3D4780') == 31

In [40]:
bits = Bitstream(hex_to_bin('D2FE28'))
bits.read(1)

'1'

In [41]:
bits.read(1)

'1'

In [42]:
bits.read(1)

'0'