In [1]:
#| default_exp shunting
from nbdev import *
from nbdev.showdoc import *

In [2]:
#| exporti
from collections import deque
import operator

In [3]:
#| export
# https://en.wikipedia.org/wiki/Associative_property
# https://en.wikipedia.org/wiki/Shunting-yard_algorithm
# Tested on:
# https://adventofcode.com/2020/day/18
# https://leetcode.com/problems/basic-calculator/
class ShuntingYard:
    """
    Init with a precedence dictionary. Then call SY.calc(line), with a string as input
    Symbols and numbers should be spaced from eachother
    
    Example precedence dictionary. Higher numbers have a higher precedence:
    prec = defaultdict(lambda: int(9))
    prec.update({'*':4, 
            '+':4,
            '/':4, 
            ':':4,
            '-':4,
            '^':4,
            '**':4})
    
    self.ops is a dictionary with the functions that are called with certain symbols, e.g.
    Example self.ops:
    self.ops = {
            '+' : operator.add,
            '-' : operator.sub,
            '*' : operator.mul,
            '/' : operator.truediv,  
            ':' : operator.truediv,
            '%' : operator.mod,
            '^' : operator.xor,
            '**' : pow,
        }

    
    """
    def __init__(self, prec=None, ops=None):
        self.prec = {
            '$':8,
            '**':7, 
            '/':6, 
            '*':6, 
            ':':6,
            '+':5,
            '-':5,
            '<<':4,
            '>>': 4,
            '&': 3,
            '^':2,
            '|': 1} if not prec else prec
    
        self.ops = {
            '$' : operator.sub,
            '+' : operator.add,
            '-' : operator.sub,
            '*' : operator.mul,
            '/' : operator.truediv,  
            ':' : operator.truediv,
            '%' : operator.mod,
            '^' : operator.xor,
            '**' : pow,
        } if not ops else ops

    def is_callable_string(self, s):
        try:
            res = eval(f'callable({s})')
        except Exception:
            return False
        return True

    def get_postfix(self, list_of_symbols):
        op_stack = deque()
        output_stack = deque()
        # we need to check if the '-' sign is unary operation (e.g. -2 instead of 1 - 2)
        possible_unary = True

        for symbol in list_of_symbols:
            if isinstance(symbol, int): 
                output_stack.append(symbol)
                possible_unary = False
            elif self.is_callable_string(symbol): 
                op_stack.append(symbol)
            elif symbol in self.prec:
                while (op_stack and op_stack[-1] != '(' and (
                    self.prec[op_stack[-1]] > self.prec[symbol] or 
                    (self.prec[op_stack[-1]] == self.prec[symbol] and symbol in '-/*+'))):
                    output_stack.append(op_stack.pop())
                if symbol == '-' and possible_unary:
                    op_stack.append('$')
                else:
                    op_stack.append(symbol)
            
            elif symbol == '(': 
                op_stack.append(symbol)
                possible_unary = True
            elif symbol == ')':
                possible_unary = False
                while op_stack[-1] != '(':
                    output_stack.append(op_stack.pop())
                if op_stack and op_stack[-1]=='(':
                    op_stack.pop() # remove the (
                if op_stack and callable(op_stack[-1]):
                    output_stack.append(op_stack.pop())
            else:
                print('should not happen')
        while op_stack:
            output_stack.append(op_stack.pop())
        return output_stack

    def eval_postfix(self, output_stack):
        res = []
        for symbol in output_stack:
            if isinstance(symbol, int): res.append(symbol)
            else:
                second = res.pop()
                if symbol == '$': # unary minus: 0 - second
                    first = 0
                else:
                    first = res.pop()
                if symbol in self.ops:
                    temp = self.ops[symbol](first, second)
                else:
                    temp = eval(f'{symbol}({first},{second})')
                res.append(temp)
        return res[0]
    
    def calc(self, line):
        line = line.replace("(","( ").replace(")"," )")
        line = [int(arg) if arg.isnumeric() else arg for arg in line.split()]
        return self.eval_postfix(self.get_postfix(line))

In [4]:
f = open('shunting.txt')
lines = [line.rstrip('\n') for line in f]
prec = {'*':4, '+':4}
SY = ShuntingYard(prec)
ans = 0
for line in lines:
    ans += SY.calc(line)
assert ans == 21993583522852

prec = {'*':4, '+':5}
SY = ShuntingYard(prec)
ans = 0
for line in lines:
    ans += SY.calc(line)
assert ans == 122438593522757


In [7]:
SY = ShuntingYard()
assert SY.calc("- 1 - 4") == -5
assert SY.calc("1 - ( - 2 )") == 3
assert SY.calc("- 1 - ( - 2 )") == 1
assert SY.calc("- 1 * 3 + 1") == -2
assert SY.calc("5 - 0 * 7") == 5
assert SY.calc("(5 - 0) * 7") == 35
assert SY.calc(" (            5        -    0     )   *   7") == 35