In [1]:
# Imports & read file
import time
import heapq
import numpy as np
import math

def read_file(filename):
    with open(filename) as infile:
        inp = infile.readline().strip()
        return f"{int(inp, 16):0>{len(inp)*4}b}"
    return None

In [2]:
# Part One
class BitsReader:
    def __init__(self, bits):
        self.bits = bits
        self.current = 0
    
    def read_next_n(self, n):
        self.current += n
        return self.bits[self.current - n:self.current]

def sum_versions(br):
    bits_read = br.current
    version = int(br.read_next_n(3), 2)
    version_sum = version
    typeid = int(br.read_next_n(3), 2)
    if typeid == 4:
        #literal value
        lit = ''
        while True:
            group = br.read_next_n(5)
            lit += group[1:]
            if group[0] == '0':
                break
        lit = int(lit, 2)
    else:
        #operator
        if br.read_next_n(1) == '0':
            #length in bits
            length = int(br.read_next_n(15), 2)
            start = br.current
            while br.current - start < length:
                version_sum += sum_versions(br)
        else:
            #length in packets
            length = int(br.read_next_n(11), 2)
            for _ in range(length):
                version_sum += sum_versions(br)
    return version_sum

def solve_part_one(file):
    return sum_versions(BitsReader(read_file(file)))

In [3]:
# Test Part One
start = time.time()
for f, s in enumerate([16, 12, 23, 31], 1):
    print(solve_part_one(f"test{f:02}.txt") == s)
time.time() - start

True
True
True
True


0.0

In [4]:
# Solve Part One
start = time.time()
print(solve_part_one("input.txt"))
time.time() - start

993


0.0

In [5]:
# Part Two
def calculate_value(br):
    bits_read = br.current
    version = int(br.read_next_n(3), 2)
    typeid = int(br.read_next_n(3), 2)
    if typeid == 4:
        #literal value
        lit = ''
        while True:
            group = br.read_next_n(5)
            lit += group[1:]
            if group[0] == '0':
                break
        return int(lit, 2)
    else:
        #operator
        values = []
        if br.read_next_n(1) == '0':
            #length in bits
            length = int(br.read_next_n(15), 2)
            start = br.current
            while br.current - start < length:
                values.append(calculate_value(br))
        else:
            #length in packets
            length = int(br.read_next_n(11), 2)
            for _ in range(length):
                values.append(calculate_value(br))
        if typeid == 0:
            return sum(values)
        if typeid == 1:
            return math.prod(values)
        if typeid == 2:
            return min(values)
        if typeid == 3:
            return max(values)
        if typeid == 5:
            return 1 if values[0] > values[1] else 0
        if typeid == 6:
            return 1 if values[0] < values[1] else 0
        if typeid == 7:
            return 1 if values[0] == values[1] else 0
    return None

def solve_part_two(file):
    return calculate_value(BitsReader(read_file(file)))

In [6]:
# Test Part Two
start = time.time()
for f, s in enumerate([3, 54, 7, 9, 1, 0, 0, 1], 5):
    print(solve_part_two(f"test{f:02}.txt") == s)
time.time() - start

True
True
True
True
True
True
True
True


0.0

In [7]:
# Solve Part Two
start = time.time()
print(solve_part_two("input.txt"))
time.time() - start

144595909277


0.0