In [1]:
input_filename = "input.txt"

with open(input_filename) as input_file:
    hex_transmission = input_file.read().strip()

# Parsing logic

In [2]:
from typing import List, Optional, Tuple


class Packet:
    
    def __init__(self, version: int, type_id: int,
                 *, 
                 literal_value: Optional[int] = None,
                 subpackets: Optional[List["Packet"]] = None):
        self.version = version
        self.type_id = type_id
        
        self.literal_value = literal_value
        self.subpackets = subpackets

    def __repr__(self):
        if self.type_id == 4:
            remaining_info = f"literal_value={str(self.literal_value)}"
        else:
            remaining_info = f"num_subpackets={len(self.subpackets)}"

        return f"Packet(V={self.version}, T={self.type_id}, {remaining_info})"


class Transmission:
    
    def __init__(self, hex_msg: str):
        self.hex_msg = hex_msg
        
        self.curr_bits = 0
        self.pad = 0  # Track how many 0s are padding the front
        
        self.idx = 0  # Track which hex character we're looking at
        self.num_bits_consumed = 0
        
        # Stores answer to part 1
        self.version_sum = 0
        
        self.already_parsed = False
        self.packet = self.parse()
    
    def parse(self) -> Packet:
        if self.already_parsed:
            raise Exception("Already parsed!")
        self.already_parsed = True
        return self.parse_packet()
    
    def parse_packet(self) -> Packet:
        if self.idx >= len(self.hex_msg):
            raise Exception("Packet already fully parsed!")
        
        version = self.get_next_bits(3)
        self.version_sum += version
        
        type_id = self.get_next_bits(3)
        
        if type_id == 4:
            literal_value = self.get_literal_value()
            return Packet(version, type_id, literal_value=literal_value)
        
        length_type_id = self.get_next_bits(1)
        if length_type_id == 0:
            total_length_in_bits = self.get_next_bits(15)
            subpackets = self.parse_subpackets_by_length(total_length_in_bits)
        elif length_type_id == 1:
            num_subpackets = self.get_next_bits(11)
            subpackets = self.parse_subpackets_by_num_subpackets(num_subpackets)
        else:
            raise Exception("You made a mistake!")
        
        return Packet(version, type_id, subpackets=subpackets)

    def get_literal_value(self):
        value = 0
        while True:
            bit_group = self.get_next_bits(5)
            
            # Take current value and append last 4 bits of new bit_group
            value = (value << 4) + (bit_group & ((1 << 4) - 1))
            
            # Check if bit_group is the last group
            if not bit_group >> 4:
                break
        
        return value
    
    def parse_subpackets_by_length(self, total_length_in_bits: int) -> Packet:
        """
        Parse packets until we've consumed `total_length_in_bits` number of bits.
        """
        subpackets = []
        num_bits = self.num_bits_consumed
        while self.num_bits_consumed <  num_bits + total_length_in_bits:
            packet = self.parse_packet()
            subpackets.append(packet)
        assert self.num_bits_consumed - num_bits == total_length_in_bits
        return subpackets

    def parse_subpackets_by_num_subpackets(self, num_subpackets: int) -> Packet:
        """
        Parse `num_subpacket` number of packets.
        """
        subpackets = []
        for _ in range(num_subpackets):
            packet = self.parse_packet()
            subpackets.append(packet)
        return subpackets
    
    def get_next_bits(self, num_bits: int):
        """Gets the next `num_bits` bits of the transmission."""
        while (self.curr_bits.bit_length() + self.pad) < num_bits:
            self.append_next(self.hex_msg[self.idx])
            self.idx += 1
        
        self.num_bits_consumed += num_bits
        return self.consume(num_bits)
    
    def append_next(self, hex_char: str) -> None:
        """
        Takes curr_bits and appends binary value of hex_char at the end.
        For example, if curr_bits were 1 and hex_char were 5 = 0101, the
        resulting bits would be 10101.
        """
        bits_to_append = int(hex_char, 16)
        if self.curr_bits == 0:
            self.pad += max(4 - bits_to_append.bit_length(), 0)
        
        self.curr_bits = (self.curr_bits << 4) + bits_to_append

    def consume(self, num_bits: int) -> int:
        """
        Removes the first num_bits from curr_bits (including any padded 0s).
        Returns the removed bits.
        """
        num_padded_removed = min(self.pad, num_bits)
        self.pad -= num_padded_removed
        
        num_bits = num_bits - num_padded_removed
        
        if num_bits > 0:
            init_length = self.curr_bits.bit_length()
            
            num_bits_to_keep = (self.curr_bits.bit_length() + self.pad) - num_bits

            removed_bits = self.curr_bits >> num_bits_to_keep
            self.curr_bits = self.curr_bits & ((1 << num_bits_to_keep) - 1)
            
            assert self.pad == 0
            self.pad = num_bits_to_keep - self.curr_bits.bit_length()

            return removed_bits
        
        return 0

# Part 1

In [3]:
def get_version_sum(msg: str) -> int:
    t = Transmission(msg)
    return t.version_sum

In [4]:
# Test cases
assert get_version_sum("8A004A801A8002F478") == 16
assert get_version_sum("620080001611562C8802118E34") == 12
assert get_version_sum("C0015000016115A2E0802F182340") == 23
assert get_version_sum("A0016C880162017C3686B18A3D4780") == 31

In [5]:
# Real case
get_version_sum(hex_transmission)

923

# Part 2

In [6]:
def operate(packet: Packet) -> int:
    
    if packet.type_id == 4:
        return packet.literal_value
    
    values = [operate(subpacket) for subpacket in packet.subpackets]
    
    if packet.type_id == 0:
        return sum(values)
    
    elif packet.type_id == 1:
        product = 1
        for value in values:
            product *= value
        return product
    
    elif packet.type_id == 2:
        return min(values)
    
    elif packet.type_id == 3:
        return max(values)
    
    elif packet.type_id == 5:
        assert len(values) == 2
        return 1 if values[0] > values[1] else 0
    
    elif packet.type_id == 6:
        assert len(values) == 2
        return 1 if values[0] < values[1] else 0

    elif packet.type_id == 7:
        assert len(values) == 2
        return 1 if values[0] == values[1] else 0
    
    else:
        raise ValueError("Unexpected type_id:", type_id)

In [7]:
test_cases = [
    # hex message, expected value
    ("D2FE28", 2021),
    ("C200B40A82", 3),
    ("04005AC33890", 54),
    ("880086C3E88112", 7),
    ("CE00C43D881120", 9),
    ("D8005AC2A8F0", 1),
    ("F600BC2D8F", 0),
    ("9C005AC2F8F0", 0),
    ("9C0141080250320F1802104A08", 1),
]

for hex_msg, expected_value in test_cases:
    try:
        assert operate(Transmission(hex_msg).packet) == expected_value
    except:
        print(hex_msg, expected_value)

In [8]:
# Real case
t = Transmission(hex_transmission)
operate(t.packet)

258888628940