In [1]:
HEX_EXPANSION = {
    '0': '0000',
    '1': '0001',
	'2': '0010',
	'3': '0011',
	'4': '0100',
	'5': '0101',
	'6': '0110',
	'7': '0111',
	'8': '1000',
	'9': '1001',
	'A': '1010',
	'B': '1011',
	'C': '1100',
	'D': '1101',
	'E': '1110',
	'F': '1111',
}

In [128]:
def read_input():
    with open("input.txt", "rt") as f:
        return f.read()

In [37]:

from dataclasses import dataclass
from typing import List

In [120]:
@dataclass
class Packet:
	version: int
	type_id: int

@dataclass
class LiteralValue(Packet):
	value: int

@dataclass
class OperatorValue(Packet):
	sub_packets: List[Packet]

def to_int(chars):
	if isinstance(chars, list):
		return int(''.join(chars), 2)
	else:
		return int(chars, 2)

def peek(l, n=1):
	return l[:n]

def consume(l, n=1):
	result = l[:n]
	for _ in range(n):
		l.pop(0)
	return result

def expand_hex(chars) -> List[str]:
	result = []
	for c in chars:
		result += list(HEX_EXPANSION[c])
	return result

def parse_literal(version, type_id, buf: List[str]) -> LiteralValue:
	result = []
	while True:
		chars = consume(buf, 5)
		result += chars[1:]
		if chars[0] == '0':
			break
	return LiteralValue(version, type_id, to_int(result))

def parse_total_length_operator(version, type_id, buf: List[str]) -> OperatorValue:
	bytes_read = 7
	total_length = to_int(consume(buf, 15))
	bytes_read += 15
	buf_segment = consume(buf, total_length)
	bytes_read += total_length
	packets = []
	while buf_segment:
		packets.append(parse_packet(buf_segment))
	return OperatorValue(version, type_id, packets)

def parse_total_count_operator(version, type_id, buf: List[str]) -> OperatorValue:
	total_length = to_int(consume(buf, 11))
	packets = []
	for _ in range(total_length):
		packets.append(parse_packet(buf))
	return OperatorValue(version, type_id, packets)

def parse_operator(version, type_id, buf: List[str]) -> OperatorValue:
	length_type = to_int(consume(buf, 1))
	if length_type == 0:
		return parse_total_length_operator(version, type_id, buf)
	else:
		return parse_total_count_operator(version, type_id, buf)


def parse_packet(buf: List[str]):
	# print(''.join(buf))
	version = to_int(consume(buf, 3))
	type_id = to_int(consume(buf, 3))
	if type_id == 4:
		return parse_literal(version, type_id, buf)
	else:
		return parse_operator(version, type_id, buf)

def sum_versions(packet: Packet):
	if isinstance(packet, LiteralValue):
		return packet.version
	elif isinstance(packet, OperatorValue):
		return packet.version + sum(sum_versions(p) for p in packet.sub_packets)
	else:
		raise ValueError(f"Invalid packet {packet}")

In [121]:
parse_packet(expand_hex('D2FE28'))

LiteralValue(version=6, type_id=4, value=2021)

In [122]:
parse_packet(expand_hex('38006F45291200'))

OperatorValue(version=1, type_id=6, sub_packets=[LiteralValue(version=6, type_id=4, value=10), LiteralValue(version=2, type_id=4, value=20)])

In [123]:
parse_packet(expand_hex('EE00D40C823060'))

OperatorValue(version=7, type_id=3, sub_packets=[LiteralValue(version=2, type_id=4, value=1), LiteralValue(version=4, type_id=4, value=2), LiteralValue(version=1, type_id=4, value=3)])

In [124]:
sum_versions(parse_packet(expand_hex('8A004A801A8002F478')))

16

In [125]:
sum_versions(parse_packet(expand_hex('620080001611562C8802118E34')))

12

In [126]:
sum_versions(parse_packet(expand_hex('C0015000016115A2E0802F182340')))

23

In [127]:
sum_versions(parse_packet(expand_hex('A0016C880162017C3686B18A3D4780')))

31

In [130]:
sum_versions(parse_packet(expand_hex(read_input())))

904

In [135]:
def sum_packet(packet: OperatorValue) -> int:
	return sum(calculate_value(p) for p in packet.sub_packets)

def product_packet(packet: OperatorValue) -> int:
	result = 1
	for p in packet.sub_packets:
		result *= calculate_value(p)
	return result

def minimum_packet(packet: OperatorValue) -> int:
	return min(calculate_value(p) for p in packet.sub_packets)

def maximum_packet(packet: OperatorValue) -> int:
	return max(calculate_value(p) for p in packet.sub_packets)

def gt_packet(packet: OperatorValue) -> int:
	return int(calculate_value(packet.sub_packets[0]) > calculate_value(packet.sub_packets[1]))

def lt_packet(packet: OperatorValue) -> int:
	return int(calculate_value(packet.sub_packets[0]) < calculate_value(packet.sub_packets[1]))

def eq_packet(packet: OperatorValue) -> int:
	return int(calculate_value(packet.sub_packets[0]) == calculate_value(packet.sub_packets[1]))

DISPATCH = {
	0: sum_packet,
	1: product_packet,
	2: minimum_packet,
	3: maximum_packet,
	5: gt_packet,
	6: lt_packet,
	7: eq_packet,
}

def calculate_value(packet: Packet) -> int:
	if isinstance(packet, LiteralValue):
		return packet.value
	elif isinstance(packet, OperatorValue):
		return DISPATCH[packet.type_id](packet)
	else:
		raise ValueError(f"Invalid packet {packet}")

In [136]:
calculate_value(parse_packet(expand_hex('C200B40A82')))

3

In [137]:
calculate_value(parse_packet(expand_hex('04005AC33890')))

54

In [138]:
calculate_value(parse_packet(expand_hex('880086C3E88112')))

7

In [139]:
calculate_value(parse_packet(expand_hex('CE00C43D881120')))

9

In [140]:
calculate_value(parse_packet(expand_hex('D8005AC2A8F0')))

1

In [141]:
calculate_value(parse_packet(expand_hex('F600BC2D8F')))

0

In [142]:
calculate_value(parse_packet(expand_hex('9C005AC2F8F0')))

0

In [143]:
calculate_value(parse_packet(expand_hex('9C0141080250320F1802104A08')))

1

In [144]:
calculate_value(parse_packet(expand_hex(read_input())))

200476472872