# Performing Additions on Long numbers using LongFormer

In [3]:
# Imports go here
import torch
import torch.nn as nn
from long_attention import SliddingWindowAttention
import random
import math
import re
import time
import numpy as np

# Tokenizer


<div class="alert alert-info">

### **Tokenizer: Organizing the addition:**

The tokenizer that we implemented here rearranges the digits in a sum by grouping digits of the same significance together, ordered from least to most significant (right to left). For example, the sum `13 + 54` is encoded as:
$$ [3, 4, 1, 5] $$

This encoding method aligns well with the Self-Attention Mechanism used in Transformer models. Since the attention score of a token $x[t]$ at position $t$ is computed with respect to all previous tokens (positions $< t$), we want to ensure that the model attends to digits of lower significance before higher ones. This simulates the way humans naturally perform addition.
</div>

In [7]:
# Tokenizer
eos_token = '[EOS]'
class Tokenizer:
    """
    Binary representation tokenizer
    """
    def __init__(self, number_bits):
        self.delimiters = r'(\[EOS\]|[,\+\=\s])'
        self.vocab = [str(x) for x in range(10)] + [eos_token] + ["="]  # No need for pad token
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
        self.number_bits = number_bits

    def encode(self, text):
        # Splitting number from symbols
        tokens_split = re.split(self.delimiters, text) # Splitting number from symbols

        # Keeping only numbers and = symbol
        tokens = [token for token in tokens_split if token.isdigit() or token == '=']

        if tokens == []:
            print("Invalid prompt, please use at least one number or the sign =")
            raise ValueError

        # get the index of '=' to separate number that should be added vs the answer!
        idx_equal = len(tokens)
        if '=' in tokens:
            idx_equal = tokens.index('=')

        # Pad number with 0 in the beginning if they have less than 'number_bits' digits
        for i in range(len(tokens)):
            if tokens[i].isdigit():
                tokens[i] = '0'*(self.number_bits + 1 - len(tokens[i])) + tokens[i]

        # If we have only one token (number or =), then return its encoding directly
        if len(tokens) == 1:
            return [self.token_to_id[c] for c in tokens[0]]

        # Now we are sure that we have all numbers of size self.number_bits + 1
        # Let us now put every two digits of the same base 10 position next to each other, starting from the units (unités, puis dixaines, puis centaines,..)
        # This ordering (from right to left) is chosen because the attention mechanism considers tokens that were shown in the past
        arranged_digits = []
        for i in range(self.number_bits + 1):
            # Pathological case: do not reverse
            if len(tokens[:idx_equal]) == 1:
                for token in tokens[:idx_equal]:
                    arranged_digits.append(token[i])

            else: # reverse
                for token in tokens[:idx_equal]:
                    arranged_digits.append(token[~i])

        # Add the answer now: remaining tokens after idx_equal
        for token in tokens[idx_equal:]:
            arranged_digits += list(token)

        return [self.token_to_id[c] for c in arranged_digits]

    def merge_digits(self, l):
        result = []
        num = ""
        for char in l:
            if char.isdigit():
                num += char  # Concatenate digits
            else:
                if num:  # If there is a collected number, add it to the result
                    result.append(str(int(num))) # to remove zeros in the beginning!
                    num = ""  # Reset num
                result.append(char)  # Add the non-digit character

        if num:  # Add any remaining number at the end
            result.append(str(int(num)))

        return result

    def decode(self, token_list):
        tokens = [self.id_to_token[j] for j in token_list]

        m = len(tokens)
        if m <= self.number_bits + 2 or self.id_to_token[token_list[-1]] == eos_token:# Answer
            l = self.merge_digits(tokens)
            return ''.join(l)

        else: # It a query
            # The number of input numbers for addition
            # Take the numbers before the sin equal
            idx_equal = len(tokens)
            if '=' in tokens:
                idx_equal = tokens.index('=')

            numbers_before = tokens[:idx_equal]
            k = len(numbers_before) // (self.number_bits + 1)

            numbers = []
            for i in range(k):
                num = list(reversed(numbers_before[i::k]))
                num = ''.join(num)
                num = str(int(num)) # To remove zeros used for padding
                numbers.append(num)

            text = '+'.join(numbers)

            # Now add the numbers after =, i.e the solution
            l = self.merge_digits(tokens[idx_equal:])
            text = text + ''.join(l)
            return text

In [9]:
tokenizer = Tokenizer(4)
prompt = "999 + +  900 = 2000"
inputs = tokenizer.encode(prompt)
print(inputs)
tokenizer.decode(inputs)

[9, 0, 9, 0, 9, 9, 0, 0, 0, 0, 11, 0, 2, 0, 0, 0]


'999+900=2000'

# Positional encoding

In [10]:
# Because we group each pair of digits of the same significance next to each other, 
# then a good positional encoding could be to give each pair the same positional encoding!

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, repeats = 50, number_bits = 3):
        # Create a positional embedding that is periodic in the number of bits used to represent numbers!
        # numbers should be consecutive with no double + or double so that this encoding works (i.e clear and clean prompts!)
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.number_bits = number_bits
        position = torch.arange(repeats, dtype= torch.float).repeat_interleave(2).unsqueeze(1) # each position is repeated 2 times: [0,0, 1, 1, 2, 2,...]

        # Positional encodings
        pe = torch.zeros(position.size(0), d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1) # shape (2 * repeats, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Dataset

In [None]:
from addtions import *
dataset_size = 64000
number_bits = 
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:4]

In [None]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

# Model: Longformer

Still need the final implementation

In [None]:
model = LongFormer()
model.to(device)

In [None]:
model.eval() # disable dropout!

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output.tolist()[0])

# Preprocessing steps

In [None]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            #out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
            #out.append([0] * (max_length - len(x)) + x)
            out.append(x)
        if type_list == "answers":
            #out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
            #out.append(x + [tokenizer.token_to_id[eos_token]] + [0] * (max_length - len(x)))
            out.append(x + [tokenizer.token_to_id[eos_token]])
    return out, max_length

In [None]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

In [None]:
def get_batch(split, i):
    data = data_train if split == 'train' else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, length_prompts, length_answers

In [None]:
X, Y, length_prompts, length_answers = get_batch("train", 243)
X.shape, Y.shape, length_prompts, length_answers

In [None]:
X[:, 0]

In [None]:
def evaluate(model):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch("test", i)
            prompts = prompts.to(device) # (length_prompts, batch_size)
            target_answers = target_answers.to(device) # (length_answers + 1, batch_size)
            output = generate(model, prompts, length_answers + 1) # (length_prompts + length_answers + 1, batch_size)
            answers_tokens = output[length_prompts:, :] # (length_answers + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (length_answers + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [None]:
evaluate(model)

# Training

In [None]:
def train_epoch(model):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i)
        prompts = prompts.to(device) # (length_prompts, batch_size)
        target_answers = target_answers.to(device) # (length_answers, batch_size)
        input_tensor = torch.cat((prompts, target_answers), 0) # (length_prompts + length_answers, batch_size)
        model.zero_grad()

        # TODO: Add comments to the following 4 lines
        output, _ = model(input_tensor) # (length_prompts + length_answers, batch_size, ntokens)
        output_answers = output[length_prompts-1:-1,:,:].reshape(-1, ntokens) # (length_answers * batch_size, ntokens)
        target_answers = target_answers.view(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                        elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def train(model, epochs):
    best_test_accuracy = None
    test_accuracy = evaluate(model)
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train_epoch(model)
        test_accuracy = evaluate(model)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [None]:
# Load again the model (random or from checkpoint)
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)
learning_rate = 8e-4
epochs = 30
train(model, epochs)

# Evaluation

In [None]:
def show_examples(model, data_test):
    model.eval()
    with torch.no_grad():
        for i in range(20):
            prompt, answers = data_test[i]
            prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
            output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
            print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

show_examples(model, data_test)

----