# Performing Additions on Long numbers using LongFormer

In [103]:
# Imports go here
import torch
import torch.nn as nn
from long_attention import *
import random
import math
import re
import time
import numpy as np

In [104]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Tokenizer


<div class="alert alert-info">

### **Tokenizer: Organizing the addition:**

The tokenizer that we implemented here rearranges the digits in a sum by grouping digits of the same significance together, ordered from least to most significant (right to left). For example, the sum `13 + 54` is encoded as:
$$ [3, 4, 1, 5] $$

This encoding method aligns well with the Self-Attention Mechanism used in Transformer models. Since the attention score of a token $x[t]$ at position $t$ is computed with respect to all previous tokens (positions $< t$), we want to ensure that the model attends to digits of lower significance before higher ones. This simulates the way humans naturally perform addition.
</div>

In [105]:
# Tokenizer
eos_token = '[EOS]'
class Tokenizer:
    """
    Binary representation tokenizer
    """
    def __init__(self, number_bits):
        self.delimiters = r'(\[EOS\]|[,\+\=\s])'
        self.vocab = [str(x) for x in range(10)] + [eos_token] + ["="]  # No need for pad token
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
        self.number_bits = number_bits
        self.vocab_size = len(self.vocab)

    def encode(self, text):
        # Splitting number from symbols
        tokens_split = re.split(self.delimiters, text) # Splitting number from symbols

        # Keeping only numbers and = symbol
        tokens = [token for token in tokens_split if token.isdigit() or token == '=']

        if tokens == []:
            print("Invalid prompt, please use at least one number or the sign =")
            raise ValueError

        # get the index of '=' to separate number that should be added vs the answer!
        idx_equal = len(tokens)
        if '=' in tokens:
            idx_equal = tokens.index('=')

        # Pad number with 0 in the beginning if they have less than 'number_bits' digits
        for i in range(len(tokens)):
            if tokens[i].isdigit():
                tokens[i] = '0'*(self.number_bits + 1 - len(tokens[i])) + tokens[i]

        # If we have only one token (number or =), then return its encoding directly
        if len(tokens) == 1:
            return [self.token_to_id[c] for c in tokens[0]]

        # Now we are sure that we have all numbers of size self.number_bits + 1
        # Let us now put every two digits of the same base 10 position next to each other, starting from the units (unités, puis dixaines, puis centaines,..)
        # This ordering (from right to left) is chosen because the attention mechanism considers tokens that were shown in the past
        arranged_digits = []
        for i in range(self.number_bits + 1):
            # Pathological case: do not reverse
            if len(tokens[:idx_equal]) == 1:
                for token in tokens[:idx_equal]:
                    arranged_digits.append(token[i])

            else: # reverse
                for token in tokens[:idx_equal]:
                    arranged_digits.append(token[~i])

        # Add the answer now: remaining tokens after idx_equal
        for token in tokens[idx_equal:]:
            arranged_digits += list(token)

        return [self.token_to_id[c] for c in arranged_digits]

    def merge_digits(self, l):
        result = []
        num = ""
        for char in l:
            if char.isdigit():
                num += char  # Concatenate digits
            else:
                if num:  # If there is a collected number, add it to the result
                    result.append(str(int(num))) # to remove zeros in the beginning!
                    num = ""  # Reset num
                result.append(char)  # Add the non-digit character

        if num:  # Add any remaining number at the end
            result.append(str(int(num)))

        return result

    def decode(self, token_list):
        tokens = [self.id_to_token[j] for j in token_list]

        m = len(tokens)
        if m <= self.number_bits + 2 or self.id_to_token[token_list[-1]] == eos_token:# Answer
            l = self.merge_digits(tokens)
            return ''.join(l)

        else: # It a query
            # The number of input numbers for addition
            # Take the numbers before the sin equal
            idx_equal = len(tokens)
            if '=' in tokens:
                idx_equal = tokens.index('=')

            numbers_before = tokens[:idx_equal]
            k = len(numbers_before) // (self.number_bits + 1)

            numbers = []
            for i in range(k):
                num = list(reversed(numbers_before[i::k]))
                num = ''.join(num)
                num = str(int(num)) # To remove zeros used for padding
                numbers.append(num)

            text = '+'.join(numbers)

            # Now add the numbers after =, i.e the solution
            l = self.merge_digits(tokens[idx_equal:])
            text = text + ''.join(l)
            return text

In [106]:
number_bits = 3
tokenizer = Tokenizer(number_bits)
prompt = "999 + +  900 = 2000"
inputs = tokenizer.encode(prompt)
print(inputs)
tokenizer.decode(inputs)

[9, 0, 9, 0, 9, 9, 0, 0, 11, 2, 0, 0, 0]


'999+900=2000'

# Positional encoding

In [107]:
# Because we group each pair of digits of the same significance next to each other,
# then a good positional encoding could be to give each pair the same positional encoding!

class Abacus(nn.Module):

    def __init__(self, d_model, dropout=0.1, repeats = 50, number_bits = 3):
        # Create a positional embedding that is periodic in the number of bits used to represent numbers!
        # numbers should be consecutive with no double + or double so that this encoding works (i.e clear and clean prompts!)
        super(Abacus, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.number_bits = number_bits
        position = torch.arange(repeats, dtype= torch.float).repeat_interleave(2).unsqueeze(1) # each position is repeated 2 times: [0,0, 1, 1, 2, 2,...]

        # Positional encodings
        pe = torch.zeros(position.size(0), d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1) # shape (2 * repeats, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [108]:
class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEmbedder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEmbedder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Dataset

In [109]:
def sample_datapoint(number_bits = 3):
    """
    returns a string containing two random numbers on `number_bits` many bits and their sum.
    """
    a_list = [random.randint(0, 9) for _ in range(number_bits)]
    b_list = [random.randint(0, 9) for _ in range(number_bits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    sum_int = a_int + b_int
    return (str(a_int) + "+" + str(b_int) + "=", str(sum_int))


In [110]:
dataset_size = 64000
number_bits = 3
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:4]

[('302+160=', '462'),
 ('505+871=', '1376'),
 ('997+725=', '1722'),
 ('926+636=', '1562')]

In [111]:
train_proportion = 0.9
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

# Model: Longformer

Still need the final implementation

In [112]:
# Parameters of the model
vocab_size = tokenizer.vocab_size
embed_dim = 256
n_blocks = 4
num_heads = 4
window_size = 3

# Positional embedder
positional_encoder = PositionalEmbedding(embed_dim)

# LongFormer
model = LongFormer(vocab_size, embed_dim, num_heads, window_size, n_blocks, positional_encoder).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("The total number of trainable parameters is ", trainable_params)

The total number of trainable parameters is  3165708


In [113]:
def generate(model, prompts, new_tokens = 5, device = device):
    input_tensor = prompts # (batch_size, length_prompts)
    input_tensor = input_tensor.to(device)

    for _ in range(new_tokens):
        output = model(input_tensor) # (batch_size, length_prompts, ntokens)
        last_output = output[:,-1,:] # (batch_size, ntokens)
        token = torch.argmax(last_output, -1).view(-1, 1) # (batch_size, 1)
        input_tensor = torch.cat((input_tensor, token), 1)
    return input_tensor

In [114]:
model.eval() # disable dropout!

prompt = "2+332="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((1,-1)).to(device)
print(prompt_tensor.shape)
output = generate(model, prompt_tensor)
output, tokenizer.decode(output.tolist()[0])

torch.Size([1, 9])


(tensor([[ 2,  2,  0,  3,  0,  3,  0,  0, 11,  8,  2,  8,  2,  8]],
        device='cuda:0'),
 '2+332=82828')

# Preprocessing steps

In [115]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append(x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]])
    return out, max_length

In [116]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['1+1=', '21+35='], ['2[EOS]', '56[EOS]'])

In [117]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X.T, Y.T, length_prompts, length_answers

In [118]:
X, Y, length_prompts, length_answers = get_batch("train", 243, 64)
X.shape, Y.shape, length_prompts, length_answers

(torch.Size([64, 9]), torch.Size([64, 5]), 9, 4)

In [120]:
print(Y[0])

tensor([ 0,  9,  3,  3, 10])


In [121]:
print(length_answers)

4


In [122]:
def evaluate(model, batch_size = 64):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (batch_size, length_prompts)
            target_answers = target_answers.to(device) # (batch_size, length_answers + 1)
            output = generate(model, prompts, length_answers + 1) # (batch_size, length_prompts + length_answers + 1)
            answers_tokens = output[:, length_prompts:] # (batch_size, length_answers + 1), contains tokens
            equality_test = answers_tokens == target_answers # (batch_size, length_answers + 1), contains boolean values
            correct += torch.all(equality_test, axis=1).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [98]:
evaluate(model)

0.0

In [127]:
prompts, target_answers, length_prompts, length_answers = get_batch("train", 0, 32)
print(target_answers.shape)
target_answers = target_answers.reshape(-1)
print(target_answers.shape)

torch.Size([32, 5])
torch.Size([160])


# Training

In [132]:
def train_epoch(model, params, vocab_size):
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'])#, betas = (params['beta_1'], params['beta_2']))
    total_loss = 0.
    start_time = time.time()
    batch_size = params['batch_size']

    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i, batch_size)
        prompts = prompts.to(device) # (batch_size, length_prompts)
        target_answers = target_answers.to(device) # (batch_size, length_answers)
        input_tensor = torch.cat((prompts, target_answers), 1) # (batch_size, length_prompts + length_answers)
        optimizer.zero_grad()

        output = model(input_tensor) # (batch_size, length_prompts + length_answers, ntokens)
        #output_answers = output[:,length_prompts-1:-1,:].reshape(-1, vocab_size) # (length_answers * batch_size, ntokens)
        output_answers = output[:,length_prompts:,:].reshape(-1, vocab_size)
        target_answers = target_answers.reshape(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % params['log_interval'] == 0 and batch > 0:
            cur_loss = total_loss / params['log_interval']
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.4f} | perplexity {:8.4f}'.format(batch, len(data_train) // batch_size,
                                                                                                        elapsed * 1000 / params['log_interval'], cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def train(model, params, vocab_size):
    best_test_accuracy = None
    test_accuracy = evaluate(model)
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, params['epochs'] +1):
        epoch_start_time = time.time()
        train_epoch(model, params, vocab_size)
        test_accuracy = evaluate(model)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [140]:
# Training params
params = {'lr': 8e-4,
          'epochs': 5,
          #'beta_1': 0.9,
          #'beta_2': 0.999,
          'batch_size': 64,
          'log_interval': 200}
vocab_size = tokenizer.vocab_size

In [141]:
# Load again the model (random or from checkpoint)
# Parameters of the model
vocab_size = tokenizer.vocab_size
embed_dim = 128
n_blocks = 4
num_heads = 1
window_size = 3

# Positional embedder
positional_encoder = PositionalEmbedding(embed_dim)

# LongFormer
model = LongFormer(vocab_size, embed_dim, num_heads, window_size, n_blocks, positional_encoder)
model.to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("The total number of trainable parameters is ", trainable_params)

The total number of trainable parameters is  796428


In [142]:
train(model, params, vocab_size)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 98.57 | loss 0.0514 | perplexity   1.0527
|   400/  900 batches | ms/batch 101.60 | loss 0.0014 | perplexity   1.0014
|   600/  900 batches | ms/batch 103.93 | loss 0.0007 | perplexity   1.0007
|   800/  900 batches | ms/batch 97.13 | loss 0.0004 | perplexity   1.0004
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 101.36s | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   200/  900 batches | ms/batch 101.88 | loss 0.0000 | perplexity   1.0000
|   400/  900 batches | ms/batch 98.32 | loss 0.0000 | perplexity   1.0000
|   600/  900 batches | ms/batch 100.56 | loss 0.0000 | perplexity   1.0000
|   800/  900 b

# Evaluation

In [143]:
def show_examples(model, data_test):
    model.eval()
    with torch.no_grad():
        for i in range(20):
            prompt, answers = data_test[i]
            prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((1,-1)).to(device) # shape (1, length_prompt)
            output = generate(model, prompt_tensor, len(answers) + 1) # shape (1, length_prompt+ length_answer)
            print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

show_examples(model, data_test)

64+875=1111	 actual result: 939
422+541=1111	 actual result: 963
443+479=1111	 actual result: 922
902+434=11111	 actual result: 1336
974+79=11111	 actual result: 1053
13+80=111	 actual result: 93
775+883=11111	 actual result: 1658
431+792=11111	 actual result: 1223
785+53=1111	 actual result: 838
386+988=11111	 actual result: 1374
220+13=1111	 actual result: 233
480+639=11111	 actual result: 1119
929+667=11111	 actual result: 1596
113+980=11111	 actual result: 1093
145+374=1111	 actual result: 519
681+773=11111	 actual result: 1454
179+336=1111	 actual result: 515
453+741=11111	 actual result: 1194
753+450=11111	 actual result: 1203
434+290=1111	 actual result: 724


----