In [1]:
from dataclasses import dataclass
from __future__ import annotations

In [2]:
@dataclass
class Node:
    value: str
    left: Node = None
    right: Node = None
    parent: Node = None
    def replace(self, child, replacement):
        if self.left == child:
            self.left = replacement
        elif self.right == child:
            self.right = replacement
        else:
            raise "Tsk!"
    def add(self, child):
        if self.left == None:
            self.left = child
            child.parent = self
        else:
            self.right = child
            child.parent = self
    def find_root(self) -> Node:
        node = self
        while node.parent is not None:
            node = node.parent
        return node
    def evaluate(self):
        if self.value == '+':
            return self.left.evaluate() + self.right.evaluate()
        elif self.value == '*':
            return self.left.evaluate() * self.right.evaluate()
        else:
            return int(self.value)
    def __repr__(self):
        if self.left is None:
            return str(self.value)
        return f"({self.value} ({self.left} {self.right}))"
    def __str__(self):
        return repr(self)

In [3]:
def get_input(fname):
    with open(fname) as f:
        return [line.strip() for line in f.readlines()]

In [4]:
test_data = get_input("test.txt")

In [5]:
test_data

['1 + 2 * 3 + 4 * 5 + 6',
 '1 + (2 * 3) + (4 * (5 + 6))',
 '2 * 3 + (4 * 5)',
 '5 + (8 * 3 + 9 + 3 * 4 * 3)',
 '5 * 9 * (7 * 3 * 3 + 9 * 3 + (8 + 6 * 4))',
 '((2 + 4 * 9) * (6 + 9 * 8 + 6) + 6) + 2 + 4 * 2']

In [6]:
def evaluate(expression, start=0):
    buffer = []
    last_op = '+'
    result = 0
    crt = start
    while crt < len(expression) and expression[crt] != ')':
        if expression[crt] in ('+', '*'):
            last_op = expression[crt]
        elif expression[crt] == '(':
            num, crt = evaluate(expression, crt + 1)
            if last_op == '+':
                result += num
            else:
                result *= num
        elif expression[crt] == ')':
            pass
        elif expression[crt] == ' ':
            if len(buffer) > 0:
                num = int(''.join(buffer))
                buffer = []
                if last_op == '+':
                    result += num
                else:
                    result *= num
        else:
            buffer.append(expression[crt])
        crt += 1
    if len(buffer) > 0:
        num = int(''.join(buffer))
        buffer = []
        if last_op == '+':
            result += num
        else:
            result *= num
    return result, crt

In [7]:
evaluate(test_data[0])

(71, 21)

In [8]:
[evaluate(expr) for expr in test_data]

[(71, 21), (51, 27), (26, 15), (437, 27), (12240, 41), (13632, 47)]

In [9]:
input_data = get_input("input.txt")

In [10]:
sum(evaluate(expr)[0] for expr in input_data)

16332191652452

In [11]:
n = Node('12')

In [12]:
n.evaluate()

12

In [13]:
Node('+', Node(1), Node(2)).evaluate()

3

In [14]:
Node('*', Node(1), Node(2)).evaluate()

2

In [15]:
def eval_tree(expression, start = 0):
    root = Node('+', Node(0))
    current_node = root
    buffer = []
    crt = start
    while crt < len(expression):
        if expression[crt] == '*':
            old_root = root
            root = Node('*')
            root.add(old_root)
            current_node = root
        elif expression[crt] == '+':
            new_node = Node('+')
            current_node.parent.replace(current_node, new_node)
            new_node.add(current_node)
            current_node = new_node
        elif expression[crt] == '(':
            num, crt = eval_tree(expression, crt + 1)
            current_node.add(Node(num))
            current_node = current_node.right
        elif expression[crt] == ')':
            break
        elif expression[crt] == ' ':
            if len(buffer) > 0:
                num = int(''.join(buffer))
                buffer = []
                current_node.add(Node(num))
                current_node = current_node.right
        else:
            buffer.append(expression[crt])
        crt += 1
    if len(buffer) > 0:
        num = int(''.join(buffer))
        current_node.right = Node(num)
        current_node = current_node.right
    return root.evaluate(), crt

In [16]:
for expr in test_data:
    print(expr, '=', eval_tree(expr)[0])

1 + 2 * 3 + 4 * 5 + 6 = 231
1 + (2 * 3) + (4 * (5 + 6)) = 51
2 * 3 + (4 * 5) = 46
5 + (8 * 3 + 9 + 3 * 4 * 3) = 1445
5 * 9 * (7 * 3 * 3 + 9 * 3 + (8 + 6 * 4)) = 669060
((2 + 4 * 9) * (6 + 9 * 8 + 6) + 6) + 2 + 4 * 2 = 23340


In [17]:
sum(eval_tree(expr)[0] for expr in input_data)

351175492232654