# Math 5680 Final Project -- Experiments

## Simple Dataset

Choose a folder in which to save the experiment results.

In [1]:
MODEL_DATABASE = 'model5'

Create the experiment object.

In [2]:
from experiments import simple

simple_exp = simple.SimpleDatasetExperiment(MODEL_DATABASE)

Run training.

In [None]:
simple_exp.run_training()

Run evaluation.

In [3]:
simple_exp.run_evaluation(6)

  0%|          | 0/10000 [00:00<?, ?batches/s]

Test accuracy: 7.09
Test accuracy (first): 100.00


## Easy Arithmetic

Choose a folder in which to save the experiment results.

In [1]:
MODEL_DATABASE = 'model4'

Create the experiment object.

In [2]:
from experiments import arithmetic_easy

ar_easy_exp = arithmetic_easy.ArithmeticEasyExperiment(MODEL_DATABASE)

Count number of parameters. Note that the original transformer model (which was also used by Saxton *et al.* on the math dataset) had ~30 million parameters.

In [None]:
sum(p.numel() for p in ar_easy_exp.model.parameters() if p.requires_grad)

Run training.

In [None]:
ar_easy_exp.run_training()

Run evaluation.

In [4]:
ar_easy_exp.run_evaluation(40)

  0%|          | 0/599999 [00:00<?, ?batches/s]

Test accuracy: 40.46%
Test accuracy (first): 76.35%


Try evaluating on the provided testing data, which contains an assortment of easy, medium and hard questions.

In [4]:
import torch
from math_dataset import MathDataset

# Load all arithmetic questions
test_subcats = MathDataset.subcategories()['interpolate']['arithmetic']
datasets = [MathDataset('interpolate', 'arithmetic', s) for s in test_subcats]
test_dl = torch.utils.data.ConcatDataset(datasets)

# Load model weights
ar_easy_exp.trainer.load(40)

# Define an evaluation function to measure accuracy and first-token accuracy
n_total = 0
n_correct = 0
n_first_correct = 0


def evaluate(model_answer, actual_answer):
    global n_total, n_correct, n_first_correct

    n_total += 1

    if model_answer == actual_answer:
        n_correct += 1

    first = min(len(model_answer), len(actual_answer))
    if model_answer[:first] == actual_answer[:first]:
        n_first_correct += 1


# Run evaluation
ar_easy_exp.trainer.evaluate(test_dl, evaluate)

# Print results
print(f'Test accuracy: {n_correct / n_total * 100:.02f}%')
print(f'Test accuracy (first): {n_first_correct / n_total * 100:.02f}%')

  0%|          | 0/90000 [00:00<?, ?batches/s]

Test accuracy: 7.96%
Test accuracy (first): 10.40%


View some sample questions and compare true and model responses.

In [5]:
sample_dl = torch.utils.data.DataLoader(ar_easy_exp.test_dataset, batch_size=64, shuffle=True)
sample_it = iter(sample_dl)

In [6]:
questions, answers = next(sample_it)

ar_easy_exp.model.eval()
with torch.no_grad():
    for q, a in zip(questions, answers):
        model_a = ar_easy_exp.model(q)
        print(f'Q: {q}')
        print(f'A: {a}')
        print(f'M: {model_a}')
        print()

Q: What is 4672 - -654?
A: 5326
M: 5

Q: Simplify (sqrt(19) + (sqrt(171)*-2 - sqrt(171)))*-6.
A: 48*sqrt(19)
M: -126*sqrt(19)

Q: Divide -8 by -9.
A: 8/9
M: 8

Q: What is 2 + (2 - 3) + (-3 - -3)?
A: 1
M: 1

Q: In base 12, what is 1387 - -5?
A: 1390
M: 1

Q: Evaluate (22/8 + -3)/((-2)/2).
A: 1/4
M: -1/4

Q: Divide -6734 by -1.
A: 6734
M: 6

Q: Divide -1 by 4459.
A: -1/4459
M: -1/14

Q: Work out 5 * -0.015.
A: -0.075
M: -0.

Q: What is the square root of 21120 to the nearest integer?
A: 145
M: 145

Q: What is -3739 - -3?
A: -3736
M: -3

Q: In base 5, what is 1132 + 1414?
A: 3101
M: 3

Q: What is the value of (-273)/(-182)*(-1)/(-6)*1?
A: 1/4
M: 1/4

Q: Calculate -4 divided by 297.
A: -4/297
M: -4/

Q: In base 6, what is -13 - 0?
A: -13
M: -13

Q: What is the product of -7.144 and 0.5?
A: -3.572
M: -3.

Q: (-1)/5 - 1089/(-2420)
A: 1/4
M: 1/4

Q: Simplify -4 + (sqrt(95)/sqrt(5) + -1 + sqrt(19) - sqrt(171)*-4).
A: -5 + 14*sqrt(19)
M: -5 sqrt(19)

Q: What is -36 divided by 10?
A: -18/5
M: -1

Test counting ability.

In [8]:
ar_easy_exp.model.eval()

with torch.no_grad():
    for i in range(1, 11):
        q = 'What is ' + ('1 + ' * i)[:-3] + '?'
        print(f'Q: {q}')
        print(f'A: {i}')
        print(f'M: {ar_easy_exp.model(q)}')
        print()

Q: What is 1?
A: 1
M: 1

Q: What is 1 + 1?
A: 2
M: 2

Q: What is 1 + 1 + 1?
A: 3
M: 3

Q: What is 1 + 1 + 1 + 1?
A: 4
M: 4

Q: What is 1 + 1 + 1 + 1 + 1?
A: 5
M: 6

Q: What is 1 + 1 + 1 + 1 + 1 + 1?
A: 6
M: 5

Q: What is 1 + 1 + 1 + 1 + 1 + 1 + 1?
A: 7
M: 5

Q: What is 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1?
A: 8
M: 115555555555555555555555555555

Q: What is 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1?
A: 9
M: 555555555555555555555555555555

Q: What is 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1?
A: 10
M: 115355555555555555555555555555

