# MathQA Preprocessing

## Usage

##### Requirements
- Ensure you are using python version 3.10 or above (necessary for match/case statements)
- Make sure all of the below packages are installed (probably wont have anytree/pyparsing)

##### How to use
- Running all of the cells below will generate three csv files, (train.csv, validation.csv, test.csv), into a folder called MathQA in the parent directory. (Change the DIRECTORY constant as needed) (Will probably take some time to run)
  
##### Columns
- The CSV file contains the following columns:
- problem
- solution (the numeric solution)
- formula (The preprocessed formula that only includes basic arithmetic operators: *,/,+,-,^.)
- formula_no_const (same as the above, but constants are converted to numeric values, If you would like to run with python eval function, be sure to change ^ to **)
- annotated_formula (the original formula provided by the MathQA dataset in case needed for training experiments)
- incremental (semi colon delimited labels, splitting formulas into a stepwise approach, explained later)
- tree (A dictionary representation of the anytree expression tree object representing the formula, can be imported using the anytree DictImporter)
- Note: the incremental label type will have some nan values (due to exceeding the hyperparemter K mentioned below), these must be removed if you intend to train on this label.

#### Imports

In [1]:
from sklearn.metrics import f1_score, accuracy_score
from datasets import Dataset, load_dataset, DatasetDict
import numpy as np
import pandas as pd
import os

import re
import ast
import pyparsing
import pyparsing as pp
from anytree import Node, RenderTree, PreOrderIter, PostOrderIter, LevelOrderIter, LevelOrderGroupIter
from anytree.exporter import DictExporter
from enum import Enum
from copy import deepcopy
import math
from math import sqrt
import warnings
warnings.filterwarnings("error")

#### Constants and Enums

In [2]:
K = 6 # The maximum number of items allowed per layer when creating step wise layers (explained in greater detail later)
DIRECTORY = '../MathQA/'

class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_PI = 'const_pi'
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_DEG_TO_RAD = 'const_deg_to_rad' # pi / 180 (There is only one example of this and its actually used incorrectly)
    
    CONST_1 = 'const_1'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600' 
    
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_0_4535 = 'const_0_4535'
    CONST_0_6 = 'const_0_6'
    CONST_1_6 = 'const_1_6'
    CONST_2_2046 = 'const_2_2046'
    CONST_2_54 = 'const_2_54'
    CONST_3_6 = 'const_3_6'
    
    CONST_0dot25 = 'const_0.25'
    CONST_0dot5 = 'const_0.5' 
    CONST_2dot0 = 'const_2.0'
    CONST_3dot0 = 'const_3.0'
    CONST_4dot0 = 'const_4.0'
    CONST_60dot0 = 'const_60.0'
    CONST_100dot0 = 'const_100.0'

values = [math.pi, -1, math.pi/180, 1, 2, 3, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600,
          0.25, 0.2778, 1/3, 0.3937, 0.4535, 0.6, 1.6, 2.2046, 2.54, 3.6, 0.25, 0.5, 2.0, 3.0, 4.0, 60.0, 100.0]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    
const2val['const_0_5'] = 0.5

# Some of the constants are duplications of other constants, so I consolodate them here into a map
convert_const = {k:v for k,v in zip(Const._value2member_map_.keys(), Const._value2member_map_.keys())}
convert_const[Const.CONST_0dot25.value] = Const.CONST_0_25.value
convert_const[Const.CONST_0dot5.value] = 'const_0_5'
convert_const[Const.CONST_2dot0.value] = Const.CONST_2.value
convert_const[Const.CONST_3dot0.value] = Const.CONST_3.value
convert_const[Const.CONST_4dot0.value] = Const.CONST_4.value
convert_const[Const.CONST_60dot0.value] = Const.CONST_60.value
convert_const[Const.CONST_100dot0.value] = Const.CONST_100.value

## Loading the Data

In [3]:
data = load_dataset('math_qa')

Allows data to be obtained for a single problem type category

In [4]:
def get_category(data, category):
    new_data = {}
    for subset, columns in data.column_names.items():
        idx = np.array(data[subset]['category'])==category
        temp = Dataset.from_dict({col:np.array(data[subset][col])[idx] for col in columns})
        temp.set_format('torch')
        new_data[subset] = temp
    return DatasetDict(new_data)
        
#general_data = get_category(data, 'general')

The following functions convert mathQA annotated formulas into expression trees.

In [5]:
# Gets the children for each parent recursively
def process_children(curr, parent):
    prev = None
    for x in curr:        
        # sanity error check
        if type(x) not in [str, pyparsing.results.ParseResults]:
            raise Exception(f'Unexpected type: {type(x)} for item: {x}')
        
        # If its a string, its a node, if its a pyparsing object, its the nodes children, if its a comma, ignore
        if type(x) == str and x != ',':
            prev = Node(x.replace(',', ''), parent)
        elif type(x) == pyparsing.results.ParseResults:
            process_children(x, prev)

# Converts a MathQA annotated formula into a tree structure
def create_tree(formula):
    parser = pp.Word(pp.alphas+'_') + pp.nestedExpr('(', ')')
    parsed = parser.parse_string(formula)
    root = Node(parsed[0]) # the first element is the name of the outermost operation
    
    # recursively adding children
    for x in parsed[1:]:
        process_children(x, root)    
        
    return root
 
# Outputs a tree given its root in human readable format
def print_tree(tree):
    for pre, fill, node in RenderTree(tree):
        print("%s%s" % (pre, node.name))    

The following function preprocesses the created formula trees in the following ways:
- All operations are rewritten as simple addition, subtraction, division, multiplication, and exponentiation problems (for example the circle_area(r) can be written as just mult(pi, power(r,2)).
- There are a few operations that cannot be rewritten as simple addition, subtraction, division, multiplication, and exponentiation problems, so trees with these are simply removed. This is the standard for any work done with this dataset among the other state of the art papers. (CITE SOME OF THEM) To mark a tree as needing to be removed, None is returned. (CONSIDER TRYING THE OTHER OPERATORS ONCE A BASELINE PERFORMANCE IS ESTABLISHED)
- The addition, subtraction, division, multiplication, exponentiation operators are rewritten to +, -, /, *, ^.
- There is actually some noise in the data where some of the formulas do not actually give the correct answer. For these few cases, the trees are removed

In [6]:
UNIMPLEMENTED = 'UNIMPLEMENTED'

def preprocess_tree(root):
    try:
        return match_node(deepcopy(root))
    except Exception as e:
        #################################
        # TODO:
        #################################
        # Figure out a way to keep track of the indices for which elements are being removed
        if str(e) != UNIMPLEMENTED: # If unimplemented, do not include in final tree list
            raise
        else:
            return None
        
def match_node(node):
    match node.name:   
        case 'add':
            node.name = Op.ADD.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'subtract':
            node.name = Op.SUB.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'multiply':
            node.name = Op.MULT.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'divide':
            node.name = Op.DIV.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'log':
            raise Exception(UNIMPLEMENTED)
        case 'sqrt': # x^(1/2)
            # 1/2
            half = Node(Op.DIV.value)
            Node(Const.CONST_1.value, half)
            Node(Const.CONST_2.value, half)
            
            # x^(1/2)
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), half)
        case 'factorial':
            raise Exception(UNIMPLEMENTED)
        case 'gcd':
            raise Exception(UNIMPLEMENTED)
        case 'lcm':
            raise Exception(UNIMPLEMENTED)
        case 'power':
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'max':
            raise Exception(UNIMPLEMENTED)
        case 'min':
            raise Exception(UNIMPLEMENTED)
        case 'reminder':
            raise Exception(UNIMPLEMENTED)
        case 'negate': # -1 * x
            x = node.children[0]
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # -1, x
            node.children = (Node(Const.CONST_NEG_1.value, parent = node), match_node(x))
        case 'inverse': # 1 / x
            x = node.children[0]
            
            # /
            node = Node(Op.DIV.value, node.parent)
            
            # 1, x
            node.children = (Node(Const.CONST_1.value, parent = node), match_node(x))
        case 'floor':
            raise Exception(UNIMPLEMENTED)
        case 'sine':
            raise Exception(UNIMPLEMENTED)
        case 'cosine':
            raise Exception(UNIMPLEMENTED)
        case 'tangent':
            raise Exception(UNIMPLEMENTED)
        case 'circle_area': # pi*r^2
            r = node.children[0]
            
            # *
            node = Node(Op.MULT.value, node.parent) # setting to new sub tree
            
            # pi
            l1_1 = Node(Const.CONST_PI.value, node)
            
            # r^2
            l1_2 = Node(Op.POW.value, node)            
            l1_2.children = (match_node(r), Node(Const.CONST_2.value, l1_2))
        case 'circumface': # circumference 2*pi*r
            r = node.children[0]
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # 2
            l1_1 = Node(Const.CONST_2.value, node)
            
            # pi * r
            l1_2 = Node(Op.MULT.value, node)
            l1_2.children = (Node(Const.CONST_PI.value, l1_2), match_node(r))
        case 'rectangle_perimeter': # 2*x1 + 2*x2
            x1 = node.children[0]
            x2 = node.children[1]
            
            # +
            node = Node(Op.ADD.value, node.parent)
            
            # 2*x1
            l1_1 = Node(Op.MULT.value, node)
            l1_1.children = (Node(Const.CONST_2.value, l1_1), match_node(x1))
            
            # 2*x2
            l1_2 = Node(Op.MULT.value, node)
            l1_2.children = (Node(Const.CONST_2.value, l1_2), match_node(x2))
        case 'rectangle_area': # x1 * x2
            node.name = Op.MULT.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'square_perimeter': # 4 * x
            node.name = Op.MULT.value
            node.children = (Node(Const.CONST_4.value, node), match_node(node.children[0]))
        case 'square_area': # x^2
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), Node(Const.CONST_2.value, node))
        case 'rhombus_perimeter': # 4 * x
            raise Exception(UNIMPLEMENTED)
        case 'rhombus_area': # pq/2
            p = node.children[0]
            q = node.children[1]
            
            # /
            node = Node(Op.DIV.value, node.parent)
            
            # p * q
            l1_1 = Node(Op.MULT.value, node)
            l1_1.children = (match_node(p), match_node(q))
            
            # 2
            l1_2 = Node(Const.CONST_2.value, node)
        case 'quadrilateral_area': # This is just trapezoid area according to training examples: (b1+b2)/2*h
            h = node.children[0]
            b1 = node.children[1]
            b2 = node.children[2]
            
            # * 
            node = Node(Op.MULT.value, node.parent)
            
            # b1+b2
            l1_1 = Node(Op.ADD.value, node)
            l1_1.children = (match_node(b1), match_node(b2))
            
            # h/2
            l2_1 = Node(Op.DIV.value, node)
            l2_1.children = (match_node(h), Node(Const.CONST_2.value, l2_1))
        case 'volume_cone': # pi*r^2*h/3
            r = node.children[0]
            h = node.children[1]
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # r^2
            r2 = Node(Op.POW.value)
            r2.children = (match_node(r), Node(Const.CONST_2.value))
            
            # pi*r^2
            l1_1 = Node(Op.MULT.value, node)
            l1_1.children = (Node(Const.CONST_PI.value), r2)
            
            # h/3
            h3 = Node(Op.DIV.value, node)
            h3.children = (match_node(h), Node(Const.CONST_3.value))
        case 'volume_rectangular_prism': # l*w*h
            l = node.children[0]
            w = node.children[1]
            h = node.children[2]
            
            # l*w
            lw = Node(Op.MULT.value)
            lw.children = (match_node(l), match_node(w))
            
            # (l*w) * h
            node.name = Op.MULT.value
            node.children = (lw, match_node(h))
        case 'volume_cube': # x^3
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), Node(Const.CONST_3.value))
        case 'volume_sphere': # 4/3 * pi * r^3
            r = node.children[0]
            
            node = Node(Op.MULT.value, node.parent)
            
            # r^3
            l1_1 = Node(Op.POW.value, node)
            l1_1.children = (match_node(r), Node(Const.CONST_3.value))
            
            # pi * (4/3)
            l1_2 = Node(Op.MULT.value, node)
            Node(Const.CONST_PI.value, l1_2)
            
            l2_1 = Node(Op.DIV.value, l1_2)
            Node(Const.CONST_4.value, l2_1)
            Node(Const.CONST_3.value, l2_1)   
        case 'volume_cylinder': # pi*h*r^2
            r = node.children[0]
            h = node.children[1]
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # r^2
            l1_1 = Node(Op.POW.value, node)
            l1_1.children = (match_node(r), Node(Const.CONST_2.value))
            
            # h*pi
            l1_2 = Node(Op.MULT.value, node)
            l1_2.children = (match_node(h), Node(Const.CONST_PI.value))         
        case 'surface_cylinder': # 2*pi*r*(r + h)
            r = node.children[0]
            h = node.children[1]
            
            r = match_node(r)
            h = match_node(h)
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # r + h
            l1_1 = Node(Op.ADD.value, node)
            l1_1.children = (deepcopy(r), h)
            
            # pi * r
            pi_r = Node(Op.MULT.value)
            pi_r.children = (Node(Const.CONST_PI.value), r)
            
            # 2 * (pi * r)
            l1_2 = Node(Op.MULT.value, node)
            l1_2.children = (Node(Const.CONST_2.value), pi_r)
        case 'surface_cube': # 6*x^2
            x = node.children[0]
            
            # * 
            node = Node(Op.MULT.value, node.parent)
            
            # 6
            Node(Const.CONST_6.value, node)
            
            # x^2
            l1_1 = Node(Op.POW.value, node)
            l1_1.children = (match_node(x), Node(Const.CONST_2.value)) 
        case 'surface_rectangular_prism': # 2*(wl + hl + hw)
            l = node.children[0]
            w = node.children[1]
            h = node.children[2]
            
            l1 = match_node(l)
            w1 = match_node(w)
            h1 = match_node(h)
            l2 = deepcopy(l1)
            w2 = deepcopy(w1)
            h2 = deepcopy(h1)
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # 2
            Node(Const.CONST_2.value, node)
            
            # w*l
            wl = Node(Op.MULT.value)
            wl.children = (w1, l1)
            
            # h*l
            hl = Node(Op.MULT.value)
            hl.children = (h1, l2)
            
            # h*w
            hw = Node(Op.MULT.value)
            hw.children = (h2, w2)
            
            # hl + hw
            temp = Node(Op.ADD.value)
            temp.children = (hl, hw)
            
            # wl + (hl + hw)
            l1_1 = Node(Op.ADD.value, node)
            l1_1.children = (wl, temp)
        case 'surface_sphere': # 4*pi*r^2
            r = node.children[0]
            
            # *
            node = Node(Op.MULT.value, node.parent)
            
            # 4 * pi
            l1_1 = Node(Op.MULT.value, node)
            l1_1.children = (Node(Const.CONST_4.value), Node(Const.CONST_PI.value))
            
            # r^2
            l1_2 = Node(Op.POW.value, node)
            l1_2.children = (match_node(r), Node(Const.CONST_2.value))
        case 'cube_edge_by_volume': # x^(1/3)
            # 1/3
            third = Node(Op.DIV.value)
            Node(Const.CONST_1.value, third)
            Node(Const.CONST_3.value, third)
            
            # x^(1/3)
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), third)
        case 'diagonal': # pythagorean: (a^2 + b^2)^(1/2)     
            a = node.children[0]
            b = node.children[1]
            
            # 1/2
            half = Node(Op.DIV.value)
            Node(Const.CONST_1.value, half)
            Node(Const.CONST_2.value, half)
            
            # a^2
            a2 = Node(Op.POW.value)
            a2.children = (match_node(a), Node(Const.CONST_2.value))
            
            # b^2
            b2 = Node(Op.POW.value)
            b2.children = (match_node(b), Node(Const.CONST_2.value))
            
            # (a^2 + b^2)
            a2_b2 = Node(Op.ADD.value)
            a2_b2.children = (a2, b2)
            
            # (a^2 + b^2)^(1/2)
            node.name = Op.POW.value
            node.children = (a2_b2, half)
        case 'square_edge_by_perimeter': # x/4
            node.name = Op.DIV.value
            node.children = (match_node(node.children[0]), Node(Const.CONST_4.value))
        case 'square_edge_by_area': # x^(1/2)
            # 1/2
            half = Node(Op.DIV.value)
            Node(Const.CONST_1.value, half)
            Node(Const.CONST_2.value, half)
            
            # x^(1/2)
            node.name = Op.POW.value
            node.children = (match_node(node.children[0]), half)
        case 'triangle_perimeter': # a + b + c
            a = node.children[0]
            b = node.children[1]
            c = node.children[2]
            
            # b + c
            b_c = Node(Op.ADD.value)
            b_c.children = (match_node(b), match_node(c))
            
            # a + (b + c)
            node.name = Op.ADD.value
            node.children = (match_node(a), b_c)
        case 'triangle_area': # b * h/2
            b = node.children[0]
            h = node.children[1]
            
            # h/2
            h2 = Node(Op.DIV.value)
            h2.children = (match_node(h), Node(Const.CONST_2.value))
            
            # b * (h/2)
            node.name = Op.MULT.value
            node.children = (match_node(b), h2)
        case 'triangle_area_three_edges': 
            raise Exception(UNIMPLEMENTED)
        case 'negate_prob': # TODO: This operation makes no sense in the training, might be worth just writing my own formulas
            raise Exception(UNIMPLEMENTED)
        case 'permutation':
            raise Exception(UNIMPLEMENTED)
        case 'p_after_gain': # (x1/100)*x2
            x1 = node.children[0]
            x2 = node.children[1]
            
            # x1/100
            x_100 = Node(Op.DIV.value)
            x_100.children = (match_node(x1), Node(Const.CONST_100.value))
            
            node.name = Op.MULT.value
            node.children = (x_100, match_node(x2))
        case 'original_price_before_gain': # x/(1+p/100)
            p = node.children[0]
            x = node.children[1]
            
            # p/100
            p_100 = Node(Op.DIV.value)
            p_100.children = (match_node(p), Node(Const.CONST_100.value))
            
            # (1 + p/100)
            temp = Node(Op.ADD.value)
            temp.children = (Node(Const.CONST_1.value), p_100)
            
            # x/(1+p/100)
            node.name = Op.DIV.value
            node.children = (match_node(x), temp)
        case 'original_price_before_loss': # x/(1-p/100)
            p = node.children[0]
            x = node.children[1]
            
            # p/100
            p_100 = Node(Op.DIV.value)
            p_100.children = (match_node(p), Node(Const.CONST_100.value))
            
            # (1 - p/100)
            temp = Node(Op.SUB.value)
            temp.children = (Node(Const.CONST_1.value), p_100)
            
            # x/(1-p/100)
            node.name = Op.DIV.value
            node.children = (match_node(x), temp)
        case 'speed': # d/t
            node.name = Op.DIV.value
            node.children = (match_node(node.children[0]), match_node(node.children[1]))
        case 'speed_in_still_water':
            raise Exception(UNIMPLEMENTED)
        case 'stream_speed': # (x1+x2)/2
            x1 = node.children[0]
            x2 = node.children[1]
            
            # x1+x2
            temp = Node(Op.ADD.value)
            temp.children = (match_node(x1), match_node(x2))
            
            # (x1+x2)/2
            node.name = Op.DIV.value
            node.children = (temp, Node(Const.CONST_2.value))
        case 'choose':
            raise Exception(UNIMPLEMENTED)
        case _: 
            # If case is not found raise an exception (can either be a constant or a number)
            if not node.name in Const._value2member_map_:   
                try:
                    float(node.name)
                except ValueError:
                    raise Exception(f'Not found: {node.name}')
            else:
                node.name = convert_const[node.name]
    return node

The following function evaluates a given expression tree given its root. The tree is expected to already have been preprocessed using preprocess_tree

In [7]:
def eval_tree_helper(node):
    if node is None:
        return None
    
    if node.name == Op.ADD.value:
        return eval_tree_helper(node.children[0]) + eval_tree_helper(node.children[1])
    elif node.name == Op.SUB.value:
        return eval_tree_helper(node.children[0]) - eval_tree_helper(node.children[1])
    elif node.name == Op.MULT.value:
        return eval_tree_helper(node.children[0]) * eval_tree_helper(node.children[1])
    elif node.name == Op.DIV.value:
        return eval_tree_helper(node.children[0]) / eval_tree_helper(node.children[1])
    elif node.name == Op.POW.value:
        return eval_tree_helper(node.children[0]) ** eval_tree_helper(node.children[1])
    elif node.name in const2val:
        return const2val[node.name]
    else:
        try:
            num = float(node.name)
        except ValueError:
            raise Exception(f'Not found: {node.name}')
        return num
    
def eval_tree(node):
    try:
        return eval_tree_helper(node)
    except ZeroDivisionError as e: # divion by zero
        return None
    except OverflowError as e:
        return None

This is very similar to the previous function, except this time we are returning the unevaluated formula

In [8]:
def eval_tree_formula(node, use_const=True):
    if node is None:
        return None
    
    if node.name == Op.ADD.value:
        return f'({eval_tree_formula(node.children[0], use_const)} + {eval_tree_formula(node.children[1], use_const)})'
    elif node.name == Op.SUB.value:
        return f'({eval_tree_formula(node.children[0], use_const)} - {eval_tree_formula(node.children[1], use_const)})'
    elif node.name == Op.MULT.value:
        return f'({eval_tree_formula(node.children[0], use_const)} * {eval_tree_formula(node.children[1], use_const)})'
    elif node.name == Op.DIV.value:
        return f'({eval_tree_formula(node.children[0], use_const)} / {eval_tree_formula(node.children[1], use_const)})'
    elif node.name == Op.POW.value:
        return f'({eval_tree_formula(node.children[0], use_const)} ^ {eval_tree_formula(node.children[1], use_const)})'
    elif node.name in const2val:
        if use_const:
            return node.name
        else:
            return f'{const2val[node.name]}'
    else:
        try:
            num = float(node.name)
        except ValueError:
            raise Exception(f'Not found: {node.name}')
        return f'{num}'

Here answers are read from the MathQA dataset so they can be verified with the expected formula result. The MathQA dataset contains a massive amount of noise, so I use a huge set of regex's to properly get the numeric answer given a problem. Solutions that contain variables, multiple answers, none of the above, etc. are returned as None, indicating that the answer is not suitable for our task of generating mathematical formulas. Code is also included that evaluates formulas within the answers to a single float value for easier comparison.

In [9]:
prefix = [
    'year', 'month', 'day', 'hour', 'hr', 'h', 'minute', 'min', 'per kg', 'second', 'sec', 's', 'km', 'cubic meter', 'metre sq', 
    'meter', 'mtr', 'mps', 'gal', 'm \\^ 2', 'percent increase', 'percent decrease', 'less', 'm ³', 'answer', 'round',
    'm / s', 'm', 'a 2 cm', 'cm', 'tile', 'ohm', 'cm ²', 'm sqaure', 'rupee', 'monkey', 'cm m ³', 'cm square', 'pure',
    'kg', 'mph', 'o kmph', 'kmph', 'mile', 'coin', 'liter', 'square inches', 'inch', 'inches', 'foot', 'feet', 'litre', 'a',
    'degrees c', 'degree', 'bc', 'am', r'a \. m', r"p \. m", 'rs', 'metre', 'km / hr', 'kmh', 'm (east|west|north|south)', 'ft', 'are', 
    'loss', 'gain', 'yard', 'lb', 'pm', 'increase', 'decrease', 'prime number', 'number', 'toy', 'd', 'unit', 'apple', 'deg', 
    'time', r'\$', 'cubic \\. cm', 'hectare', 'of a', '× x', 'b', 
    'more', 'men', 'o', 'mts', 'cc', 'lt', 'yr', 'profit', 'x', 'c', r'cm \^ 2', 'z', 'dm', 'nd', 'percent', 'cm cube', 'mark',
    'noon', 'no change', 'º', 'cubic', 'acre', r'cm ³', 'gallon', 'soldier', 'm cube', 't', 'minw', 'women', 'of', 'ab',
    'v â € ™', 'cube', 'r', 'l', 'm â ²', 'y ² / z ²', 'meters (east|west|north|south)', 'b \\^ 2', 'excess', 'gm', 'm square',
    'st', 'd', 'p', 'rd',
]

remove = [r'\s(%\s|cu\s|th\s|st\s|rd\s|sq\s|square\s)?(\.\s)?' + p + r"(\s?['.23s\]]?)*$" for p in prefix]

remove.extend(r'\s?' + x + r"(\s?\.?)*$" for x in ['%', r'\+', ''])

remove.extend([
    r'^[A-Za-z] = (\$ )?',
    '^no = ', '^green balls = ',
    r"°(\s?['.\]fc]?)*$",
    r"sq(\s?['.23s\]]?)*$",
    r"th(\s?['.23s\]]?)*$",
    r'^rs \. ', '^rs ', r'^s \. ', '^s : ', '^s ',
    r'^\$ \$ ', r'^\$ ',
    r'^\+ ',
    r'^increases by ', r'^decreases by ', r'^after ', r'^becomes ', r'^divisible by ', r'^rupees ', r'^ratio ', r'^and ', r'^euro ',
    r'^a \)( \$)? ', r'^b \)( \$)? ', r'^c \)( \$)? ', r'^d \)( \$)? ', r'^e \)( \$)? ',
    r'^a : ', r'^b : ', r'^c : ', r'^d : ', r'^e : ',
    r'^a \. ', r'^b \. ', r'^c \. ', r'^d \. ', r'^e \. ',
    r'^\( a \) ', r'^\( b \) ', r'^\( c \) ', r'^\( d \) ', r'^\( e \) ',
    r'^a ', r'^b ', r'^c ', r'^d ', r'^e ',
])

more_special = {
    ' ⁄ ': ' / ',
    ' · ': ' ',
    'π|pi|∏': f'* {math.pi}',
    ' billion$': f' * {1e9}',
    ' million$': f' * {1e6}',
    '^− ': '-',
    '^– ': '-',
    '^- ': '-',
    '^one$': '1',
    '^two$': '2',
    '^three$': '3',
    '^four': '4',
    '^five$': '5',
    '^six$': '6',
    '^seven$': '7',
    '^eight$': '8',
    '^nine$': '9',
    '^ten$': '10',
    '^eleven$': '11',
    '^twelve$': '12',
    r'^none.*|^ca\sn\s\'\st\s.*|^all( of)? the.*|^iii\sonly|^(it )?can\snot.*|^no\sanswer|^undefined|^unidentified|^not\spossible|^insufficient data|^undentified|^not enough info.*|^data\sinadequate|^indeterminate': 'None',
    r'^\. ': r'.',
    r' ½$': ' + .5',
    r'^½$': '.5',
}

to_remove = re.compile('|'.join(remove))

def special_cases(problem, answer):
    num = r'[+-]?((\d+(\.\d*)?)|(\.\d+))'
    big = r'-?\d{1,3}(\s?,\s?\d{3})+(\.\d*)?'
    
    answer = answer.replace('/ *', '/')
    answer = answer.replace('â','')
    answer = answer.replace(' ∗', '')
    
    if m := re.match(f'^(?P<num1>{num}) (/|\\\\) (?P<num2>{num})$', answer): # x1 / x2
        answer = float(m.group('num1')) / float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) / (?P<num2>{big})$', answer): # x1 / x2 (comma separated)
        answer = float(m.group('num1')) / float(m.group('num2').replace(',',''))
        
    elif m := re.match(f'^(?P<num1>{num}) (:|ratio) (?P<num2>\\d+)$', answer): # x1 : x2
        if 'ratio' in problem['Problem'] or (len(m.group('num1'))==1 and len(m.group('num2'))==1):
            answer = float(m.group('num1')) / float(m.group('num2'))
        elif 'time' in problem['Problem'] or m.group('num2')=='00':
            answer = float(m.group('num1')) + (float(m.group('num2'))/60)
        else:
            answer = float(m.group('num1')) / float(m.group('num2'))
            
    elif m := re.match(f'^(?P<num1>{num}) \\* (?P<num2>{num})$', answer): # x1 * x2
        answer = float(m.group('num1')) * float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) \\+ (?P<num2>{num})$', answer): # x1 + x2
        answer = float(m.group('num1')) + float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) - (?P<num2>{num})$', answer): # x1 - x2
        answer = float(m.group('num1')) - float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) \\^ (?P<num2>{num})$', answer): # x1 ^ x2
        answer = float(m.group('num1')) ** float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num3>{num}) \\( (?P<num1>{num}) \\^ (?P<num2>{num} )\\)$', answer): # x1 (x2 ^ x3)
        answer = float(m.group('num3')) * (float(m.group('num1')) ** float(m.group('num2')))
        
    elif m := re.match(f'^\\( (?P<num1>{num}) / (?P<num2>{num} )\\) \\^ (?P<num3>{num})$', answer): # (x1 / x2) ^ x3
        answer = (float(m.group('num1')) / float(m.group('num2'))) ** float(m.group('num3'))
        
    elif m := re.match(f'^(?P<num1>{num}) \\^ (?P<num2>{num}) - (?P<num3>{num})$', answer): # x1 ^ x2 - x3
        answer = (float(m.group('num1')) ** float(m.group('num2'))) - float(m.group('num3'))
        
    elif m := re.match(f'^(?P<num3>{num}) \\( (?P<num1>{num}) / (?P<num2>{num} )\\)$', answer): # x1 (x2 / x3)
        answer = float(m.group('num3')) * (float(m.group('num1')) / float(m.group('num2')))
        
    elif m := re.match(f'^(?P<num3>{num}) \\+ \\( (?P<num1>{num}) / (?P<num2>{num} )\\)$', answer): # x1 + (x2 / x3)
        answer = float(m.group('num3')) + (float(m.group('num1')) / float(m.group('num2')))
        
    elif m := re.match(f'^(?P<num1>{num}) √ (?P<num2>{num})$', answer): # x1 √ x2
        answer = float(m.group('num1')) * float(m.group('num2'))**(1/2)
        
    elif m := re.match(f'^(?P<num1>{num}) / √ (?P<num2>{num})$', answer): # x1 / √ x2
        answer = float(m.group('num1')) / float(m.group('num2'))**(1/2)
        
    elif m := re.match(f'^(√|sqrt) (?P<num1>{num})$', answer): # √ x2
        answer = float(m.group('num1'))**(1/2)
        
    elif m := re.match(f'^(?P<num1>{num}) √ \\( (?P<num2>{num}) / (?P<num3>{num}) \\)$', answer): # x1 √ ( x2 / x3 )
        answer = float(m.group('num1')) * (float(m.group('num2'))/float(m.group('num3')))**(1/2)
    
    elif m := re.match(f'^(?P<num1>{num}) √ (?P<num2>{num}) \\* {math.pi}$', answer): # x1 √ x2 * pi
        answer = float(m.group('num1')) * float(m.group('num2'))**(1/2) * math.pi
        
    elif m := re.match(f'^(?P<num1>{num}) \\* {math.pi} - (?P<num2>{num})$', answer): # x1 * pi - x2
        answer = float(m.group('num1')) * math.pi - float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) \\+ √ (?P<num2>{num})$', answer): # x1 + √ x2
        answer = float(m.group('num1')) + float(m.group('num2'))**(1/2)
        
    elif m := re.match(f'^-\\* {math.pi}$', answer): # x1 * pi - x2
        answer = -math.pi
        
    elif m := re.match(f'^(?P<num1>{num}) e - (?P<num2>{num})$', answer): # x1 e - num2
        answer = float(m.group('num1')) * (10**-float(m.group('num2')))
        
    elif m := re.match(f'^(?P<num1>{num}) \\[ (?P<num2>{num}) / (?P<num3>{num}) \\]$', answer): # x1 [ x2 / x3 ]
        answer = float(m.group('num1')) + float(m.group('num2')) / float(m.group('num3'))
        
    elif m := re.match(f'^(\\[|\\() (?P<num1>{num}) (\\]|\\))$', answer): # [ x1 ]
        answer = float(m.group('num1'))
        
    elif m := re.match(f'^(?P<num1>{num})( \\+)? (?P<num2>{num}) / (?P<num3>{num})$', answer): # x1 x2 / x3
        answer = float(m.group('num1')) + float(m.group('num2'))/float(m.group('num3'))
        
    elif m := re.match(f'^(?P<num1>{num}) !$', answer): # x1 !
        answer = math.factorial(int(m.group('num1')))
    
    elif m := re.match(f'^(?P<num1>{num}) / (?P<num2>{num}) !$', answer): # x1 / x1 !
        answer = float(m.group('num1')) / math.factorial(int(m.group('num2')))
        
    elif m := re.match(f'^{big}$', answer): # number is a big comma separated number
        answer = answer.replace(' , ', '')
        answer = answer.replace(',', '')
        
    elif m := re.match(f'^(?P<num1>{num}) (hour|hr)s?( and| \\.)? (?P<num2>{num})', answer): # x1 hours and x2 minutes
        answer = float(m.group('num1')) * 60 + float(m.group('num2'))
        
    elif m := re.match(f'^(?P<num1>{num}) (min|minute)s?( ,)? (?P<num2>{num})', answer): # x1 minutes and x2 seconds
        answer = float(m.group('num1')) * 60 + float(m.group('num2'))
        
    elif m := re.match(r'\d+\s?(,|and)\s?\d+', answer): # if not a big number, problem has multiple answers, which we do not allow (This only accounts for roughly 200 problems)
        answer = None
        
    elif m := re.match(r'^sec$|^mx$|^l$|^m$|^positive$|^no$|^e$|^d$|^p$', answer): # if is just one of these, has no answer
        answer = None
        
    elif m := re.search(r'(^|\s)[A-Za-z]{1,2}\s|and|,|even|odd', answer): # contains a variable so isnt considered
        answer = None
    
    # Changing remaining formulas so they can be evaluated using eval
    if type(answer) == str:
        answer = answer.replace('^', '**')
        answer = answer.replace('( *', '(')
        answer = answer.replace(':', '/')
        answer = re.sub(r'√\s(\d+)', r'math.sqrt(\1)', answer)
        answer = re.sub(r'sqrt\s(\d+)', r'math.sqrt(\1)', answer)
        answer = re.sub(r'sqr\s(\d+)', r'math.sqrt(\1)', answer)
        answer = re.sub(r'^sqrt|^√', r'math.sqrt', answer)
        answer = re.sub(r'^(\d+) (\(|math|sqrt)', r'\1 * \2', answer)
        
    return answer
    
# Expects data['train/test/validation'][idx]
def get_answer(problem):
    numeric_answer = None
    answer_idx = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5}
    
    if m := re.match(r"^a \) (.*) , b \) (.*) , c \) (.*) , d \) (.*) , e \) (.*)$", problem['options']):
        answer = m.group(answer_idx[problem['correct']])
    elif m := re.match(r"^\[\'a \) (.*)\', \'b \) (.*)\', \'c \) (.*)\', \'d \) (.*)\', \'e \) (.*)\'\]$", problem['options']):
        answer = m.group(answer_idx[problem['correct']])
    else:
        answer = None
        print(f'{problem["options"]}')
        print('\n\n')
        
    if answer is not None:
        answer = re.sub(to_remove, '', answer)

        for k,v in more_special.items():
            answer = re.sub(k, v, answer)
            
        answer = special_cases(problem, answer)
        
        if answer == 'None':
            answer = None

        try:
            numeric_answer = float(answer)
        except:
            # raise Exception(f'Invalid answer: {answer}')
            if answer is not None:
                try:
                    numeric_answer = float(eval(answer))
                except:
                    numeric_answer = None
    return numeric_answer


This function converts a preprocessed expression tree into a set of labels that generates the equation incrementally. Each step contains all independent sub functions starting from the leaf level of the trees, moving upwards. Each step is padded with None to a length of K. For example, if we have the equation (1 * (2 * 3) * (3 * 4)), we could write steps as [2 * 3, 3 * 4, None, None, None] and [var1 * var2, None, None, None, None, None]. This is explained in greater detail later.

In [10]:
def create_stepwise_labels(tree):
    # Getting tree layer by layer
    layers = []
    for x in LevelOrderGroupIter(tree):
        get_name = lambda node: node.name
        layers.insert(0, list(map(get_name, x)))
    size = len(layers)-1
    
    # Creating step wise labels
    labels = np.full((size,6), None)
    for idx1 in range(size):
    
        # Getting the operation for each pair of the current layer
        # Also sets 2nd layer ops to reference previous layer by marking as 'var'
        ops = []
        count = 1
        for idx2, x in enumerate(layers[idx1+1]):
            if x in Op._value2member_map_.keys():
                ops.append(x)
                layers[idx1+1][idx2] = f'var{count}'
                count += 1       
    
        # Creating the labels
        for idx2, (x1, x2) in enumerate(zip(layers[idx1][::2], layers[idx1][1::2])):
            if idx2 >= K: # If more independent subformulas than K, we return None
                return None
            labels[idx1,idx2] = f'{x1}{ops[idx2]}{x2}'

    # Returns semicolon delimited label set
    return ' ; '.join([list_str(x) for x in labels])

def list_str(arr):
    result = f"['{arr[0]}'"
    for x in arr[1:]:
        if x:
            result += f", '{x}'"
        else:
            result += f', {x}'
    result += ']'
    return result

Simple function that returns true if two float values are within a certain threshold. This allows values like 80.0 and 80.00000001 to be marked as equal. 0.5 is chosen as the threshold to account for rounding.

In [11]:
def float_equal(x1, x2):
    return x1 is not None and x2 is not None and abs(x1-x2) <= 0.5

As mentioned, the MathQA dataset is contains a lot of noise. As a result, expected answers are verified with the annotated formula to make sure both math. If these do no match, the example is thrown out. This is consistent with other literature on the dataset (CITE SOME OF THE PAPERS THAT DO THIS).

The following function saves the final preprocessed dataset to ../MathQA/x.csv

In [12]:
def preprocessing_pipeline(name):
    # Converting MathQA annotated formula into expression tree
    trees = [create_tree(x) for x in data[name]['annotated_formula']]
    
    # For each data entry
    problems = []
    solutions = []
    new_formulas = []
    new_formulas_no_const = []
    original_formula = []
    incremental = []
    expression_trees = []
    for idx in range(len(data[name])):
        # Converting expression tree to only use +,-,/,*,^ (This also marks unincluded operations like sine as None)
        fixed_tree = preprocess_tree(trees[idx])
    
        # Getting the answer the problem marked as correct
        correct_answer = get_answer(data[name][idx])
    
        # Getting the answer the expression tree evaluates to (evaluating None returns None)
        formula_answer = eval_tree(fixed_tree)
    
        # If the correct and formula answers match, add it to the data (If either is None, they do not equal)
        if float_equal(correct_answer, formula_answer):
            problems.append(data[name]['Problem'][idx])
            solutions.append(correct_answer)
            new_formulas.append(eval_tree_formula(fixed_tree))
            new_formulas_no_const.append(eval_tree_formula(fixed_tree, use_const=False))
            original_formula.append(data[name]['annotated_formula'][idx])
            incremental.append(create_stepwise_labels(fixed_tree))
            expression_trees.append(DictExporter().export(fixed_tree))
    
    df = pd.DataFrame()
    df['problem'] = problems
    df['solution'] = solutions
    df['formula'] = new_formulas
    df['formula_no_const'] = new_formulas_no_const
    df['annotated_formula'] = original_formula
    df['incremental'] = incremental
    df['tree'] = expression_trees
    
    if not os.path.exists(DIRECTORY):
        os.makedirs(DIRECTORY)
    df.to_csv(DIRECTORY+f'{name}.csv', index=False)

In [13]:
preprocessing_pipeline('test')
preprocessing_pipeline('validation')
preprocessing_pipeline('train')