# Chain-of-Thought training data

In [1]:
import random

import torch
import matplotlib.pyplot as plt

from arithmetic_lm.model import TransformerDecoder, generate
from arithmetic_lm.tokenizer import CharTokenizer
from arithmetic_lm.formatting import format_line, split_operands_and_op

import warnings

warnings.filterwarnings("ignore")

In [40]:
def chain_of_thought_addition(a: str, b: str) -> str:
    """
    Input: 567+7890
    CoT: 7+0=7c0,6+9=5c1,5+8=3c1,0+7=8c0|567+7890=8457
    """
    res = ""

    length = max(len(a), len(b))
    a = a.zfill(length)
    b = b.zfill(length)

    # start from last digit
    for da, db in zip(reversed(a), reversed(b)):
        da = int(da)
        db = int(db)
        msum = (da + db) % 10
        carry = (da + db) // 10
        res += f"{da}+{db}={msum}c{carry},"
    res = res[:-1]  # remove last comma
    res += f"|{a}+{b}={int(a)+int(b)}"
    return res

In [41]:
example = "567+7890="

a, op, b = split_operands_and_op(example)
print(a, b, op)

cot = chain_of_thought_addition(a, b)
print("CoT:", cot)

567 7890 +
CoT: 7+0=7c0,6+9=5c1,5+8=3c1,0+7=7c0|0567+7890=8457


7+0
6+9
5+8
