In [1]:
import re, ast, math, operator
from collections import Counter
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


### Load GSM8k dataset

In [2]:
dataset = load_dataset("gsm8k", 'main')
dataset, dataset['train'][0]

Downloading readme: 100%|██████████| 7.94k/7.94k [00:00<00:00, 4.85MB/s]
Downloading data: 100%|██████████| 2.31M/2.31M [00:00<00:00, 6.17MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 1.10MB/s]]
Downloading data files: 100%|██████████| 2/2 [00:00<00:00,  2.56it/s]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 570.85it/s]
Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 199965.77 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 182138.90 examples/s]


(DatasetDict({
     train: Dataset({
         features: ['question', 'answer'],
         num_rows: 7473
     })
     test: Dataset({
         features: ['question', 'answer'],
         num_rows: 1319
     })
 }),
 {'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
  'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72'})

### Convert equations to toolcalls

Use a helper to savely evaluate equations to easily get their result values

In [3]:
def safe_eval(s):

    def checkmath(x, *args):
        if x not in [x for x in dir(math) if not "__" in x]:
            raise SyntaxError(f"Unknown func {x}()")
        fun = getattr(math, x)
        return fun(*args)

    binOps = {
        ast.Add: operator.add,
        ast.Sub: operator.sub,
        ast.Mult: operator.mul,
        ast.Div: operator.truediv,
        ast.Mod: operator.mod,
        ast.Pow: operator.pow,
        ast.Call: checkmath,
        ast.BinOp: ast.BinOp,
    }

    unOps = {
        ast.USub: operator.neg,
        ast.UAdd: operator.pos,
        ast.UnaryOp: ast.UnaryOp,
    }

    ops = tuple(binOps) + tuple(unOps)

    tree = ast.parse(s, mode='eval')

    def _eval(node):
        if isinstance(node, ast.Expression):
            return _eval(node.body)
        elif isinstance(node, ast.Str):
            return node.s
        elif isinstance(node, ast.Num):
            return node.value
        elif isinstance(node, ast.Constant):
            return node.value
        elif isinstance(node, ast.BinOp):
            if isinstance(node.left, ops):
                left = _eval(node.left)
            else:
                left = node.left.value
            if isinstance(node.right, ops):
                right = _eval(node.right)
            else:
                right = node.right.value
            return binOps[type(node.op)](left, right)
        elif isinstance(node, ast.UnaryOp):
            if isinstance(node.operand, ops):
                operand = _eval(node.operand)
            else:
                operand = node.operand.value
            return unOps[type(node.op)](operand)
        elif isinstance(node, ast.Call):
            args = [_eval(x) for x in node.args]
            r = checkmath(node.func.id, *args)
            return r
        else:
            raise SyntaxError(f"Bad syntax, {type(node)}")

    return _eval(tree)

In [4]:
safe_eval('1 + 1')

2

Recursively process all equations and convert them to toolcalls. Make sure that nested equations are converted to multiple toolcalls (which are binary ops) bottom-up. 

In [5]:
s = 'Let S be the amount Alexis paid for the shoes.\nShe spent S + 30 + 46 + 38 + 11 + 18 = S + <<+30+46+38+11+18=143>>143.\nShe used all but $16 of her budget, so S + 143 = 200 - 16 = 184.\nThus, Alexis paid S = 184 - 143 = $<<184-143=41>>41 for the shoes.\n#### 41'

op__to_toolcall = {
    ast.Add: 'add',
    ast.Sub: 'subtract',
    ast.Mult: 'multiply',
    ast.Div: 'divide',
}

def depth_ast(root):
    return 1 + max(map(depth_ast, ast.iter_child_nodes(root)),
                   default = 0)

def rec_tree_to_toolcalls(tree):

    if isinstance(tree, ast.UnaryOp):
        if isinstance(tree.op, ast.UAdd):
            return rec_tree_to_toolcalls(tree.operand)
        elif isinstance(tree.op, ast.USub):
            return rec_tree_to_toolcalls(tree.operand)
        else:
            raise NotImplementedError()
    elif depth_ast(tree) == 1:
        return []

    value = safe_eval(ast.unparse(tree))
    left_value = safe_eval(ast.unparse(tree.left))
    right_value = safe_eval(ast.unparse(tree.right))
    op = tree.op

    toolcall = f"<T>{op__to_toolcall[type(op)]}({left_value}, {right_value})={value}"
    toolcalls = rec_tree_to_toolcalls(tree.left) + rec_tree_to_toolcalls(tree.right) + [toolcall]
    return toolcalls


def change_equations_to_toolcall(row):
    s = row['answer']

    equations_str = re.findall(r'<<(.+?)>>(\d*\.?\d*,?\d*,?\d*,?\d*)', s)
    equations_ast = [e[0].split('=')[0] for e in equations_str]
    equations_ast = [ast.parse(e).body[0].value for e in equations_ast]
    depths = [depth_ast(e) for e in equations_ast]

    toolcalls = [rec_tree_to_toolcalls(e) for e in equations_ast]
    return {
        'equations': [f"<<{s[0]}>>{s[1]}" for s in equations_str],
        'depths': depths,
        'toolcalls': toolcalls,
    }

change_equations_to_toolcall({'answer': '<<1+-(2*5)=1>>'})
change_equations_to_toolcall(dataset['train'][8])

{'equations': ['<<+30+46+38+11+18=143>>143.', '<<184-143=41>>41'],
 'depths': [6, 2],
 'toolcalls': [['<T>add(30, 46)=76',
   '<T>add(76, 38)=114',
   '<T>add(114, 11)=125',
   '<T>add(125, 18)=143'],
  ['<T>subtract(184, 143)=41']]}

In [6]:
def try_change_equations_to_toolcall(row):
    try:
        return change_equations_to_toolcall(row)
    except Exception as e:
        print(f"Exception {e} on {row}")
        return {'equations': [], 'depths': [], 'toolcalls': []}

dataset = dataset.map(try_change_equations_to_toolcall, load_from_cache_file=False)
dataset = dataset.filter(lambda x: len(x['equations']) > 0)

dataset

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:  10%|▉         | 738/7473 [00:00<00:02, 2492.63 examples/s]

Exception <class 'ast.FloorDiv'> on {'question': 'Chad sandwiches 2 crackers with a scoop of peanut butter.  He has 5 of these crackers a night before bed.  A box of crackers has 4 sleeves with each sleeve holding 28 crackers.  How many nights will 5 boxes of crackers last him?', 'answer': 'Chad uses 2 crackers for each "sandwich" and has 5 sandwiches a night so he eats 2*5 = <<2*5=10>>10 crackers a night\nThe box has 4 sleeves of crackers and each sleeve has 28 crackers for a total of 4*28 = <<4*28=112>>112 crackers\n1 box has 112 crackers so 5 boxes will have 112*5 = <<112*5=560>>560 crackers\nHe eats 10 crackers a night and 5 boxes have 560 crackers so they will last him 560//10 = <<560//10=56>>56 nights\n#### 56'}


Map:  65%|██████▌   | 4859/7473 [00:02<00:01, 2387.95 examples/s]

Exception <class 'ast.FloorDiv'> on {'question': 'An apple tree has three times as many apples as the number of plums on a plum tree. If Damien picks 3/5 of the fruits from the trees, calculate the total number of plums and apples remaining on the tree if there were 180 apples on the apple tree before Damien picked any of the fruits.', 'answer': 'Initially, the were 180//3= <<180//3=60>>60 plums on the plum tree.\nWhen Damien picked 3/5 of the apples from the apple tree, he picked 3/5*180 = <<3/5*180=108>>108 apples\nThe number of apples remaining on the apple tree is 180-108= <<180-108=72>>72\nThe number of plums that Damien picked from the plum tree is 3/5*60 = <<3/5*60=36>>36\nThe number of plums remaining on the plum tree is 60-36 = <<60-36=24>>24\nThe total number of fruits remaining on the trees is 72 apples+24 plums = <<72+24=96>>96\n#### 96'}


Map: 100%|██████████| 7473/7473 [00:03<00:00, 2281.03 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 2346.22 examples/s]
Filter: 100%|██████████| 7473/7473 [00:00<00:00, 28827.96 examples/s]
Filter: 100%|██████████| 1319/1319 [00:00<00:00, 26752.26 examples/s]


DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 7376
    })
    test: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 1301
    })
})

Check how deep our equations are. Equations of a depth >2 are challenging as we do not know how to convert them to proper text with explanations. 

In [7]:
depth_counter = Counter([d for row in dataset['train'] for d in row['depths']])
depth_counter

Counter({2: 20153, 3: 2754, 1: 350, 4: 345, 5: 75, 6: 22, 7: 6, 10: 1})

Remove rows with too deep or shallow toolcalls as they are not suitable for our dataset

In [8]:
dataset_suitable = dataset\
    .filter(lambda x: 1 not in x['depths'])\
    .filter(lambda x: {2} == set(x['depths']))\
    .filter(lambda x: all([len(t) == 1 for t in x['toolcalls']]))
dataset_suitable

Filter: 100%|██████████| 7376/7376 [00:00<00:00, 15838.46 examples/s]
Filter: 100%|██████████| 1301/1301 [00:00<00:00, 18727.72 examples/s]
Filter: 100%|██████████| 7048/7048 [00:00<00:00, 19591.60 examples/s]
Filter: 100%|██████████| 1240/1240 [00:00<00:00, 9618.81 examples/s]
Filter: 100%|██████████| 4591/4591 [00:00<00:00, 19416.25 examples/s]
Filter: 100%|██████████| 793/793 [00:00<00:00, 18741.24 examples/s]


DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 4587
    })
    test: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 791
    })
})

In [9]:
depth_counter = Counter([d for row in dataset_suitable['train'] for d in row['depths']])
depth_counter

Counter({2: 13695})

In [10]:
dataset_suitable['train'][5]

{'question': 'Ken created a care package to send to his brother, who was away at boarding school.  Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds.  Then, he added enough brownies to cause the weight to triple.  Next, he added another 2 pounds of jelly beans.  And finally, he added enough gummy worms to double the weight once again.  What was the final weight of the box of goodies, in pounds?',
 'answer': 'To the initial 2 pounds of jelly beans, he added enough brownies to cause the weight to triple, bringing the weight to 2*3=<<2*3=6>>6 pounds.\nNext, he added another 2 pounds of jelly beans, bringing the weight to 6+2=<<6+2=8>>8 pounds.\nAnd finally, he added enough gummy worms to double the weight once again, to a final weight of 8*2=<<8*2=16>>16 pounds.\n#### 16',
 'equations': ['<<2*3=6>>6', '<<6+2=8>>8', '<<8*2=16>>16'],
 'depths': [2, 2, 2],
 'toolcalls': [['<T>multiply(2, 3)=6'],
  ['<T>add(6, 2)=8'],
  ['<T>multip

Now, we only have equations of depth=2, i.e. equations that are a single toolcall. We convert our dataset to contain toolcalls in the answer column

In [11]:
def replace_equations_with_toolcalls(row):
    answer = row['answer']
    for equation, toolcall in zip(row['equations'], row['toolcalls']):
        assert len(toolcall) == 1, f"Equation {equation} has multiple tool calls {toolcall}"
        toolcall = toolcall[0]
        answer = answer.replace(equation, toolcall)
    return {'answer': answer}

dataset_suitable = dataset_suitable.map(replace_equations_with_toolcalls)
dataset_suitable = dataset_suitable.map(lambda row: {'answer': re.sub(r'#### \d+', '', row['answer'])})

dataset_suitable, dataset_suitable['train'][5]

Map:   0%|          | 0/4587 [00:00<?, ? examples/s]

Map: 100%|██████████| 4587/4587 [00:00<00:00, 5834.60 examples/s]
Map: 100%|██████████| 791/791 [00:00<00:00, 5340.85 examples/s]
Map: 100%|██████████| 4587/4587 [00:00<00:00, 9286.08 examples/s] 
Map: 100%|██████████| 791/791 [00:00<00:00, 10243.21 examples/s]


(DatasetDict({
     train: Dataset({
         features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
         num_rows: 4587
     })
     test: Dataset({
         features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
         num_rows: 791
     })
 }),
 {'question': 'Ken created a care package to send to his brother, who was away at boarding school.  Ken placed a box on a scale, and then he poured into the box enough jelly beans to bring the weight to 2 pounds.  Then, he added enough brownies to cause the weight to triple.  Next, he added another 2 pounds of jelly beans.  And finally, he added enough gummy worms to double the weight once again.  What was the final weight of the box of goodies, in pounds?',
  'answer': 'To the initial 2 pounds of jelly beans, he added enough brownies to cause the weight to triple, bringing the weight to 2*3=<T>multiply(2, 3)=6 pounds.\nNext, he added another 2 pounds of jelly beans, bringing the weight to 6+2=<T>add(6, 2)=

In [12]:
for row in dataset_suitable['train'].select(range(50)):
    print(row['answer'])
    print()

Natalia sold 48/2 = <T>divide(48, 2)=24.0 clips in May.
Natalia sold 48+24 = <T>add(48, 24)=72 clips altogether in April and May.


Weng earns 12/60 = $<T>divide(12, 60)=0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<T>multiply(0.2, 50)=10.0


Maila read 12 x 2 = <T>multiply(12, 2)=24 pages today.
So she was able to read a total of 12 + 24 = <T>add(12, 24)=36 pages since yesterday.
There are 120 - 36 = <T>subtract(120, 36)=84 pages left to be read.
Since she wants to read half of the remaining pages tomorrow, then she should read 84/2 = <T>divide(84, 2)=42.0 pages.


He writes each friend 3*2=<T>multiply(3, 2)=6 pages a week
So he writes 6*2=<T>multiply(6, 2)=12 pages every week
That means he writes 12*52=<T>multiply(12, 52)=624 pages a year


He eats 32 from the largest pizzas because 2 x 16 = <T>multiply(2, 16)=32
He eats 16 from the small pizza because 2 x 8 = <T>multiply(2, 8)=16
He eats 48 pieces because 32 + 16 = <T>add(32, 16)=48


To the initial 2 pounds of jelly b

Save the dataset

In [21]:
train_eval = dataset_suitable["train"].train_test_split(test_size=0.1, seed=42)
dataset_suitable["train"] = train_eval["train"]
dataset_suitable["eval"] = train_eval["test"]
dataset_suitable

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 4128
    })
    test: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 791
    })
    eval: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 459
    })
})

In [22]:
dataset_suitable.push_to_hub('jvhoffbauer/gsm8k-toolcalls', private=True)

Creating parquet from Arrow format: 100%|██████████| 5/5 [00:00<00:00, 21.10ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 94.22ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  3.55it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 28.11ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.36s/it]


In [23]:
dataset_suitable

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 4128
    })
    test: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 791
    })
    eval: Dataset({
        features: ['question', 'answer', 'equations', 'depths', 'toolcalls'],
        num_rows: 459
    })
})