In [27]:
from enum import Enum
import re
import graphviz

### Defining Opcodes & Regex

In [28]:
# Define opcodes and their associated regexes. Order determines precedence
class Opcode(Enum):
  # Groupings
  kNeg = 7
  kParens = 8
  kBrackets = 9
  # Integrals & Summation
  kInt = 1
  kSum = 2
  # Constants & Varaibles
  kInteger = 5
  kVariable = 4
  kPi = 3
  kInfty = 6
  kBinom = 10
  kFrac = 11
  # Addition & Subtraction
  kAdd = 12
  kSub = 13
  # Equality
  kEqual = 14
  # Products & Divsion
  kJuxt = 15
  kCdot = 16
  kProd = 17
  # Trigonometric Functions
  kSin = 18
  kCos = 19
  kTan = 20
  # Exponents & Roots
  kPow = 21
  kExp = 22
  kLn = 23
  # Factorials
  kFact = 24

grouping = r"({.*}|\\[a-z]+|.)"

regexes = {
  Opcode.kInteger: r"-?[0-9]+",
  Opcode.kVariable: r"[a-z]",
  Opcode.kNeg: r"-(.*)",
  Opcode.kJuxt: r"([0-9]+|[a-z])([a-z])",
  Opcode.kParens: r"\((.*)\)",
  Opcode.kBrackets: r"{(.*)}",
  Opcode.kPi: r"\\pi",
  Opcode.kInfty: r"\\infty",
  Opcode.kInt: r"\\int_" + grouping + r"\^" + grouping + r"(.*)d(.*)",
  Opcode.kProd: r"\\prod_" + grouping + r"\^" + grouping + r"(.*)",
  Opcode.kCdot: r"(.*)\\cdot(.*)",
  Opcode.kFrac: r"\\frac" + grouping + grouping,
  Opcode.kSum: r"\\sum_" + grouping + r"\^" + grouping + r"(.*)",
  Opcode.kSin: r"\\sin(.*)",
  Opcode.kCos: r"\\cos(.*)",
  Opcode.kTan: r"\\tan(.*)",
  Opcode.kExp: r"\\exp(.*)",
  Opcode.kLn: r"\\ln(.*)",
  Opcode.kPow: r"(.*)\^" + grouping,
  Opcode.kAdd: r"(.*)\+(.*)",
  Opcode.kEqual: r"(.*)=(.*)",
  Opcode.kSub: r"(.*)-(.*)",
  Opcode.kFact: r"(.*)!",
  Opcode.kBinom: r"(.*)\\choose(.*)"
}

In [29]:
examples = ["(\int_0^1 x \, dx) + 1",
            "\sum_{n = 1}^\infty n",
            "x + iy",
            "8x + 1",
            "\sin x^2",
            "\\frac{a + b}{2}",
            "a \cdot b^3",
            "n^2!",
            "\sin(2z)",
            "{n - 1 \choose k}",
            "\int_a^{-3} x^2 \, dx",
            "\sum_{n = 1}^\infty \\frac{1}{n}",
            "z^3+z^2+z+1",
            "\int_0^1 x^n \, dx + 1"
            "e^{-z^2}",
            "2 \pi i"
            ]

In [49]:
import re

class Tokens(Enum):
  kInteger = 1
  kVariable = 2
  kPi = 3
  kCdot = 4
  kChoose = 5
  kFrac = 6
  kSin = 7
  kSum = 8
  kInt  = 9
  kInfty = 10
  kUnderscore = 11
  kCarat = 12
  kOpenBrace = 13
  kCloseBrace = 14
  kOpenParen = 15
  kCloseParen = 16
  kPlus = 17
  kMinus = 18
  kEquals = 19
  kExclamation = 20

regexes = {
  Tokens.kInteger: "[0-9]+",
  Tokens.kVariable: "[a-z]",
  Tokens.kPi: r"\\pi",
  Tokens.kCdot: r"\\cdot",
  Tokens.kChoose: r"\\choose",
  Tokens.kFrac: r"\\frac",
  Tokens.kSin: r"\\sin",
  Tokens.kSum: r"\\sum",
  Tokens.kInt: r"\\int",
  Tokens.kInfty: r"\\infty",
  Tokens.kUnderscore: "_",
  Tokens.kCarat: "\^",
  Tokens.kOpenBrace: "{",
  Tokens.kCloseBrace: "}",
  Tokens.kOpenParen: "\(",
  Tokens.kCloseParen: "\)",
  Tokens.kPlus: "\+",
  Tokens.kMinus: "-",
  Tokens.kEquals: "=",
  Tokens.kExclamation: "!"
}

class Instruction:
  def __init__(self, opcode, value):
    self.opcode = opcode
    self.value = value
    self.operands = []

  def __str__(self):
    return self.opcode.name + '(' + ', '.join(map(str, self.operands)) + ')'

def to_tokens(tex):
  tex = re.sub(r"\s|\\,", "", tex)

  tokens = []
  while tex:
    for token in Tokens:
      match = re.match("(" + regexes[token] + ")", tex)
      if match:
        tokens.append(match.groups()[0])
        # tokens.append(token)
        tex = tex[match.end():]
        break
  return tokens

def to_tree(tokens):
  for token in tokens:
    if token == '!':
      print(['!', tokens[:-1]])
    # if target in tokens:
    #   print(target)
    #   break

for ex in examples:
  print(ex)
  print(to_tokens(ex))
  to_tree(to_tokens(ex))

x + iy
['x', '+', 'i', 'y']
8x + 1
['8', 'x', '+', '1']
\sin x^2
['\\sin', 'x', '^', '2']
\frac{a + b}{2}
['\\frac', '{', 'a', '+', 'b', '}', '{', '2', '}']
a \cdot b^3
['a', '\\cdot', 'b', '^', '3']
n^2!
['n', '^', '2', '!']
['!', ['n', '^', '2']]
\sin(2z)
['\\sin', '(', '2', 'z', ')']
{n - 1 \choose k}
['{', 'n', '-', '1', '\\choose', 'k', '}']
\int_a^{-3} x^2 \, dx
['\\int', '_', 'a', '^', '{', '-', '3', '}', 'x', '^', '2', 'd', 'x']
\sum_{n = 1}^\infty \frac{1}{n}
['\\sum', '_', '{', 'n', '=', '1', '}', '^', '\\infty', '\\frac', '{', '1', '}', '{', 'n', '}']
z^3+z^2+z+1
['z', '^', '3', '+', 'z', '^', '2', '+', 'z', '+', '1']


### Parsing

In [31]:
class Instruction:
  def __init__(self, opcode, value):
    self.opcode = opcode
    self.value = value
    self.operands = []

  def __str__(self):
    return self.opcode.name + '(' + ', '.join(map(str, self.operands)) + ')'

  def __eq__(self, other):
    if self.opcode != other.opcode:
      return False
    if self.value != other.value:
      return False
    for (self_operand, other_operand) in zip(self.operands, other.operands):
      # TODO Maybe handle commutativity
      if self_operand != other_operand:
        return False
    return True

In [32]:
# Construct a tree representation of LaTeX input
def parse(tex):
  # Avoid escaping characters and clear out $, whitespace, etc
  tex = re.sub(r"\$|\s|\\,|\\left|\\right", "", tex)
  return recursive_parse(tex)

def recursive_parse(tex):
  for op in Opcode:
    match = re.fullmatch(regexes[op], tex)
    if match:
      # Groupings are unnecessary in a tree structure
      if op == Opcode.kParens or op == Opcode.kBrackets:
        return recursive_parse(match.groups()[0])

      # If this is an integer constant, we need to store its value
      value = None
      if op == Opcode.kInteger:
        value = tex

      # TODO We might already have the instruction -- kill redundancy

      # Create the instruction and recursively parse operands
      node = Instruction(op, value)
      for operand in match.groups():
        node.operands.append(recursive_parse(operand))
      break
  return node

### Plotting with GraphViz

In [33]:
# Traverse the node and build up the graph
def graph(node, dot=None):
  if dot == None:
    dot = graphviz.Digraph(graph_attr={'bgcolor':'deepskyblue'}, node_attr={'shape':'box',
    'style':'rounded,filled', 'fillcolor':'hotpink', 'fontname':'Courier', 'fontcolor':'white'})

  # Add the node to the graph its id in memory as an identifier
  node_id = str(id(node))
  dot.node(node_id, node.value if node.value else node.opcode.name)

  # Recurse on children and draw edges to them
  for i, operand in enumerate(node.operands):
    graph(operand, dot)
    dot.edge(node_id, str(id(operand)), label=str(i))
  return dot

In [34]:
examples = ["x + iy",
            "8x + 1",
            "\sin x^2",
            "\\frac{a + b}{2}",
            "a \cdot b^3",
            "n^2!",
            "\sin(2z)",
            "{n - 1 \choose k}",
            "\int_a^{-3} x^2 \, dx",
            "\sum_{n = 1}^\infty \\frac{1}{n}",
            "z^3+z^2+z+1",
            # "\int_0^1 x^n \, dx + 1"
            # "e^{-z^2}",
            # "2 \pi i"
            ]
for example in examples:
  display(graph(parse(example)))

KeyError: <Opcode.kNeg: 7>

### Optimization

In [72]:
# Common sub-expession elimination
def cse(node):
  seen = []

  # Postorder traversal
  def cse_traverse(node):
    for i, operand in enumerate(node.operands):
      if operand in seen:
        # Update current node to depend on already seen version
        node.operands[i] = seen[seen.index(operand)]
      else:
        # Otherwise, mark operand as seen and traverse
        cse_traverse(operand)
        seen.append(operand)

  cse_traverse(node)

In [None]:
# Ensure z^3 + z^2 + z + 1 is processed as desired
node = parse("z^3+z^2+z+1")
display(graph(node))
cse(node)
display(graph(node))

In [None]:
# TODO Add binarization
node = parse("1+2+3+4+5+6+7+8")
display(graph(node))

# TODO Rework parser to handle this
node2 = parse("((1+2)+(3+4))+((5+6)+(7+8))")
display(graph(node2))

### Code Generation

In [9]:
# import math

# def codegen(node):
#   # TODO Do one layer of recursion
#   f = None
#   if node.opcode == Opcode.kVariable:
#     f = lambda x : x
#   elif node.opcode == Opcode.kSin:
#     f = lambda x : math.sin(x)
#   elif node.opcode == Opcode.kPow:
#     f = lambda x, y : x**y
#   elif node.opcode == Opcode.kAdd:
#     f = lambda x, y : x + y
#   elif node.opcode == Opcode.kInteger:
#     f = lambda : node.value

#   print(node.opcode)
#   return f(*map(codegen, node.operands))


TODO
* Remove kBrackets/kJuxt/kParens
* Fix broken examples