In [1]:
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

emb_dim = 32
max_digits = 6 # maximum number of digits in the two numbers we are adding
context_length = 2*max_digits + max_digits+1 + 2 # number1 + number2 = number3 (which can have max_digits+1)
batch_size = 32

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, masked=True):
        super().__init__()
        self.n_heads = n_heads
        self.masked = masked
        self.q = nn.Linear(emb_dim, emb_dim, bias=False) #compute for all heads in parrallel, n_heads * head_size = emb_dim
        self.k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.v = nn.Linear(emb_dim, emb_dim, bias=False)
        self.proj = nn.Linear(emb_dim, emb_dim)
        if self.masked:
            self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape
        Q, K, V = self.q(x), self.k(x), self.v(x)
        Q = Q.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #self attention needs to be done individually for each head
        K = K.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #transpose needed so we do self attention in the given char context not between the different heads
        V = V.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2)
        wei = Q @ K.transpose(-1, -2)
        if self.masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei / C**-0.5
        out = F.softmax(wei, dim=-1) @ V
        out = out.transpose(1,2).contiguous().view(B, T, C)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff = nn.Sequential(
        nn.Linear(emb_dim, 4*emb_dim), nn.ReLU(),
        nn.Linear(4*emb_dim, emb_dim)
        )

    def forward(self, x):
        return self.ff(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_head):
        super().__init__()
        self.mha = MultiHeadAttention(n_head)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward()
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        out = x + self.ff(self.ln2(x))
        return out


In [2]:
class GPT(nn.Module):
    def __init__(self, vocab_size, n_heads=4, n_blocks=2): 
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding = nn.Embedding(context_length, emb_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(n_heads) for _ in range(n_blocks)])
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embs = self.token_embedding(idx) # B, T -> B, T, C
        pos_embs = self.position_embedding(torch.arange(T))
        out = token_embs + pos_embs
        out = self.blocks(out)
        logits = self.lm_head(out)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def calculate(self, idx, max_tokens=100):
        for _ in range(max_tokens): # just to make sure we dont run into infinite loop if model fails to end its output
            context = idx[:, -context_length:]
            logits, _ = self(context)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.argmax(probs, keepdim=True) # use argmax instead of multinomial, there is only one correct answer
            idx = torch.cat((idx, next_token), dim=1)
            if next_token == 15:
                break
        return idx

In [3]:
optoi = {'+': 10, '-':11, '*':12, '/': 13, '=': 14, '<END>':15}
itoop = {i: op for op, i in optoi.items()}

# we sample two random numbers as input and their sum as the label
def sample_mathproblems(num_problems): 
    operation = '+' #TODO: Expand to all four basic operations

    # sample input data
    first_nums = torch.randint(0, 9, (num_problems, max_digits), dtype=torch.long)
    sum_symbols = torch.ones((num_problems,1), dtype=torch.long) * optoi[operation]
    second_nums = torch.randint(0, 9, (num_problems, max_digits), dtype=torch.long)
    equals_symbols = torch.ones((num_problems,1), dtype=torch.long) * optoi['=']
    x = torch.cat((first_nums, sum_symbols, second_nums, equals_symbols), dim=1)

    # sample output data
    masked_labels = -100 * torch.ones(num_problems, x.shape[1]-1, dtype=torch.long) # mask loss for first n inputs (-100 gets ignored by pytorch)
    first_nums = torch.tensor([int(''.join(map(str, num.tolist()))) for num in first_nums])
    second_nums = torch.tensor([int(''.join(map(str, num.tolist()))) for num in second_nums])

    if operation == '+':
        results = first_nums + second_nums
    labels = [[int(digit) for digit in reversed(str(result.item()))] for result in results]
    labels = torch.tensor(list(zip(*itertools.zip_longest(*labels, fillvalue=0))), dtype=torch.long)
    end_tokens = torch.ones((num_problems, 1), dtype=torch.long) * optoi['<END>']
    x = torch.cat((x, labels), dim=1)
    y = torch.cat((masked_labels, labels, end_tokens), dim=1)
    
    return x, y

n_samples = 1_000_000
x, y = sample_mathproblems(n_samples)
print(x.shape, y.shape)

torch.Size([1000000, 21]) torch.Size([1000000, 21])


In [4]:
train_size = int(0.8 * n_samples)
val_size = int(0.9 * n_samples)
x_train, y_train = x[:train_size], y[:train_size]
x_val, y_val = x[train_size:val_size], y[train_size:val_size]
x_test, y_test = x[val_size:], y[val_size:]
x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape

(torch.Size([800000, 21]),
 torch.Size([800000, 21]),
 torch.Size([100000, 21]),
 torch.Size([100000, 21]),
 torch.Size([100000, 21]),
 torch.Size([100000, 21]))

In [5]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            if split == 'train':
                idx = torch.randint(0, len(x_train), (batch_size, ))
                X, Y = x_train[idx], y_train[idx]
            elif split == 'val':
                idx = torch.randint(0, len(x_val), (batch_size, ))
                X, Y = x_val[idx], y_val[idx]
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_gpt(model, optimizer, train_steps=100_000, eval_iters=200):
    for step in range(train_steps):
        # forward pass
        idx = torch.randint(0, len(x_train), (batch_size,))
        x, y = x_train[idx], y_train[idx]
        _, loss = model(x, y)

        # backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if step % 10_000 == 0:
            losses = estimate_loss(model, eval_iters) 
            train_loss = losses['train']
            val_loss = losses['val']
            print(f"{step}/{train_steps} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [9]:
model = GPT(16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

train_gpt(model, optimizer, train_steps=40_000)

0/40000 - Train Loss: 3.1023, Val Loss: 3.0984
10000/40000 - Train Loss: 0.0022, Val Loss: 0.0023
20000/40000 - Train Loss: 0.0005, Val Loss: 0.0005
30000/40000 - Train Loss: 0.0001, Val Loss: 0.0001


In [12]:
def decode(x):
    x = x[0]
    assert x[-1].item() == optoi['<END>'], "Model did not end calculation with <END> token, result is wrong."
    num1 = ''.join(map(str, x[:max_digits].tolist()))
    op = itoop[x[max_digits].item()]
    num2 = ''.join(map(str, x[max_digits+1:2*max_digits+1].tolist()))
    equals = itoop[x[2*max_digits+1].item()]
    pred_res = ''.join(map(str, reversed(x[2*max_digits+2:-1].tolist())))
    pred_res = pred_res.lstrip('0')
    real_res = int(num1) + int(num2)
    return " ".join([num1, op, num2, equals, str(real_res)]), pred_res

for example in x_test[:10]:
    decoded_str, pred_res = decode(model.calculate(example[:14].view(1, -1)))
    print(f"{decoded_str}, predicted result is {pred_res}")

201014 + 185545 = 386559, predicted result is 386559
780634 + 852367 = 1633001, predicted result is 1633001
228510 + 372265 = 600775, predicted result is 600775
258532 + 150206 = 408738, predicted result is 408738
837800 + 761450 = 1599250, predicted result is 1599250
274460 + 427124 = 701584, predicted result is 701584
006353 + 720134 = 726487, predicted result is 726487
418111 + 670353 = 1088464, predicted result is 1088464
055536 + 682646 = 738182, predicted result is 738182
034806 + 447431 = 482237, predicted result is 482237
