In [None]:
import textwrap
import random
import collections
import typing

In [None]:
Question = collections.namedtuple("Question", ["x", "y"])
RationaleStep = collections.namedtuple("RationaleStep", ["x", "y", "acc", "carry"])
Rationale = typing.List[RationaleStep]
Correction = collections.namedtuple("Correction", ["line", "pred", "true", "rationale"])
Example = collections.namedtuple("Example", ["question", "rationale", "answer"])

In [None]:
def make_rationale_step_corrupt(rs: RationaleStep) -> RationaleStep:
    corrupt_x, corrupt_y, corrupt_acc, corrupt_carry = rs
    sample = random.uniform(0, 1)

    if 0 < sample <= 0.1:
        corrupt_x = corrupt_x + random.randint(1, 9)  # corrupt x
    elif 0.1 < sample <= 0.2:
        corrupt_y = corrupt_y + random.randint(1, 9)  # corrupt y
    elif 0.2 < sample <= 0.5:
        corrupt_acc = str(int(corrupt_acc) + random.randint(1, 9))  # corrupt acc
    else:
        corrupt_carry = 1 if corrupt_carry == 0 else 0  # corrupt carry

    return RationaleStep(
        x=corrupt_x,
        y=corrupt_y,
        acc=corrupt_acc,
        carry=corrupt_carry,    
    )


rs = RationaleStep(x=2, y=5, acc="6", carry=1)
make_rationale_step_corrupt(rs)

In [None]:
def make_question(num_digits: int) -> Question:
    assert num_digits > 0, "need at least one digit to do arithmetic"
    min_digit = 0
    max_digit = int("9" * num_digits)
    x = random.randint(min_digit, max_digit)
    y = random.randint(min_digit, max_digit)
    return Question(x, y)


make_question(3)

In [None]:
def make_rationale(q: Question, is_corrupted=False) -> Rationale:
    x = q.x
    y = q.y
    acc = ""
    carry = 0
    rationale = [
        RationaleStep(
            x=x,
            y=y,
            acc=acc,
            carry=carry,
        )
    ]
    n_steps = max(len(str(q.x)), len(str(q.y)))
    corrupt_idx = random.randint(0, n_steps-1) if is_corrupted else -1

    for i in range(n_steps):
        acc = f"{((x % 10) + (y % 10)) % 10 + carry}{acc}"
        carry = ((x % 10) + (y % 10)) // 10
        x = x // 10
        y = y // 10
        rationale_step = RationaleStep(
            x=x,
            y=y,
            acc=acc,
            carry=carry,
        )

        if i == corrupt_idx:
            rationale_step = make_rationale_step_corrupt(rationale_step)
            x, y, acc, carry = rationale_step  # accumulate mistakes

        rationale.append(rationale_step)
    
    rationale.append(RationaleStep(
        x=0,
        y=0,
        acc=f"{carry}{acc}",
        carry=0,
    ))
    
    return rationale


q = Question(x=29, y=57)
print(q)
print(make_rationale(q, is_corrupted=True))

In [None]:
def make_correction(incorrect_rs, correct_rs) -> Correction:
    pass

In [None]:
def make_example(num_digits, is_corrupted=False) -> Example:
    q = make_question(num_digits)
    r = make_rationale(q, is_corrupted=is_corrupted)
    a = int(r[-1].acc)

    return Example(
        question=q,
        rationale=r,
        answer=a,
    )


make_example(3, is_corrupted=True)

In [None]:
def number_to_str(x: int) -> str:
    return " ".join(list(str(x)))


number_to_str(123)

In [None]:
def rationale_to_str(r: Rationale) -> str:
    rationale_str = ""

    for i, step in enumerate(r):
        if step.x != 0 or step.y != 0:
            rationale_str += f"{number_to_str(step.x)} + {number_to_str(step.y)} "

        if i != len(r) - 1:
            rationale_str += f", {'' if step.acc == -1 else number_to_str(step.acc)} C: {step.carry}\n"
        
        if i == len(r) - 1:
            rationale_str += f"{'' if step.acc == -1 else number_to_str(step.acc)}"

    return rationale_str


# q = Question(x=29, y=57)
q = Question(x=54, y=2)
r = make_rationale(q)
print(rationale_to_str(r))

In [None]:
def correction_to_str(c: Correction) -> str:
    return ""

In [None]:
def example_to_str(e: Example) -> str:
    question, rationale, answer = e
    x_str = number_to_str(question.x)
    y_str = number_to_str(question.y)
    rationale_str = rationale_to_str(rationale)
    answer_str = number_to_str(answer)

    correction_str = ""
    correction_rationale_str = ""
    correction_answer_str = ""
    
    return textwrap.dedent(
f"""Input:
{x_str} + {y_str}

Target:
<scratch>
{rationale_str}
</scratch>
{answer_str}

Correction:
{correction_str}
<scratch>
{correction_rationale_str}
</scratch>
{correction_answer_str}""").strip()


print(example_to_str(make_example(2, is_corrupted=False)))