# Teach an LLM to do additions

The goal of this project is to teach an LLM to do additions, playing only with two parts:
* the tokenizer
* the positional embedding

Both the model and the dataset are fixed.

You are allowed to tune the hyperparameters, but this is not the main goal. Depending on the quality of your tokenizer and positional embedding, you may change the number of bits. The initial value of 3 is very small.

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
number_bits = 3

dataset_size = 64_000
train_proportion = 0.9

log_interval = 200
batch_size = 64
epochs = 4
learning_rate = 8e-4

## Step 1: Construct a tokenizer

In [4]:
pad_token="[PAD]"
eos_token="[EOS]"

### Baseline: character-level tokenizer

In [5]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [6]:
base_tokenizer = character_level_tokenizer()
ntokens = base_tokenizer.ntokens
ntokens

14

In [7]:
prompt = "12 + 42 ="
inputs = base_tokenizer.encode(prompt)
inputs, base_tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

# Implement your tokenizer here!

You can do anything (as long as you do not compute the addition!).
Some ideas:
* reversing numbers left to right
* arranging by groups (of, 2, 3,...)
* aligning numbers

Mon idée est que lorsque l'on faisait des opérations arithmétiques en primaire, on calculait dizaine par dizaine, avec des retenues. On traite donc les chiffres individuellement, mais à la dizaine pertinente. Ainsi, ici, je propose d'encoder chaque chiffre d'un nombre mais en donnant son unité. Par exemple, "123" sera tokénisé en ["100", "20", "3"].

In [8]:
class advanced_tokenizer:
    """
    Encodes numbers as a sequence of digits of variable length depending of their Power of 10
    (e.g. 123 is encoded as 100, 20, 3)
    """
    def __init__(self):
        """Initialises the tokenizer. 
        The vocabulary is all numbers of the form [digit]00...0 , +, =, [PAD], [EOS]
        0 is explicited added to the vocabulary to avoid repetition of the same number
        We make available in the vocabulary number of n_bits + 1 digits"""

        self.vocab = ["0"] + [str(x * 10 ** k) for x in range(1, 10) for k in range(number_bits +1)] + ["+", "="] + [pad_token, eos_token] 
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = r'[^0-9+=]' # Permet de ne garder que les chiffres, + et = et éviter la fragmentation des nombres à plusieurs chiffres (e.g. 100, 20 ,etc.)
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        transforms the text into a list of tokens
        a number is a sequence of digits that are encoded with their corresponding number in the vocabulary
        for example, 12 is encoded as 10, 2
        """
        number_length = 0
        text_list = []
        number_digits_list = []
        for c in text:
            if c in self.vocab:
                if c == "0" and number_length == 0:
                    continue # avoid adding 0 at the beginning of a number to not confuse the model
                elif c not in [str(x) for x in range(10)] and number_length == 0: # case for + and =
                    text_list.append(c)
                elif c in [str(x) for x in range(0, 10)]: # if c is a digit
                    number_digits_list.append(c)
                    number_length += 1
                #elif c == "0" and number_length > 0: # if c is a 0 in the middle of a number
                 #   number_length += 1 # avoid to have repeated 0 in the middle of a number
                else : # at the end of a number
                    if number_length > 0:
                        for i in range(number_length):
                            text_list.append(str(10 ** (number_length - i - 1) * int(number_digits_list[i])))
                        number_digits_list = []
                        number_length = 0
                    text_list.append(c)
            else:
                continue

        if number_length > 0: # if the number is at the end of the text (not supposed to happen)
            for i in range(number_length):
                text_list.append(str(10 ** (number_length - i - 1) * int(number_digits_list[i])))


        return text_list


    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        """Decoder"""
        output = []
        for x in token_list:
            if x > 9 * (number_bits +1): # if x is not corresponding to a number
                output.append(self.id_to_token[x])
            else:
                output.append(self.id_to_token[x][0])
        return "".join(output)

In [9]:
print("vocabulary: ", advanced_tokenizer().vocab)
print("pretokenisation of 12 + 42 = : ", advanced_tokenizer().pre_tokenization("12 + 42 ="))

vocabulary:  ['0', '1', '10', '100', '1000', '2', '20', '200', '2000', '3', '30', '300', '3000', '4', '40', '400', '4000', '5', '50', '500', '5000', '6', '60', '600', '6000', '7', '70', '700', '7000', '8', '80', '800', '8000', '9', '90', '900', '9000', '+', '=', '[PAD]', '[EOS]']
pretokenisation of 12 + 42 = :  ['10', '2', '+', '40', '2', '=']


In [10]:
tokenizer = advanced_tokenizer()
ntokens = tokenizer.ntokens
ntokens

41

In [11]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([2, 5, 37, 14, 5, 38], '12+42=')

### Inverser l'ordre des nombres

Dans les articles concernant les dernières méthodes de tokénisation il s'agit de l'approche préférée, pour combiner cela avec un abaccus (cf. plus bas). En voici une implémentation.

In [12]:
class inverse_tokenizer:
    """
    Encodes numbers as a sequence of digits of variable length depending of their Power of 10
    (e.g. 123 is encoded as 100, 20, 3)
    """
    def __init__(self):
        """Initialises the tokenizer. 
        Identical to the baseline tokenizer"""

        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        transforms the text into a list of tokens
        a number is a sequence of digits that are encoded with their corresponding number in the vocabulary
        for example, 12 is encoded as 10, 2
        """
        number_length = 0
        text_list = []
        number_digits_list = []
        for c in text:
            if c in self.vocab:
                if c == "0" and number_length == 0:
                    continue # avoid adding 0 at the beginning of a number to not confuse the model
                elif c not in [str(x) for x in range(10)] and number_length == 0: # case for + and =
                    text_list.append(c)
                elif c in [str(x) for x in range(0, 10)]: # if c is a digit
                    number_digits_list.append(c)
                    number_length += 1
                else : # at the end of a number
                    if number_length > 0:
                        for i in range(number_length):
                            text_list.append(number_digits_list[number_length - i - 1])
                        number_digits_list = []
                        number_length = 0
                    text_list.append(c)
            else:
                continue

        if number_length > 0: # if the number is at the end of the text (not supposed to happen)
            for i in range(number_length):
                text_list.append(number_digits_list[number_length - i - 1])


        return text_list


    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        """Decoder that removes the 0 at the end of a number"""
        output = []
        number_length = 0
        number_list = []
        for x in token_list:
            if x > 9 and number_length > 0: # if x is not corresponding to a number
                for i in range(number_length):
                    output.append(self.id_to_token[number_list[number_length - i - 1]])
                output.append(self.id_to_token[x])
                number_list = []
                number_length = 0
            elif x > 9 and number_length == 0: # if x is not corresponding to a number
                output.append(self.id_to_token[x])
            else:
                number_list.append(x)
                number_length += 1
        if number_length > 0:
            for i in range(number_length):
                output.append(self.id_to_token[number_list[number_length - i - 1]])

        return "".join(output)

In [13]:
print("vocabulary: ", inverse_tokenizer().vocab)
print("pretokenisation of 12 + 42 = : ", inverse_tokenizer().pre_tokenization("12 + 42 ="))

vocabulary:  ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '=', '[PAD]', '[EOS]']
pretokenisation of 12 + 42 = :  ['2', '1', '+', '2', '4', '=']


In [14]:
i_tokenizer = inverse_tokenizer()
ntokens = i_tokenizer.ntokens
ntokens

14

In [15]:
prompt = "12 + 42 ="
inputs = i_tokenizer.encode(prompt)
inputs, i_tokenizer.decode(inputs)

([2, 1, 10, 2, 4, 11], '12+42=')

## Step 2: Create a dataset for arithmetic operations

In [16]:
def sample_datapoint(number_bits = 3):
    """
    returns a string containing two random numbers on `number_bits` many bits and their sum.
    """
    a_list = [random.randint(0, 9) for _ in range(number_bits)]
    b_list = [random.randint(0, 9) for _ in range(number_bits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    sum_int = a_int + b_int
    return (str(a_int) + "+" + str(b_int) + "=", str(sum_int))

sample_datapoint(3)

('834+253=', '1087')

In [17]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(number_bits))
data[:4]

[('825+373=', '1198'),
 ('471+488=', '959'),
 ('821+874=', '1695'),
 ('939+339=', '1278')]

In [18]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

### Basline: the classical Positional Embedding

In [19]:
class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEmbedder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEmbedder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [20]:
class BaseTransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(BaseTransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEmbedding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        # Passage à la positional encoder en fournissant les token ids
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

# Implement your positional embedding here!

You can do anything. Some ideas:
* RoPE
* (randomised) FIRE
* Abacus

Pour bien correspondre avec le tokeniseur que nous avons implémenté, nous allons utiliser un embedding inspiré d'un Abacus. L'idée est d'indiquer à chaque composante sa puissance de 10 en décomposant les nombre en niveaux hiérarchiques (centaines, dizaines, unités) et permettrait de renforcer les informations d'unité sur les nombres. Ainsi, par exemple 300 recevra l'embedding (2, 3) car $300 = 3 \times 10^2$. Cela permet d'avoir toutes les informations sur un nombre.

Concernant l'usage de l'abacus, celui que je propose est un peu différent, puisqu'il est dans les articles et codes que j'ai vu (nottament https://seunghan96.github.io/llm/nlp/AriTrans/ et https://github.com/mcleish7/arithmetic). En effet, les articles l'utilisant préfèrent le combiner avec un tokéniseur qui se contente de renverser l'ordre de traitement pour mettre d'abord les nombres les plus importants (par exemple 123 est tokénisé en "321"). Il s'agit donc, à ma connaissance d'un travail de recherche nouveau sur la question.

Comme dans la baseline, nous utilisons nn.dropout pour éviter une trop grande adaptation des réseaux de neurone (Les raisons de cela sont bien expliquées dans la documentation: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html).



In [21]:
class AbacusPositionalEmbedding(nn.Module):
    def __init__(self, d_model, dropout=0.5, max_len=5000, tokenizer=tokenizer):
        """
        d_model   : dimension des embeddings
        dropout   : taux de dropout
        max_len   : longueur maximale des séquences pour la positional encoding classique
        tokenizer : instance du tokenizer pour récupérer le vocabulaire et construire la table d'exposants
        """
        super(AbacusPositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.tokenizer = tokenizer
       
        exponent_table = []
        mantissa_table = []
        numeric_mask = []  # 1 pour token numérique (hors "0"), 0 sinon.
        max_exponent = 0
        for token_id, token in tokenizer.id_to_token.items():
            if token.isdigit() and token not in [pad_token, eos_token, "+", "="]:
                if token == "0":
                    exponent_table.append(0)
                    numeric_mask.append(0)
                    mantissa_table.append(0)
                else:
                    exp =  len(token) - 1  # "2" -> 0, "20" -> 1, "300" -> 2, etc. 
                    #exp = number_bits - len(token) + 1 # "2" -> 2, "20" -> 1, "300" -> 0, etc. : autre option pour utiliser le même principe que les nombres binaires (cf. sources plus haut)
                    exponent_table.append(exp)
                    mantissa_table.append(int(token[0]))
                    numeric_mask.append(1)
                    max_exponent = max(max_exponent, exp)
            else:
                exponent_table.append(0)
                numeric_mask.append(0)
                mantissa_table.append(0)

        self.register_buffer('exponent_table', torch.tensor(exponent_table, dtype=torch.long))
        self.register_buffer('numeric_mask', torch.tensor(numeric_mask, dtype=torch.float))
        self.register_buffer('mantissa_table', torch.tensor(mantissa_table, dtype=torch.float))
        # Embedding simple pour les exposants allant de 0 à max_exponent
        self.exp_embedding = nn.Embedding(max_exponent + 1, d_model)
        self.mant_embedding = nn.Embedding(10, d_model)

    def forward(self, x, token_ids=None):
        """
        x         : Tensor de shape (seq_len, batch_size, d_model) (output de l'embedding d'entrée)
        token_ids : Tensor de shape (seq_len, batch_size) contenant les indices originaux (utilisés pour l'abacus)
                    On ajoute l'information abacus.
        """
        
        #token_ids = token_ids.long()
        # Récupère pour chaque token son exposant depuis la table
        exponents = self.exponent_table[token_ids]       # shape : (seq_len, batch_size)
        mantissas = self.mantissa_table[token_ids]       # shape : (seq_len, batch_size)
        mask = self.numeric_mask[token_ids].unsqueeze(-1)  # shape : (seq_len, batch_size, 1)
        exp_emb = self.exp_embedding(exponents)            # shape : (seq_len, batch_size, d_model)
        mant_emb = self.mant_embedding(mantissas)            # shape : (seq_len, batch_size, d_model)
        x = x + (exp_emb + mant_emb) * mask # Ajout de l'information abacus
        x = x + exp_emb * mask 
        return self.dropout(x)

### Abacus for inverse tokeniser

This code was implemented to compare my implementation with the one suggested in the Abaccus paper (https://arxiv.org/pdf/2311.14737).

In [22]:
digit_tokens = inverse_tokenizer().encode("0123456789")

class InverseAbacusPositionalEmbedding(torch.nn.Module):
    """
    Abacus Embeddings, learned emebddings resued for each digit.
    Integers must be reversed for this to work correctly.
    Transformers Can Do Arithmetic with the Right Embeddings, McLeish et al. (2024)
    """

    def __init__(self, embedding_dim, dropout=0.5, max_seq_length=5000, max_k=99, digit_tokens=digit_tokens):
        """
        digit_tokens (list): list of the tokens for each of the 10 digits
        dropout (float): dropout rate (not used)
        embedding_dim (int): dimension to embed into
        max_seq_length (int): maximum number of embeddings that can be trained
        max_k (int): maximum k value which we randomly shift by during training
        """
        super().__init__()
        self.embedding = torch.nn.Embedding(max_seq_length, embedding_dim)
        self.register_buffer("digits", torch.tensor(digit_tokens, dtype=torch.long), persistent=False)
        self.dropout = nn.Dropout(p=dropout)

        self.max_k = max_k

    def helper(self, mask, device=device):
        """
        Converts a binary mask of digit locations into spans of consecutive digits
        """
        mask_shape = mask.shape
        
        # Create a shifted version of the mask to detect changes from 0 to 1
        shifted_mask = torch.cat([torch.zeros((mask_shape[0], 1), device=device, dtype=mask.dtype), mask[:, :-1]], dim=1)
        starts = (shifted_mask != mask) & mask
        
        # Generate IDs for each segment of 1s, processing row-wise
        segment_ids = torch.cumsum(starts, dim=1)
        
        # Generate an index array row-wise
        index = torch.arange(mask.size(1)).repeat(mask.size(0), 1).to(device)
        
        # Reset index at the start of each segment
        reset_index = torch.zeros_like(mask).long()
        second_term = index * starts.long()
        reset_index = reset_index.scatter_add(1, segment_ids, second_term)
        
        # Calculate positions in segment
        positions = index - reset_index.gather(1, segment_ids) + 1
        
        # Ensure only values within 1-segments are non-zero
        result = positions * mask

        return result

    def forward(self, input_ids):
        """
        input_ids (tensor): a batch of inputs, each row is a sample
        """
        mask = torch.isin(input_ids, self.digits)
        output = self.helper(mask, input_ids.device)

        k=0
        if self.training:
            k = random.randint(0, self.max_k)
            output[output>0] += k # as we already have ones in the tensor, the tensor values will be k+1

        return self.embedding(output)

In [23]:
class InvTransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(InvTransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = InverseAbacusPositionalEmbedding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        # Calcul de l'embedding d'entrée
        src = self.input_emb(src) * math.sqrt(self.ninp)
        # Passage à la positional encoder en fournissant les token ids
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

### Transformer model

**!!! IMPORTANT !!!** This model of Transformers is "input first", meaning that an input is a tensor with shape
(length_prompts, batch_size)

In [24]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = AbacusPositionalEmbedding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        # Calcul de l'embedding d'entrée
        emb = self.input_emb(src) * math.sqrt(self.ninp)
        # Passage à la positional encoder en fournissant les token ids
        src = self.pos_encoder(emb, token_ids=src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return F.log_softmax(output_dec, dim=-1), output_enc

Please do not change these parameters!

In [25]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)

tokenizer = tokenizer



In [26]:
def generate(model, prompts, new_tokens = 5):
    input_tensor = prompts # (length_prompts, batch_size)
    input_tensor = input_tensor.to(device)

     # Debug: Print initial tensor shape
    print(f"Initial input_tensor shape: {input_tensor.shape}")
    
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (length_prompts, batch_size, ntokens)

        # Debug: Print output tensor shape
        print(f"Output tensor shape: {output.shape}")

        last_output = output[-1,:,:] # (batch_size, ntokens)
        token = torch.argmax(last_output, -1).view((1,-1)) # (1, batch_size)
        input_tensor = torch.cat((input_tensor, token), 0)
    return input_tensor

In [27]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output.tolist()[0])

Initial input_tensor shape: torch.Size([4, 1])


/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelectSmallIndex: block: [0,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1369: indexSelect

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [28]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [29]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['[PAD][PAD]1+1=', '21+35='], ['2[EOS][PAD]', '56[EOS]'])

In [30]:
def get_batch(split, i):
    data = data_train if split == 'train' else data_test
    prompts = [tokenizer.encode(data[i][0]) for i in range(i, i + batch_size)]
    padded_prompts, length_prompts = pad(prompts, "prompts")
    answers = [tokenizer.encode(data[i][1]) for i in range(i, i + batch_size)]
    padded_answers, length_answers = pad(answers, "answers")
    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, length_prompts, length_answers

In [31]:
X, Y, length_prompts, length_answers = get_batch("train", 243)
X.shape, Y.shape, length_prompts, length_answers

(torch.Size([8, 64]), torch.Size([5, 64]), 8, 4)

## Step 4: Evaluate

In [32]:
def evaluate():
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, length_prompts, length_answers = get_batch("test", i)
            prompts = prompts.to(device) # (length_prompts, batch_size)
            target_answers = target_answers.to(device) # (length_answers + 1, batch_size)
            output = generate(model, prompts, length_answers + 1) # (length_prompts + length_answers + 1, batch_size)
            answers_tokens = output[length_prompts:, :] # (length_answers + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (length_answers + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [33]:
evaluate()

Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([1

0.0

## Step 4: Train the model

In [34]:
def train_epoch():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    total_loss = 0.
    start_time = time.time()
    for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
        prompts, target_answers, length_prompts, length_answers = get_batch("train", i)
        prompts = prompts.to(device) # (length_prompts, batch_size)
        target_answers = target_answers.to(device) # (length_answers, batch_size)
        input_tensor = torch.cat((prompts, target_answers), 0) # (length_prompts + length_answers, batch_size)
        model.zero_grad()
        output, _ = model(input_tensor) # (length_prompts + length_answers, batch_size, ntokens)
        output_answers = output[length_prompts-1:-1,:,:].reshape(-1, ntokens) # (length_answers * batch_size, ntokens)
        target_answers = target_answers.view(-1)
        loss = F.cross_entropy(output_answers, target_answers)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                        elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def train():
    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train_epoch()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy < best_test_accuracy:
            with open("arithmetic.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [35]:
train()

Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([10, 64, 14])
Output tensor shape: torch.Size([11, 64, 14])
Output tensor shape: torch.Size([12, 64, 14])
Initial input_tensor shape: torch.Size([8, 64])
Output tensor shape: torch.Size([8, 64, 14])
Output tensor shape: torch.Size([9, 64, 14])
Output tensor shape: torch.Size([1

In [36]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers)).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

Initial input_tensor shape: torch.Size([8, 1])
Output tensor shape: torch.Size([8, 1, 14])
Output tensor shape: torch.Size([9, 1, 14])
Output tensor shape: torch.Size([10, 1, 14])
Output tensor shape: torch.Size([11, 1, 14])
737+347=1088	 actual result: 1084
Initial input_tensor shape: torch.Size([8, 1])
Output tensor shape: torch.Size([8, 1, 14])
Output tensor shape: torch.Size([9, 1, 14])
Output tensor shape: torch.Size([10, 1, 14])
Output tensor shape: torch.Size([11, 1, 14])
525+718=1200	 actual result: 1243
Initial input_tensor shape: torch.Size([6, 1])
Output tensor shape: torch.Size([6, 1, 14])
Output tensor shape: torch.Size([7, 1, 14])
Output tensor shape: torch.Size([8, 1, 14])
3+552=820	 actual result: 555
Initial input_tensor shape: torch.Size([8, 1])
Output tensor shape: torch.Size([8, 1, 14])
Output tensor shape: torch.Size([9, 1, 14])
Output tensor shape: torch.Size([10, 1, 14])
Output tensor shape: torch.Size([11, 1, 14])
838+935=1888	 actual result: 1773
Initial input_

## Probing

This is just for fun...

In [37]:
import numpy as np

train_size = 1000
test_size = 100

model.eval()

def data_probing(size):
    X = []
    y = np.zeros(size)
    for i in range(size):
        input = torch.tensor(tokenizer.encode(data[i][0])).view((-1, 1)).to(device)
        _, output = model(input)
        output = output[-1,:,:].flatten()
        # determine whether there was a carry in the result:
        carry = len(data[i][1]) > len(data[i][0]) / 2
        X.append(output.cpu().detach().numpy())
        y[i] = carry
    return np.array(X), y

In [38]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

X_train, y_train = data_probing(train_size)
X_test, y_test = data_probing(test_size)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

reg = LogisticRegression()
reg.fit(X_train,y_train)
reg.score(X_test, y_test)

1.0