# MathQA Training

#### Imports

In [134]:
from enum import Enum
import os
import anytree
import pandas as pd
from itertools import permutations
import seaborn as sns
import math
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from sklearn.metrics import f1_score, accuracy_score
from datasets import Dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from sklearn.utils.class_weight import compute_class_weight
import re

#### Constants

In [168]:
DATA_PATH = './dataset/'
SET_NAMES = ['train', 'validation', 'test']

class Op(Enum):
    ADD = '+'
    SUB = '-'
    MULT = '*'
    DIV = '/'
    POW = '^'
    
class Const(Enum):
    CONST_PI = 'const_pi'
    CONST_NEG_1 = 'const_neg_1' # I added this
    CONST_DEG_TO_RAD = 'const_deg_to_rad' # pi / 180 (There is only one example of this and its actually used incorrectly)
    CONST_1 = 'const_1'
    CONST_2 = 'const_2'
    CONST_3 = 'const_3'
    CONST_4 = 'const_4'
    CONST_5 = 'const_5'
    CONST_6 = 'const_6'
    CONST_10 = 'const_10'
    CONST_12 = 'const_12'
    CONST_26 = 'const_26'
    CONST_52 = 'const_52'
    CONST_60 = 'const_60'
    CONST_100 = 'const_100'
    CONST_180 = 'const_180'
    CONST_360 = 'const_360'
    CONST_1000 = 'const_1000'
    CONST_3600 = 'const_3600' 
    CONST_0_25 = 'const_0_25'
    CONST_0_2778 = 'const_0_2778'
    CONST_0_33 = 'const_0_33'
    CONST_0_3937 = 'const_0_3937'
    CONST_0_4535 = 'const_0_4535'
    CONST_0_6 = 'const_0_6'
    CONST_1_6 = 'const_1_6'
    CONST_2_2046 = 'const_2_2046'
    CONST_2_54 = 'const_2_54'
    CONST_3_6 = 'const_3_6' 
    CONST_0dot25 = 'const_0.25'
    CONST_0dot5 = 'const_0.5' 
    CONST_2dot0 = 'const_2.0'
    CONST_3dot0 = 'const_3.0'
    CONST_4dot0 = 'const_4.0'
    CONST_60dot0 = 'const_60.0'
    CONST_100dot0 = 'const_100.0'

values = [math.pi, -1, math.pi/180, 1, 2, 3, 4, 5, 6, 10, 12, 26, 52, 60, 100, 180, 360, 1000, 3600,
          0.25, 0.2778, 1/3, 0.3937, 0.4535, 0.6, 1.6, 2.2046, 2.54, 3.6, 0.25, 0.5, 2.0, 3.0, 4.0, 60.0, 100.0]
const2val = {k:v for k,v in zip(Const._value2member_map_.keys(), values)}    
const2val['const_0_5'] = 0.5

op2id = {k:v for k,v in zip(Op._value2member_map_.keys(), range(len(Op._value2member_map_)))}

## Loading the data

Reading csv into a dictionary of dataframes

In [245]:
data = {name:pd.read_csv(f'{DATA_PATH}{name}.csv') for name in SET_NAMES}

Converts operations for each problem into a multi label onehot encoded setup

In [170]:
def onehot_ops(data):
    labels = []
    for op_set in data.ops:
        op_set = eval(op_set)
        idx = [op2id[op] for op in op_set]
        onehot = np.zeros(len(op2id))
        onehot[idx] = 1
        labels.append(onehot)
    return np.array(labels)
        
#onehot_ops(data['train'])

Sort nums for each each problem in increasing order

In [221]:
def max_num(nums):
    get_float = lambda x: float(const2val[x]) if x in const2val else float(x)
    return max(map(get_float, nums))

def remove_const(data):
    nums = []
    for num_list in data.nums:
        nums.append(set([float(x) for x in eval(num_list) if x not in const2val]))
    return nums

def get_nums_from_problem(data):
    nums = []
    for problem in data.problem:
        num = re.compile('([+-]?((\d+(\.\d*)?)|(\.\d+)))')
        big = re.compile(r'(-?\d{1,3}(,\d{3})+(\.\d*)?)')
        
        big_results = re.findall(big, problem)
        problem = re.sub(big, '', problem)
        s1 = set([float(x[0].replace(',','')) for x in big_results])
        
        num_results = re.findall(num, problem)
        s2 = set([float(x[0]) for x in num_results])
        
        nums.append(s1.union(s2))
    return nums

def sort_nums(data):
    nums_sorted = []
    nums_no_const_sorted = []
    for nums in data.nums_no_const:
        nums_no_const_sorted.append(sorted(list(eval(nums)), key=lambda x: float(x)))
    for nums in data.nums:
        num_list = list(eval(nums))
        maximum = max_num(num_list)
        get_float = lambda x: float(const2val[x])+maximum if x in const2val else float(x)
        nums_sorted.append(sorted(num_list, key=get_float))
    return nums_sorted, nums_no_const_sorted

#sort_nums(data['train'])

Here I do some testing to see if the numbers from the equation can be found in the problem description using simple regexes. This actually works extremely well, having only a single example where the expected numbers is not a subset of the obtained numbers. This does not include constants. Constants are values which should not occur in the problem description (like pi or the 2 in r^2 for example)

In [235]:
expected = remove_const(data['validation'])
obtained = get_nums_from_problem(data['validation'])

idx = 0
for x, y in zip(expected, obtained):
    if not (x <= y):
        print('------------------')
        print(data['validation']['problem'][idx])
        print(f'Expected: {x}')
        print(f'Obtained: {y}')
        print('------------------')
    idx += 1

------------------
problem                 if you roll a fair - sided die twice , what is...
category                                                          general
solution                                                         0.166667
formula                 (((55.0 - ((const_4 * const_10) / const_100)) ...
formula_no_const        (((55.0 - ((4 * 10) / 100)) + 2) / (((55.0 - (...
annotated_formula       divide(add(subtract(5,5, divide(multiply(const...
incremental             ['const_4*const_10', 'const_4*const_10', None,...
nums                    {'const_2', 'const_100', 'const_10', 'const_4'...
ops                                                  {'-', '+', '/', '*'}
incremental_no_const    ['4*10', '4*10', None, None, None, None] ; ['4...
nums_no_const                               {'2', '100', '10', '4', '55'}
ops_no_const                                         {'-', '+', '/', '*'}
tree                    {'name': '/', 'children': [{'name': '+', 'chil...
Name: 2879, dtype: 

In [244]:
data['train']['category'].value_counts()

category
general        7721
physics        5243
gain           3752
geometry       1506
other          1164
probability     196
Name: count, dtype: int64

In [240]:
len(data['train'])

19582

In [241]:
len(data['test'])

1938

In [242]:
len(data['validation'])

2959