# Train a decoder-only transformer (GPT-like) to do addition

I learned a lot from https://github.com/karpathy/minGPT, but I have rewritten all the code based on my own understanding.


In [1]:
import math
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from tqdm import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Data processing
### Generate an addition dataset

In [2]:
PLUS_SIGN = 10
MUL_SIGN  = 11
MINUS_SIGN = 12
EQUAL_SIGN = 13
EOS = 14
BOS = 15
PAD = 16
UNK = 17

symbol_to_int_dict = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4,
                      "5": 5, "6": 6, "7": 7, "8": 8, "9": 9,
                      "+": PLUS_SIGN, "*": MUL_SIGN, "-": MINUS_SIGN,
                      "=": EQUAL_SIGN,  "<EOS>": EOS, "<BOS>": BOS,
                      "<pad>": PAD, "??": UNK
                      }

int_to_symbol_dict = {y:x for (x,y) in symbol_to_int_dict.items()}
vocab_size = len(symbol_to_int_dict)

def decode_equation(equation):
    '''convert an equation in list format to string format '''
    res = "".join([str(int_to_symbol_dict.get(x, UNK)) for x in equation.tolist()])
    return res.replace("<BOS>", "").replace("<EOS>", "")

def encode_equation(equation, max_ndigits, padQ=True):
    '''convert an equation(up to the equal sign in it) in string format to a list'''
    equal_size_loc = equation.index('=')
    plus_size_loc = equation.index('+')
    num1 = pad_number(equation[0:plus_size_loc], max_ndigits)
    num2 = pad_number(equation[plus_size_loc+1:equal_size_loc], max_ndigits)
    new_equation = num1 + "+" + num2 + "="
    return torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in new_equation]).to(DEVICE)


def pad_number(num, max_ndigits)->str:
    s = str(num)
    while len(s)<max_ndigits:
      s = "0"+s
    return s

def create_add_dataset(max_ndigits, dataset_size, padQ=True):
    ''' Function for creating an addition dataset.
    if padQ=True, pre-padding of 0s will be added on the numbers such that all the 
    numbers has the same length max_ndigits, for example, with max_ndigits=3,  
    32 will be represented 032.
    '''
    dataset_str = []
    for i in range(dataset_size):
        num1, num2 = np.random.randint(0, 10**max_ndigits, 2)
        ans = num1 + num2
        if padQ:
            equation = pad_number(num1, max_ndigits) + '+' + pad_number(num2, max_ndigits) + "=" + pad_number(ans, max_ndigits)
        else:
            equation = str(num1) + '+' + str(num2) + "=" + str(ans)
        dataset_str.append(equation)

    dataset = [torch.tensor([BOS]+[symbol_to_int_dict.get(n, UNK) for n in x]+[EOS])
               for x in dataset_str]
    return dataset, dataset_str

print(create_add_dataset(2,20, padQ=False))
print(create_add_dataset(2, 20, padQ=True))

([tensor([15,  7, 10,  8,  7, 13,  9,  4, 14]), tensor([15,  8,  7, 10,  7,  0, 13,  1,  5,  7, 14]), tensor([15,  8,  5, 10,  4,  4, 13,  1,  2,  9, 14]), tensor([15,  6,  7, 10,  7,  0, 13,  1,  3,  7, 14]), tensor([15,  0, 10,  1,  7, 13,  1,  7, 14]), tensor([15,  1,  6, 10,  5,  1, 13,  6,  7, 14]), tensor([15,  2,  1, 10,  9,  7, 13,  1,  1,  8, 14]), tensor([15,  5,  9, 10,  6,  9, 13,  1,  2,  8, 14]), tensor([15,  7,  3, 10,  1,  3, 13,  8,  6, 14]), tensor([15,  7,  4, 10,  7,  5, 13,  1,  4,  9, 14]), tensor([15,  5,  7, 10,  7,  8, 13,  1,  3,  5, 14]), tensor([15,  6,  9, 10,  1,  4, 13,  8,  3, 14]), tensor([15,  2,  9, 10,  1, 13,  3,  0, 14]), tensor([15,  6,  3, 10,  3,  4, 13,  9,  7, 14]), tensor([15,  5,  9, 10,  4,  3, 13,  1,  0,  2, 14]), tensor([15,  2,  2, 10,  9,  2, 13,  1,  1,  4, 14]), tensor([15,  8,  5, 10,  8,  4, 13,  1,  6,  9, 14]), tensor([15,  6,  4, 10,  2,  4, 13,  8,  8, 14]), tensor([15,  6,  3, 10,  6, 13,  6,  9, 14]), tensor([15,  1,  5, 10, 

### Create dataloders for the train, validation and test sets

In [3]:
class TranslationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

batch_size = 256
def pad_sequence(batch):
    input_padded = torch.nn.utils.rnn.pad_sequence(batch,
                                batch_first=True, padding_value = PAD)
    return input_padded

@dataclass
class DataLoaders:
    max_ndigits: int
    dataset_size: int
    padQ: bool = True
    val_loader = None
    test_loader = None
    train_loader = None

    def split_data(self, split=[0.7, 0.1, 0.2]):
        if isinstance(split[0], float):
            train_size  = round(self.dataset_size*split[0])
            val_size = round(self.dataset_size*split[1])
            test_size = self.dataset_size - train_size - val_size
        elif isinstance(split[0], int):
            val_size = split[0]
            test_size = split[1]
            train_size  = dataset_size - test_size - val_size


        dataset, _ = create_add_dataset(self.max_ndigits, self.dataset_size, padQ=self.padQ)
        train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                             [train_size, val_size, test_size],
                                                    generator=torch.Generator().manual_seed(42) )

        self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                           shuffle=True, collate_fn = pad_sequence)
        self.test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)
        self.val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size,
                                           shuffle=True, collate_fn=pad_sequence)


## GPT model
Here is my implementation of the GPT model, including the multi-headed self-attention module.

In [4]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0 # check the h number
        self.d_k = d_model//h
        self.d_model = d_model
        self.h = h
        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_query, x_key, x_value, mask=None):
        nbatch = x_query.size(0) # get batch size
        # 1) Linear projections to get the multi-head query, key and value
        # x_query, x_key, x_value dimension: nbatch * seq_len * d_model
        # LHS query, key, value dimensions: nbatch * h * seq_len * d_k
        query = self.WQ(x_query).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        key   = self.WK(x_key).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        value = self.WV(x_value).view(nbatch, -1, self.h, self.d_k).transpose(1,2)
        # 2) Attention
        # scores has dimensions: nbatch * h * seq_len * seq_len
        scores = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_model)
        # 3) Mask out padding tokens and future tokens
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        # p_atten dimensions: nbatch * h * seq_len * seq_len
        p_atten = torch.nn.functional.softmax(scores, dim=-1)
        # x dimensions: nbatch * h * seq_len * d_k
        x = torch.matmul(p_atten, value)
        # x now has dimensions:nbtach * seq_len * d_model
        x = x.transpose(1, 2).contiguous().view(nbatch, -1, self.d_model)

        return self.linear(x) # final linear layer


class ResidualConnection(nn.Module):
  '''residual connection: x + dropout(sublayer(layernorm(x))) '''
  def __init__(self, dim, dropout):
      super().__init__()
      self.drop = nn.Dropout(dropout)
      self.norm = nn.LayerNorm(dim)

  def forward(self, x, sublayer):
      return x + self.drop(sublayer(self.norm(x)))


class Decoder(nn.Module):

    def __init__(self, vocab_size, h, d_embed, max_len, N=4, drop_rate=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_embed)
        self.pos_embed = nn.Embedding(max_len, d_embed)
        self.dropout = nn.Dropout(drop_rate)
        self.decoder_blocks = nn.Sequential(*[DecoderBlock(h, d_embed) for _ in range(N)])
        self.norm = nn.LayerNorm(d_embed)
        self.linear = nn.Linear(d_embed, vocab_size)

    def forward(self, trg, trg_pad_mask):
        pos_embedding = self.pos_embed(torch.tensor(range(trg.size(-1))).to(DEVICE))
        x = self.embed(trg) + pos_embedding
        x = self.dropout(x)
        for layer in self.decoder_blocks:
            x = layer( x, trg_pad_mask)
        x = self.norm(x)
        logits = self.linear(x)
        return logits


class DecoderBlock(nn.Module):
    def __init__(self, h, d_embed, dropout=0.1):
        super().__init__()
        self.atten1 = MultiHeadedAttention(h, d_embed)
        self.atten2 = MultiHeadedAttention(h, d_embed)
        self.ffn = nn.Sequential(
            nn.Linear(d_embed, 4*d_embed),
            nn.GELU(),
            nn.Linear(4*d_embed, d_embed),
            nn.Dropout(dropout)
        )
        self.residual1 = ResidualConnection(d_embed, dropout)
        self.residual2 = ResidualConnection(d_embed, dropout)
        self.residual3 = ResidualConnection(d_embed, dropout)

    def future_mask(self, seq_len):
        '''mask for masking out tokens at future positions'''
        mask = (torch.triu(torch.ones(seq_len, seq_len, requires_grad=False), diagonal=1)!=0).to(DEVICE)
        return mask.view(1, 1, seq_len, seq_len)

    def forward(self,  decoder_layer_input, decoder_pad_mask):
        y = decoder_layer_input
        seq_len = y.size(-2)
        decoder_mask = torch.logical_or(decoder_pad_mask, self.future_mask(seq_len))
        y = self.residual1(y, lambda y: self.atten1(y, y, y, mask=decoder_mask))

        return self.residual3(y, self.ffn)

class GPT(nn.Module):
    def __init__(self, decoder):
        super().__init__()
        self.decoder = decoder

    def forward(self, input, pad_mask):
        return self.decoder(input, pad_mask)

### Let's creat a GPT!

In [5]:
@dataclass
class ModelConfig:
  d_embed: int
  # d_ff is the dimension of the fully-connected layer
  d_ff: int
  # h is the number of attention head
  h: int
  N_decoder: int
  max_len: int
  dropout: float


def make_GPT(config):
    model = GPT(Decoder(vocab_size, config.h, config.d_embed, config.max_len,
                        config.N_decoder)).to(DEVICE)
    # initialize model parameters
    # it seems that this initialization is very important!
    for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    return model

## Functions for training and input/output processing

In [6]:
def make_batch_input(x):
        'function for generating model input, target and pad_mask from raw input x'
        input = x[:, :-1].to(DEVICE)
        equal_sign_loc = [(equation==EQUAL_SIGN).nonzero().item() for equation in x]
        target = [torch.cat((torch.tensor([PAD]*equal_sign_loc[i]), x[i][equal_sign_loc[i]+1:])) for i in range(len(x))]
        target = torch.cat(target, 0).contiguous().view(-1).to(DEVICE)
        pad_mask = (input == PAD).view(input.size(0), 1, 1, input.size(-1))
        return input, target, pad_mask

In [34]:
def train_epoch(model, dataloader):
    model.train()
    grad_norm_clip = 1.0
    losses, acc, count = [], 0, 0
    num_batches = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=num_batches)
    for idx, x  in  pbar:
        optimizer.zero_grad()
        input, target, pad_mask = make_batch_input(x)
        pred = model(input, pad_mask).to(DEVICE)
        pred = pred.view(-1, pred.size(-1))
        loss = loss_fn(pred, target).to(DEVICE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        scheduler.step()
        losses.append(loss.item())
        # report progress
        if idx>0 and idx%50 == 0:
            pbar.set_description(f"ep: {scheduler.last_epoch//num_batches}, train loss={loss.item():.3f},lr={scheduler.get_last_lr()[0]:.5f}")
    return np.mean(losses)

def train(model, dataloaders, epochs):
    best_val_loss = float('inf')
    global early_stop_count
    train_size = len(dataloaders.train_loader)*batch_size
    for ep in range(epochs):
        train_loss = train_epoch(model, dataloaders.train_loader)
        val_loss = validate(model, dataloaders.val_loader)
        print(f'ep {ep}: train_loss: {train_loss:.5f}, val_loss: {val_loss:.5f}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        else:
            if scheduler.last_epoch>2*warmup_steps:
                early_stop_count -= 1
                if early_stop_count<=0:
                    #torch.save(model, f'saved_models/{SRC}_to_{TRG}_train_size_{train_size}_model_size_{model_size}.pt')
            #f = open("save_model/{SRC}_to_{TRG}_dataset_size_{dataset_size}.txt", 'w'):
            #f.write()
                    return train_loss, val_loss
    return train_loss, val_loss


def validate(model, dataloder):
    'function for computing the loss on the validation set'
    model.eval()
    losses = []
    with torch.no_grad():
        for i, x in enumerate(dataloder):
            input, target, pad_mask = make_batch_input(x)
            pred = model(input, pad_mask).to(DEVICE)
            pred = pred.view(-1, pred.size(-1))
            losses.append(loss_fn(pred, target).item())
    return np.mean(losses)

In [22]:
@torch.no_grad()
def compute_sum(model, x):
    'Function for computing the sum of two numbers.'
    for i in range(max_ndigits+2):
        pad_mask = (x == PAD).view(1, 1, 1, x.size(-1)).to(DEVICE)
        logits = model(x, pad_mask)
        last_output = logits.argmax(-1)[:,-1].view(1,1)
        x = torch.cat((x, last_output), 1).to(DEVICE)
        if last_output.item() == EOS:
            break
    return x[0]

def evaluate(model, dataloader, num_batch=None):
    '''Function for evaluation the model.
    This function take equations, and truncate them up to the equal-sign, and feed them to the
    model to get the predictions, compare them with the correct answers, and output the accuracy.
    '''
    model.eval()
    acc, count = 0, 0
    num_wrong_to_display = 5
    for idx, x in enumerate(dataloader):
        for equation in x:
            loc_equal_sign = equation.tolist().index(EQUAL_SIGN)
            loc_EOS = equation.tolist().index(EOS)
            input = equation[0:loc_equal_sign+1].view(1, -1).to(DEVICE)
            ans = equation[:loc_EOS+1].tolist()
            ans_pred = compute_sum(model, input)
            count += 1

            if ans == ans_pred.tolist():
                acc +=1
            else:
                if num_wrong_to_display > 0:
                    print(f'correct equation: {decode_equation(equation).replace("<pad>","")}')
                    print(f'predicted:        {decode_equation(ans_pred)}')
                    num_wrong_to_display -= 1
        if num_batch and idx>num_batch:
            break
    return acc/count

def what_is(question:str)->str:
    'function for computing the sum of two numbers with input in literal string format'
    pred = compute_sum(model, encode_equation(question, max_ndigits).view(1,-1))
    pred = decode_equation(pred)
    pred = pred[pred.index("=")+1:]
    return question+pred


## 2-digit addition

In [32]:
max_ndigits = 2
# max_len is determined by 1+ max_ndigits + 1 + max_ndigits + 1 + max_ndigits +1 +1
max_len = 3*max_ndigits + 6
config = ModelConfig(d_embed=128, d_ff=256, h=4, N_decoder=2, max_len= max_len,
                           dropout=0.1)
dataset_size = 10000
data_loaders = DataLoaders(max_ndigits, dataset_size, padQ=True)
data_loaders.split_data(split=[1000, 2000])
train_size = len(data_loaders.train_loader)*batch_size
model = make_GPT(config)
model_size = sum([p.numel() for p in model.parameters()])
print(f'model_size: {model_size}, train_set_size: {train_size}')
warmup_steps = 3*len(data_loaders.train_loader)
# lr first increases in the warmup steps, and then descreases
#lr_fn = lambda step: 3*min([(step+1)/(3*len(data_loaders.train_loader)), (step+1)**(-0.5)])
lr_fn = lambda step: config.d_embed**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD)
early_stop_count = 10 # Setting early_stop_count to a large number, that is, I'm not implementing early_stop here

train_loss, val_loss = train(model, data_loaders, epochs=30)


model_size: 535570, train_set_size: 7168


100%|██████████| 28/28 [00:00<00:00, 65.79it/s]


train_loss: 2.08625, val_loss: 1.45997


100%|██████████| 28/28 [00:00<00:00, 69.45it/s]


train_loss: 1.37422, val_loss: 1.19786


100%|██████████| 28/28 [00:00<00:00, 69.52it/s]


train_loss: 1.16191, val_loss: 1.03280


100%|██████████| 28/28 [00:00<00:00, 66.76it/s]


train_loss: 1.01763, val_loss: 0.90767


100%|██████████| 28/28 [00:00<00:00, 67.13it/s]


train_loss: 0.93689, val_loss: 0.85948


100%|██████████| 28/28 [00:00<00:00, 70.69it/s]


train_loss: 0.87691, val_loss: 0.77854


100%|██████████| 28/28 [00:00<00:00, 68.69it/s]


train_loss: 0.72052, val_loss: 0.43769


100%|██████████| 28/28 [00:00<00:00, 69.17it/s]


train_loss: 0.42631, val_loss: 0.20224


100%|██████████| 28/28 [00:00<00:00, 69.75it/s]


train_loss: 0.26190, val_loss: 0.09397


100%|██████████| 28/28 [00:00<00:00, 69.70it/s]


train_loss: 0.17282, val_loss: 0.06079


100%|██████████| 28/28 [00:00<00:00, 68.66it/s]


train_loss: 0.12881, val_loss: 0.02600


100%|██████████| 28/28 [00:00<00:00, 67.95it/s]


train_loss: 0.09225, val_loss: 0.01808


100%|██████████| 28/28 [00:00<00:00, 65.06it/s]


train_loss: 0.07763, val_loss: 0.01069


100%|██████████| 28/28 [00:00<00:00, 68.70it/s]


train_loss: 0.06313, val_loss: 0.00614


100%|██████████| 28/28 [00:00<00:00, 67.29it/s]


train_loss: 0.05309, val_loss: 0.00383


100%|██████████| 28/28 [00:00<00:00, 66.48it/s]


train_loss: 0.04832, val_loss: 0.00399


100%|██████████| 28/28 [00:00<00:00, 70.75it/s]


train_loss: 0.04199, val_loss: 0.00221


100%|██████████| 28/28 [00:00<00:00, 64.95it/s]


train_loss: 0.03559, val_loss: 0.00253


100%|██████████| 28/28 [00:00<00:00, 66.48it/s]


train_loss: 0.02993, val_loss: 0.00191


100%|██████████| 28/28 [00:00<00:00, 69.32it/s]


train_loss: 0.02842, val_loss: 0.00102


100%|██████████| 28/28 [00:00<00:00, 71.35it/s]


train_loss: 0.02615, val_loss: 0.00168


100%|██████████| 28/28 [00:00<00:00, 67.35it/s]


train_loss: 0.02490, val_loss: 0.00071


100%|██████████| 28/28 [00:00<00:00, 67.29it/s]


train_loss: 0.02042, val_loss: 0.00046


100%|██████████| 28/28 [00:00<00:00, 70.18it/s]


train_loss: 0.01498, val_loss: 0.00053


100%|██████████| 28/28 [00:00<00:00, 70.89it/s]


train_loss: 0.01445, val_loss: 0.00057


100%|██████████| 28/28 [00:00<00:00, 69.41it/s]


train_loss: 0.01646, val_loss: 0.00098


100%|██████████| 28/28 [00:00<00:00, 60.21it/s]


train_loss: 0.01500, val_loss: 0.00033


100%|██████████| 28/28 [00:00<00:00, 62.89it/s]


train_loss: 0.01418, val_loss: 0.00034


100%|██████████| 28/28 [00:00<00:00, 67.31it/s]


train_loss: 0.01468, val_loss: 0.00029


100%|██████████| 28/28 [00:00<00:00, 61.46it/s]

train_loss: 0.01221, val_loss: 0.00023





In [33]:
test_loss = validate(model, data_loaders.test_loader)
test_acc = evaluate(model, data_loaders.test_loader)
val_acc = evaluate(model, data_loaders.test_loader)
train_acc = evaluate(model, data_loaders.train_loader, 20)
current_result = f'''train_size: {train_size}, train_loss: {train_loss},
                val_loss: {val_loss}, test_loss: {test_loss},
                test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
                '''
print(current_result)

train_size: 7168, train_loss: 0.012206953301626657,
                val_loss: 0.0002306888454768341, test_loss: 0.0003604226258175913,
                test_acc: 1.0, val_acc: 1.0, train_acc: 1.0
                
