Inspired by https://github.com/liutiedong/goat/blob/main/dataset.ipynb

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import os
import random
from typing import Union

# Set up prompting

In [7]:
# Prompt that goes ahead of each question
prompt = "What is "

# Addition

In [8]:
# Create addition examples

# Creates
pairs = \
[(random.randint(10**(i-1), 10**i), random.randint(10**(j-1), 10**j)) for i in range(1,16) for j in range(i,16) for k in range(1000)] +\
[(random.randint(10**(i-1), 10**i), random.randint(10**(j-1), 10**j)) for i in range(3,16) for j in range(i,16) for k in range(1000)] +\
[(random.randint(10**(i-1), 10**i), random.randint(10**(j-1), 10**j)) for i in range(6,16) for j in range(i,16) for k in range(1000)] +\
[(random.randint(10**(i-1), 10**i), random.randint(10**(j-1), 10**j)) for i in range(9,16) for j in range(i,16) for k in range(1000)] +\
[(random.randint(10**(i-1), 10**i), random.randint(10**(j-1), 10**j)) for i in range(12,16) for j in range(i,16) for k in range(1000)] 

print('FIRST 5:',pairs[:5])
print('MIN:', min(pairs))
print('MAX:', max(pairs))
print('Total:', len(pairs))

FIRST 5: [(8, 8), (10, 8), (7, 3), (7, 2), (9, 9)]
MIN: (1, 1)
MAX: (999966048596408, 402391066750083)
Total: 304000


In [9]:
# Shuffle data
random.shuffle(pairs)

print("Shuffled First 5:", pairs[:5])

Shuffled First 5: [(3231165, 4533275279), (53, 781681064), (2545499, 272615628), (1077359346, 78753630566), (892699, 11129023)]


Data successfully shuffled.

## Create json

In [11]:
# Create output json
addition_output = []

# Do addition
for num1, num2 in pairs:
    
    # Randomly swap numbers
    if random.random()<0.5:
        num1, num2 = num2, num1 

    # Create Question and output
    answer = num1 + num2
    question = f"{prompt}{num1} + {num2}" 
    output = f"{answer}"
    
    addition_output.append({'input': question, 'output': output, 'answer': str(answer)})
    
# View examples
print(addition_output[0])
print(addition_output[-1])

{'input': 'What is 3231165 + 4533275279', 'output': '4536506444', 'answer': '4536506444'}
{'input': 'What is 5488131067032 + 34334519', 'output': '5488165401551', 'answer': '5488165401551'}


# Subtraction

In [13]:
# Create output json
subtraction_output = []

# Do addition
for num1, num2 in tqdm(pairs):
    
    # Randomly swap numbers
    if random.random()<0.5:
        num1, num2 = num2, num1 

    # Create Question and output
    answer = num1 - num2
    question = f"{prompt}{num1} - {num2}" 
    output = f"{num1} - {num2} = {answer}"
    
    subtraction_output.append({'input': question, 'output': output, 'answer': str(answer)})
    
# View examples
print(subtraction_output[0])
print(subtraction_output[-1])

100%|██████████| 304000/304000 [00:00<00:00, 594128.10it/s]

{'input': 'What is 3231165 - 4533275279', 'output': '3231165 - 4533275279 = -4530044114', 'answer': '-4530044114'}
{'input': 'What is 5488131067032 - 34334519', 'output': '5488131067032 - 34334519 = 5488096732513', 'answer': '5488096732513'}





# Multiplication

In [14]:
# Create output json
multiplication_output = []

# Do addition
for num1, num2 in tqdm(pairs):
    
    # Randomly swap numbers
    if random.random()<0.5:
        num1, num2 = num2, num1 

    # Create Question and output
    answer = num1 * num2
    question = f"{prompt}{num1} * {num2}" 
    output = f"{num1} * {num2} = {answer}"
    
    multiplication_output.append({'input': question, 'output': output, 'answer': str(answer)})
    
# View examples
print(multiplication_output[0])
print(multiplication_output[-1])

100%|██████████| 304000/304000 [00:00<00:00, 569195.21it/s]

{'input': 'What is 3231165 * 4533275279', 'output': '3231165 * 4533275279 = 14647760416870035', 'answer': '14647760416870035'}
{'input': 'What is 5488131067032 * 34334519', 'output': '5488131067032 * 34334519 = 188432340395500477608', 'answer': '188432340395500477608'}





# Division

In [15]:
# Create output json
division_output = []

# Do addition
for num1, num2 in tqdm(pairs):
    
    # Randomly swap numbers
    if random.random()<0.5:
        num1, num2 = num2, num1 

    # Create Question and output
    answer = num1 / num2
    question = f"{prompt}{num1} / {num2}" 
    output = f"{num1} / {num2} = {answer}"
    
    division_output.append({'input': question, 'output': output, 'answer': str(answer)})
    
# View examples
print(division_output[0])
print(division_output[-1])

100%|██████████| 304000/304000 [00:00<00:00, 339723.36it/s]

{'input': 'What is 3231165 / 4533275279', 'output': '3231165 / 4533275279 = 0.0007127661130503344', 'answer': '0.0007127661130503344'}
{'input': 'What is 5488131067032 / 34334519', 'output': '5488131067032 / 34334519 = 159842.95766694736', 'answer': '159842.95766694736'}





# Save Outputs

In [16]:
# Addition

# Write output json
addition_output_path = './data/addition_dataset.json'

with open(addition_output_path, "w") as f:
    json.dump(addition_output, f, indent=4)
    
# Verify creation
assert("addition_dataset.json" in os.listdir("data"))

# Subtraction

# Write output json
subtraction_output_path = './data/subtraction_dataset.json'

with open(subtraction_output_path, "w") as f:
    json.dump(subtraction_output, f, indent=4)
    
# Verify creation
assert("subtraction_dataset.json" in os.listdir("data"))

# Multiplication

# Write output json
multiplication_output_path = './data/multiplication_dataset.json'

with open(multiplication_output_path, "w") as f:
    json.dump(multiplication_output, f, indent=4)
    
# Verify creation
assert("multiplication_dataset.json" in os.listdir("data"))

# Division

# Write output json
division_output_path = './data/division_dataset.json'

with open(division_output_path, "w") as f:
    json.dump(division_output, f, indent=4)
    
# Verify creation
assert("division_dataset.json" in os.listdir("data"))

In [22]:
# Read in each dataset and print examples to verify
print("Addition Data")
with open(addition_output_path, "r") as f:
    file = f.readlines()
    
    for line in file[:11]:
        print(line)
        
print("Subtraction Data")
with open(subtraction_output_path, "r") as f:
    file = f.readlines()
    
    for line in file[:11]:
        print(line)
        
print("Multiplication Data")
with open(multiplication_output_path, "r") as f:
    file = f.readlines()
    
    for line in file[:11]:
        print(line)
        
print("Division Data")
with open(division_output_path, "r") as f:
    file = f.readlines()
    
    for line in file[:11]:
        print(line)

Addition Data
[

    {

        "input": "What is 3231165 + 4533275279",

        "output": "4536506444",

        "answer": "4536506444"

    },

    {

        "input": "What is 53 + 781681064",

        "output": "781681117",

        "answer": "781681117"

    },

Subtraction Data
[

    {

        "input": "What is 3231165 - 4533275279",

        "output": "3231165 - 4533275279 = -4530044114",

        "answer": "-4530044114"

    },

    {

        "input": "What is 53 - 781681064",

        "output": "53 - 781681064 = -781681011",

        "answer": "-781681011"

    },

Multiplication Data
[

    {

        "input": "What is 3231165 * 4533275279",

        "output": "3231165 * 4533275279 = 14647760416870035",

        "answer": "14647760416870035"

    },

    {

        "input": "What is 53 * 781681064",

        "output": "53 * 781681064 = 41429096392",

        "answer": "41429096392"

    },

Division Data
[

    {

        "input": "What is 3231165 / 4533275279",

        