In [None]:
import textwrap
import random
import collections
import typing

In [None]:
Question = collections.namedtuple("Question", ["x", "y"])
RationaleStep = collections.namedtuple("RationaleStep", ["x", "y", "acc", "carry"])
Rationale = typing.List[RationaleStep]
Correction = collections.namedtuple("Correction", ["line", "actual_rs", "expected_rs"])
Example = collections.namedtuple("Example", ["question", "rationale", "answer"])
CorruptedExample = collections.namedtuple("CorruptedExample", [
    "question",
    "incorrect_rational",
    "incorrect_answer",
    "correction",
    "correction_rationale",
    "correction_answer",
])

In [None]:
def make_question(num_digits: int) -> Question:
    assert num_digits > 0, "need at least one digit to do arithmetic"
    min_digit = 0
    max_digit = int("9" * num_digits)
    x = random.randint(min_digit, max_digit)
    y = random.randint(min_digit, max_digit)
    return Question(x, y)


make_question(3)

In [None]:
def make_rationale_step_corrupt(rs: RationaleStep) -> RationaleStep:
    corrupt_x, corrupt_y, corrupt_acc, corrupt_carry = rs
    sample = random.uniform(0, 1)

    if 0 < sample <= 0.1:
        corrupt_x = corrupt_x + random.randint(1, 9)  # corrupt x
    elif 0.1 < sample <= 0.2:
        corrupt_y = corrupt_y + random.randint(1, 9)  # corrupt y
    elif 0.2 < sample <= 0.5:
        corrupt_acc = str(int(corrupt_acc) + random.randint(1, 9))  # corrupt acc
    else:
        corrupt_carry = 1 if corrupt_carry == 0 else 0  # corrupt carry

    return RationaleStep(
        x=corrupt_x,
        y=corrupt_y,
        acc=corrupt_acc,
        carry=corrupt_carry,    
    )


rs = RationaleStep(x=2, y=5, acc="6", carry=1)
make_rationale_step_corrupt(rs)

In [None]:
def next_rational_step(rs: RationaleStep) -> RationaleStep:
    return RationaleStep(
        x=rs.x // 10,
        y=rs.y // 10,
        acc=f"{((rs.x % 10) + (rs.y % 10) + rs.carry) % 10}{rs.acc}",
        carry=((rs.x % 10) + (rs.y % 10) + rs.carry) // 10,
    )


next_rational_step(rs=RationaleStep(x=21, y=28, acc="07", carry=1))

In [None]:
def make_rationale(q: Question, is_corrupted=False) -> Rationale:
    n_steps = max(len(str(q.x)), len(str(q.y)))
    corrupt_idx = random.randint(0, n_steps-1) if is_corrupted else -1
    rationale = [
        RationaleStep(
            x=q.x,
            y=q.y,
            acc="",
            carry=0,
        )
    ]

    for i in range(n_steps):
        rationale_step = next_rational_step(rationale[-1])
        
        if i == corrupt_idx:
            rationale_step = make_rationale_step_corrupt(rationale_step)
            x, y, acc, carry = rationale_step  # accumulate mistakes

        rationale.append(rationale_step)
    
    rationale.append(RationaleStep(
        x=0,
        y=0,
        acc=f"{rationale_step.carry}{rationale_step.acc}",
        carry=0,
    ))
    
    return rationale


q = Question(x=29, y=57)
print(q)
print(make_rationale(q, is_corrupted=False))

In [None]:
def correct_rationale(incorrect_r: Rationale) -> typing.Union[Correction, None]:
    """
    Loops through rationale steps, finding first that doesn't match expected.
    Returns None if no mistake is found.
    """
    if len(incorrect_r) == 0:
        return None
    
    expected_rs = incorrect_r[0]

    for i, actual_rs in enumerate(incorrect_r):
        if actual_rs != expected_rs:
            return Correction(
                line=i,
                actual_rs=actual_rs,
                expected_rs=expected_rs,
            )
        expected_rs = next_rational_step(expected_rs)        

    return None


q = Question(x=29, y=57)
incorrect_r = make_rationale(q, is_corrupted=True)
print(incorrect_r)
print(correct_rationale(incorrect_r))

In [None]:
def complete_rationale(rs: RationaleStep, max_steps=1_000) -> Rationale:
    rationale = [rs]
    step = 0

    while not (rs.x == 0 and rs.y == 0 and rs.carry == 0 and rs.acc != ""):
        rs = next_rational_step(rs)
        rationale.append(rs)

        if step == max_steps:
            raise AssertionError("Maximum number of rationale steps exceeded!")
        
        step += 1
    
    return rationale


complete_rationale(RationaleStep(x=2, y=9, acc='8', carry=0))

In [None]:
def make_corrupted_example(q: typing.Union[Question, None] = None, default_num_digits: int = 3) -> CorruptedExample:
    if q is None:
        q = make_question(num_digits=default_num_digits)
    incorrect_r = make_rationale(q, is_corrupted=True)
    incorrect_a = int(incorrect_r[-1].acc)

    assert correct_rationale(incorrect_r) is not None, "Mistake must be present in corrupted example"

    correction = correct_rationale(incorrect_r)
    correction_r = complete_rationale(correction.expected_rs)
    correct_a = int(correction_r[-1].acc)

    return CorruptedExample(
        question=q,
        incorrect_rational=incorrect_r,
        incorrect_answer=incorrect_a,
        correction=correction,
        correction_rationale=correction_r,
        correction_answer=correct_a,
    )


make_corrupted_example(default_num_digits=3)

In [None]:
def make_example(num_digits, is_corrupted=False) -> Example:
    q = make_question(num_digits)
    r = make_rationale(q, is_corrupted=is_corrupted)
    a = int(r[-1].acc)

    return Example(
        question=q,
        rationale=r,
        answer=a,
    )


make_example(3, is_corrupted=True)

In [None]:
def number_to_str(x: int) -> str:
    return " ".join(list(str(x)))


number_to_str(123)

In [None]:
def rationale_to_str(r: Rationale) -> str:
    rationale_str = ""

    for i, step in enumerate(r):
        if step.x != 0 or step.y != 0:
            rationale_str += f"{number_to_str(step.x)} + {number_to_str(step.y)} "

        if i != len(r) - 1:
            rationale_str += f", {'' if step.acc == -1 else number_to_str(step.acc)} C: {step.carry}\n"
        
        if i == len(r) - 1:
            rationale_str += f"{'' if step.acc == -1 else number_to_str(step.acc)}"

    return rationale_str


# q = Question(x=29, y=57)
q = Question(x=54, y=2)
r = make_rationale(q)
print(rationale_to_str(r))

In [None]:
def correction_to_str(c: Correction) -> str:
    if c.actual_rs.x != c.expected_rs.x:
        return f"line {c.line} : \"{c.actual_rs.x}\" should be \"{c.expected_rs.x}\""
    elif c.actual_rs.y != c.expected_rs.y:
        return f"line {c.line} : \"{c.actual_rs.y}\" should be \"{c.expected_rs.y}\""
    elif c.actual_rs.acc != c.expected_rs.acc:
        return f"line {c.line} : \"{c.actual_rs.acc}\" should be \"{c.expected_rs.acc}\""
    elif c.actual_rs.carry != c.expected_rs.carry:
        return f"line {c.line} : \"C: {c.actual_rs.carry}\" should be \"C: {c.expected_rs.carry}\""
    else:
        raise ValueError(f"Expected rational step is the same as actual! Got {c.actual_rs} for both actual and expected.")


c = Correction(
    line=2,
    actual_rs=RationaleStep(x=1, y=0, acc='816', carry=0),
    expected_rs=RationaleStep(x=1, y=1, acc='816', carry=0),
)
correction_to_str(c)

In [None]:
def example_to_str(e: Example) -> str:
    question, rationale, answer = e
    x_str = number_to_str(question.x)
    y_str = number_to_str(question.y)
    rationale_str = rationale_to_str(rationale)
    answer_str = number_to_str(answer)
    
    return textwrap.dedent(
f"""Input:
{x_str} + {y_str}

Target:
<scratch>
{rationale_str}
</scratch>
{answer_str}""").strip()


print(example_to_str(make_example(2, is_corrupted=False)))

In [None]:
def corrupted_example_to_str(ce: CorruptedExample) -> str:
    question, incorrect_r, incorrect_a, correction, correction_r, correction_a = ce
    x_str = number_to_str(question.x)
    y_str = number_to_str(question.y)
    incorrect_r_str = rationale_to_str(incorrect_r)
    incorrect_a_str = number_to_str(incorrect_a)
    correction_str = correction_to_str(correction)
    correction_r_str = rationale_to_str(correction_r)
    correction_a_str = number_to_str(correction_a)
    
    return textwrap.dedent(
f"""Input:
{x_str} + {y_str}

Target:
<scratch>
{incorrect_r_str}
</scratch>
{incorrect_a_str}

Correction:
{correction_str}
<scratch>
{correction_r_str}
</scratch>
{correction_a_str}""").strip()


print(corrupted_example_to_str(make_corrupted_example(default_num_digits=2)))

In [None]:
for n in range(1, 1_000):
    ce = make_corrupted_example(default_num_digits=n)
    actual = ce.correction_answer
    expected = ce.question.x + ce.question.y
    if actual != expected:
        print(f"Failed for {ce.question}")
        break

In [None]:
print(corrupted_example_to_str(make_corrupted_example(q=Question(x=2185, y=2822))))