# Teach an LLM to do additions

## Name: Oliver Jack, Mail: olijacklu@gmail.com

## TLDR: Used aligning numbers & reversing for tokenizer (123+456 -> [(3,6), (2,5), (1,4)]) and Abacus approach for positional embedding

The goal of this project is to teach an LLM to do additions, playing only with two parts:
* the tokenizer
* the positional embedding

Both the model and the dataset are fixed.

You are allowed to tune the hyperparameters, but this is not the main goal. Depending on the quality of your tokenizer and positional embedding, you may change the number of bits. The initial value of 3 is very small.

In [31]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

import itertools

In [32]:
number_bits = 12

dataset_size = 64_000
train_proportion = 0.9

log_interval = 200
batch_size = 64
epochs = 4
learning_rate = 8e-4

## Step 1: Construct a tokenizer

In [33]:
pad_token="[PAD]"
eos_token="[EOS]"

### Baseline: character-level tokenizer

In [34]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [35]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [36]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

# Implement your tokenizer here!

You can do anything (as long as you do not compute the addition!).
Some ideas:
* reversing numbers left to right
* arranging by groups (of, 2, 3,...)
* aligning numbers

## My tokenizer

In [37]:
class MyTokenizer:
   def __init__(self):
       # Initialize vocabulary with digits 0-9
       self.vocab = [str(x) for x in range(10)]
       # Extend vocabulary with all digit pairs, plus special tokens
       self.vocab += [str(item) for item in list(itertools.product(self.vocab, self.vocab))] + ["+", "="] + [pad_token, eos_token]

       self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
       self.id_to_token = {k : v for k, v in enumerate(self.vocab)}

       # Regex patterns for tokenization
       self.pattern = re.compile(r'(\d+|\[PAD\]|\[EOS\]|\D)')
       self.num_pattern = re.compile(r'\d+')
       self.ntokens = len(self.vocab)

   def encode(self, text):
       # Remove spaces and split text into components
       text = text.replace(" ", "")
       text_list = re.findall(self.pattern, text)

       new_list = []
       # Handle equations of the form (num1+num2)
       if len(text_list) > 2 and self.num_pattern.match(text_list[0]) and self.num_pattern.match(text_list[2]):
           # Extract and pad numbers to equal length, then reverse
           num1 = text_list[0].zfill(max(len(text_list[0]), len(text_list[2])))[::-1]
           num2 = text_list[2].zfill(max(len(text_list[0]), len(text_list[2])))[::-1]

           # Create digit pairs from aligned numbers
           pairs = [str((num1[i], num2[i])) for i in range(len(num1))]
           new_list.extend(pairs)

           # Process any remaining numbers by reversing and splitting into digits
           for i in range(3, len(text_list)):
               if self.num_pattern.match(text_list[i]):
                   new_list.extend(text_list[i][::-1])
               else:
                   new_list.append(text_list[i])
       else:
           # Default case: process each element individually
           for elem in text_list:
               if self.num_pattern.match(elem):
                   # Reverse numbers and split into individual digits
                   new_list.extend(elem[::-1])
               else:
                   new_list.append(elem)

       # Convert tokens to IDs
       return [self.token_to_id[c] for c in new_list]

   def decode(self, token_list):
       # Convert IDs back to tokens
       tokens = [self.id_to_token[x] for x in token_list if x in self.id_to_token]
       
       # Case 1: Handle simple reversed digits
       if all(t.isdigit() for t in tokens):
           return ''.join(tokens[::-1])
       
       # Case 2: Handle equation with digit pairs
       tuples = [t for t in tokens if t.startswith('(') and t.endswith(')')]
       if tuples:
           # Extract and process digit pairs
           pairs = []
           for t in tuples:
               clean = t.strip('()').replace("'", "").replace('"', "")
               parts = clean.split(',')
               if len(parts) == 2 and parts[0].strip().isdigit() and parts[1].strip().isdigit():
                   pairs.append((parts[0].strip(), parts[1].strip()))
           
           # Reconstruct numbers from pairs and build equation
           if pairs:
               num1 = ''.join(d[0] for d in pairs)[::-1].lstrip('0') or '0'
               num2 = ''.join(d[1] for d in pairs)[::-1].lstrip('0') or '0'
               
               # Add result if equation has equals sign
               if '=' in tokens:
                   eq_idx = tokens.index('=')
                   result = ''.join(t for t in tokens[eq_idx+1:] if t.isdigit())
                   return f"{num1}+{num2}={result[::-1]}"
               return f"{num1}+{num2}"
       
       # Final fallback, join all tokens
       return ''.join(tokens)

In [38]:
tokenizer = MyTokenizer()
ntokens = tokenizer.ntokens
ntokens

114

## Step 2: Create a dataset for arithmetic operations

In [39]:
def sample_datapoint(number_bits = 3):
    """
    returns a string containing two random numbers on `number_bits` many bits and their sum.
    """
    a_list = [random.randint(0, 9) for _ in range(number_bits)]
    b_list = [random.randint(0, 9) for _ in range(number_bits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    sum_int = a_int + b_int
    return (str(a_int) + "+" + str(b_int) + "=", str(sum_int))

sample_datapoint(3)

('983+442=', '1425')

In [40]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:4]

[('124441398717+821229841729=', '945671240446'),
 ('937299185907+993516402866=', '1930815588773'),
 ('946922115189+566800611371=', '1513722726560'),
 ('487533766643+194442577408=', '681976344051')]

In [41]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

### Basline: the classical Positional Embedding

In [42]:
class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEmbedder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEmbedder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Implement your positional embedding here!

You can do anything. Some ideas:
* RoPE
* (randomised) FIRE
* Abacus

**!!! IMPORTANT !!!** This model of Transformers is "input first", meaning that an input is a tensor with shape
(length_prompts, batch_size)

In [43]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(device)

mps


## My positional embedding

In [44]:
class AbacusPositionalEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1000):
        super(AbacusPositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create empty positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # Simple sinusoidal encoding focused on digit positions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Prepare for addition to input embeddings
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to input embedding
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [45]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = AbacusPositionalEmbedding(ninp)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

Please do not change these parameters!

In [46]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=114, bias=True)
  (input_emb): Embedding(114, 128)
  (pos_encoder): AbacusPositionalEmbedding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [47]:
def generate(model, prompts, new_tokens = 5):
    input_tensor = prompts # (length_prompts, batch_size)
    input_tensor = input_tensor.to(device)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (length_prompts, batch_size, ntokens)
        last_output = output[-1,:,:] # (batch_size, ntokens)
        token = torch.argmax(last_output, -1).view((1,-1)) # (1, batch_size)
        input_tensor = torch.cat((input_tensor, token), 0)
    return input_tensor

In [48]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output.tolist()[0])

(tensor([[ 33, 111,  68,   0,   0,   0,   0]], device='mps:0'), '52+83=0000')

In [49]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [50]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['1+1=', '21+35='], ['2[EOS][PAD]', '65[EOS]'])

In [51]:
def get_batch(split, i):
    data = data_train if split == 'train' else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, length_prompts, length_answers

In [52]:
X, Y, length_prompts, length_answers = get_batch("train", 243)
X.shape, Y.shape, length_prompts, length_answers

(torch.Size([13, 64]), torch.Size([14, 64]), 13, 13)

## Step 4: Evaluate

In [53]:
def evaluate():
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch("test", i)
            prompts = prompts.to(device) # (length_prompts, batch_size)
            target_answers = target_answers.to(device) # (length_answers + 1, batch_size)
            output = generate(model, prompts, length_answers + 1) # (length_prompts + length_answers + 1, batch_size)
            answers_tokens = output[length_prompts:, :] # (length_answers + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (length_answers + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [54]:
evaluate()

0.0

## Step 4: Train the model

In [55]:
def train_epoch():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i)
        prompts = prompts.to(device) # (length_prompts, batch_size)
        target_answers = target_answers.to(device) # (length_answers, batch_size)
        input_tensor = torch.cat((prompts, target_answers), 0) # (length_prompts + length_answers, batch_size)
        model.zero_grad()
        output, _ = model(input_tensor) # (length_prompts + length_answers, batch_size, ntokens)
        output_answers = output[length_prompts-1:-1,:,:].reshape(-1, ntokens) # (length_answers * batch_size, ntokens)
        target_answers = target_answers.view(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                        elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def train():
    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train_epoch()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [56]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 57.34 | loss  2.17 | perplexity     8.80
|   400/  900 batches | ms/batch 58.51 | loss  1.22 | perplexity     3.39
|   600/  900 batches | ms/batch 58.54 | loss  0.17 | perplexity     1.19
|   800/  900 batches | ms/batch 58.09 | loss  0.05 | perplexity     1.05
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 77.16s | test accuracy  0.99
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 60.82 | loss  0.03 | perplexity     1.03
|   400/  900 batches | ms/batch 56.88 | loss  0.03 | perplexity     1.03
|   600/  900 batches | ms/batch 60.48 | loss  0.02 | perplexity     1.02
|   800/  900 batches | ms/

In [57]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers)).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

996106597082+445195073285=1441301670367	 actual result: 1441301670367
172902337353+543860561905=716762899258	 actual result: 716762899258
876030450100+234838827760=1110869277860	 actual result: 1110869277860
153046479547+90632912264=243679391811	 actual result: 243679391811
114750700752+976000521104=1090751221856	 actual result: 1090751221856
521931289912+792045587590=1313976877502	 actual result: 1313976877502
920879056650+707814711732=1628693768382	 actual result: 1628693768382
573709497379+250513637183=824223134562	 actual result: 824223134562
673054406245+325047874160=998102280405	 actual result: 998102280405
49797655384+378807193437=428604848821	 actual result: 428604848821
373562768775+382237869698=755800638473	 actual result: 755800638473
387972753517+961995164473=1349967917990	 actual result: 1349967917990
328316127883+765309242356=1093625370239	 actual result: 1093625370239
84728028247+159508358630=244236386877	 actual result: 244236386877
779760423185+89951121052=869711544237

In [61]:
data_test[:20]

[('996106597082+445195073285=', '1441301670367'),
 ('172902337353+543860561905=', '716762899258'),
 ('876030450100+234838827760=', '1110869277860'),
 ('153046479547+90632912264=', '243679391811'),
 ('114750700752+976000521104=', '1090751221856'),
 ('521931289912+792045587590=', '1313976877502'),
 ('920879056650+707814711732=', '1628693768382'),
 ('573709497379+250513637183=', '824223134562'),
 ('673054406245+325047874160=', '998102280405'),
 ('49797655384+378807193437=', '428604848821'),
 ('373562768775+382237869698=', '755800638473'),
 ('387972753517+961995164473=', '1349967917990'),
 ('328316127883+765309242356=', '1093625370239'),
 ('84728028247+159508358630=', '244236386877'),
 ('779760423185+89951121052=', '869711544237'),
 ('387385513906+545138404678=', '932523918584'),
 ('347600583277+355724944683=', '703325527960'),
 ('820129599080+474057074944=', '1294186674024'),
 ('816261259727+296927704937=', '1113188964664'),
 ('743307171048+505171638385=', '1248478809433')]

## Probing

This is just for fun...

In [59]:
import numpy as np

train_size = 1000
test_size = 100

model.eval()

def data_probing(size):
    X = []
    y = np.zeros(size)
    for i in range(size):
        input = torch.tensor(tokenizer.encode(data[i][0])).view((-1, 1)).to(device)
        _, output = model(input)
        output = output[-1,:,:].flatten()
        # determine whether there was a carry in the result:
        carry = len(data[i][1]) > len(data[i][0]) / 2
        X.append(output.cpu().detach().numpy())
        y[i] = carry
    return np.array(X), y

In [60]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

X_train, y_train = data_probing(train_size)
X_test, y_test = data_probing(test_size)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

reg = LogisticRegression()
reg.fit(X_train,y_train)
reg.score(X_test, y_test)

0.99