In [56]:
from datasets import load_from_disk
import os
DATA_DIR = "../data/" 

#check if data directory exists, if not create it
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

dataset = load_from_disk(os.path.join(DATA_DIR, "gsm8k"))


In [57]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 7473
    })
    test: Dataset({
        features: ['question', 'answer'],
        num_rows: 1319
    })
})

In [36]:
train_dataset = dataset["train"]

print("Example of a training sample:")
print("~~~~~Question:~~~~~" )
print(train_dataset["question"][0])
print("~~~~~Answer:~~~~~" )
print(train_dataset["answer"][0])

Example of a training sample:
~~~~~Question:~~~~~
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
~~~~~Answer:~~~~~
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72


## Parsing Dataset

In [49]:
import re

# def find_encapsulated_strings(input_string, pattern = r'<<(.*?)>>'):
    
#     def get_matched_string_with_indices(match):
#         matched_str_w_idx = {
#             "equation_w_pattern": match.group(0),
#             "equation_witout_pattern ": match.group(1),
#             "start_idx": match.start(),
#             "end_idx": match.end()
#         }
#         index_equal_sign = matched_str_w_idx["equation_w_pattern"].index('=')
#         len_pre_equal_sign = len(matched_str_w_idx["equation_w_pattern"][:index_equal_sign])
#         matched_str_w_idx["idx_equal_sign"] = index_equal_sign
#         matched_str_w_idx["og_len_pre_equal_sign"] = len_pre_equal_sign
#         return matched_str_w_idx

#     # Use regular expression to find all matches
#     matches = re.finditer(pattern, input_string)
    
#     # Store matches and their indices in a list of tuples
#     matched_strings_with_indices = [get_matched_string_with_indices(match) for match in matches]
#     return matched_strings_with_indices

def find_pattern(input_string,pattern):
        
    # Use regular expression to find all matches
    matches = re.finditer(pattern, input_string)
    
    # Store indices of matches in a list
    equal_sign_indices = [{"start": match.start(), "end": match.end(), "pattern": pattern} for match in matches]
    
    return equal_sign_indices

# def add_pause_to_matched_strings(matched_strings_with_indices, pause_token = "<|PAUSE|>", n_pauses = 1):
#     def add_pause_to_matched_string(matched_str_with_idx):
#         content = matched_str_with_idx["equation_w_pattern"]
#         index = content.index('=') + 1
#         matched_str_with_idx["equation_w_pattern"] = add_pause(content, index, n_pauses, pause_token)
#         return matched_str_with_idx

#     augmented_matched_strings_with_indices = [add_pause_to_matched_string(matched_str_with_idx) for matched_str_with_idx in matched_strings_with_indices]
#     return augmented_matched_strings_with_indices

def add_pause(string, idx ,n_pauses, pause_token):
    pause_toks = n_pauses * pause_token
    return string[:idx] + pause_toks + string[idx:]

def inject_pause_to_str(input_string, n_pauses_per_patterns, pause_token):
    patterns = list(n_pauses_per_patterns.keys())
    
    pattern_occurences = []
    for pat in patterns:
        pattern_occurences.extend(find_pattern(input_string, pattern= pat))
    
   
    pattern_occurences.sort(key=lambda x: x["start"], reverse=True)
    augmented_string = input_string
    for patt in pattern_occurences:
        augmented_string =  add_pause(
            string = augmented_string,
            idx = patt["start"] + 1,
            n_pauses = n_pauses_per_patterns[patt["pattern"]],
            pause_token = pause_token
        )
    
    if r"\n" in patterns:
        
        augmented_string =  add_pause(
                string = augmented_string,
                idx = 0,
                n_pauses = n_pauses_per_patterns[r"\n"],
                pause_token = pause_token
            )
    return augmented_string

def inject_pauses(
        sample,
        n_pauses_per_patterns = {
            r"=": 1,
            r"\n": 1
            },
        pause_token = "<|PAUSE|>"
    ):
    
    input_string = sample["answer"]
    sample["pause_augmented_answer"]  = inject_pause_to_str(input_string, n_pauses_per_patterns, pause_token)
    return sample

# Example usage
input_string = "Natalia sold 48/2 = <<48/2=24>>24 clips in May. Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May. #### 72"

res = inject_pause_to_str(input_string, n_pauses_per_patterns = {r"=": 1, r"\n": 1}, pause_token = "<|PAUSE|>")
print(res)



<|PAUSE|>Natalia sold 48/2 =<|PAUSE|> <<48/2=<|PAUSE|>24>>24 clips in May. Natalia sold 48+24 =<|PAUSE|> <<48+24=<|PAUSE|>72>>72 clips altogether in April and May. #### 72


In [53]:
n_pauses_per_patterns = {
            r"=": 1,
            r"\n": 1
            }
pause_token = "<|PAUSE|>"

train_dataset = train_dataset.map(lambda sample: inject_pauses(sample,n_pauses_per_patterns, pause_token)) 

Map:  15%|█▍        | 1087/7473 [00:00<00:00, 10774.48 examples/s]

Map: 100%|██████████| 7473/7473 [00:00<00:00, 13097.92 examples/s]


In [58]:
train_dataset.sample()

AttributeError: 'Dataset' object has no attribute 'sample'

In [54]:
for sample in train_dataset:
    print("~~~~~OG Answer:~~~~~" )
    print(sample["answer"])
    print("~~~~~Answer:~~~~~" )
    print(sample["pause_augmented_answer"])
    print("~~~~~~~~~~~~~~~~~")
    

~~~~~OG Answer:~~~~~
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72
~~~~~Answer:~~~~~
<|PAUSE|><|PAUSE|>Natalia sold 48/2 =<|PAUSE|> <<48/2=<|PAUSE|>24>>24 clips in May.
<|PAUSE|><|PAUSE|>Natalia sold 48+24 =<|PAUSE|> <<48+24=<|PAUSE|>72>>72 clips altogether in April and May.
<|PAUSE|><|PAUSE|>#### 72
~~~~~~~~~~~~~~~~~
~~~~~OG Answer:~~~~~
Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.
Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.
#### 10
~~~~~Answer:~~~~~
<|PAUSE|><|PAUSE|>Weng earns 12/60 =<|PAUSE|> $<<12/60=<|PAUSE|>0.2>>0.2 per minute.
<|PAUSE|><|PAUSE|>Working 50 minutes, she earned 0.2 x 50 =<|PAUSE|> $<<0.2*50=<|PAUSE|>10>>10.
<|PAUSE|><|PAUSE|>#### 10
~~~~~~~~~~~~~~~~~
~~~~~OG Answer:~~~~~
In the beginning, Betty has only 100 / 2 = $<<100/2=50>>50.
Betty's grandparents gave her 15 * 2 = $<<15*2=30>>30.
This means, Betty needs 100 - 50 - 30 - 15 = $<<100-50-30-15=5>>5 more.
#### 5
~~~