In [30]:
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

emb_dim = 32
max_digits = 6 # maximum number of digits in the two numbers we are adding/mutliplying/dividing
context_length = 2*max_digits + 2*max_digits + 2 # number1 * number2 = number3 (which can have 2*max_digits)
batch_size = 32

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, masked=True):
        super().__init__()
        self.n_heads = n_heads
        self.masked = masked
        self.q = nn.Linear(emb_dim, emb_dim, bias=False) #compute for all heads in parrallel, n_heads * head_size = emb_dim
        self.k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.v = nn.Linear(emb_dim, emb_dim, bias=False)
        self.proj = nn.Linear(emb_dim, emb_dim)
        if self.masked:
            self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape
        Q, K, V = self.q(x), self.k(x), self.v(x)
        Q = Q.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #self attention needs to be done individually for each head
        K = K.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #transpose needed so we do self attention in the given char context not between the different heads
        V = V.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2)
        wei = Q @ K.transpose(-1, -2)
        if self.masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei / C**-0.5
        out = F.softmax(wei, dim=-1) @ V
        out = out.transpose(1,2).contiguous().view(B, T, C)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff = nn.Sequential(
        nn.Linear(emb_dim, 4*emb_dim), nn.ReLU(),
        nn.Linear(4*emb_dim, emb_dim)
        )

    def forward(self, x):
        return self.ff(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_head):
        super().__init__()
        self.mha = MultiHeadAttention(n_head)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward()
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        out = x + self.ff(self.ln2(x))
        return out


In [2]:
class GPT(nn.Module):
    def __init__(self, vocab_size, n_heads=4, n_blocks=2): 
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding = nn.Embedding(context_length, emb_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(n_heads) for _ in range(n_blocks)])
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embs = self.token_embedding(idx) # B, T -> B, T, C
        pos_embs = self.position_embedding(torch.arange(T))
        out = token_embs + pos_embs
        out = self.blocks(out)
        logits = self.lm_head(out)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def calculate(self, idx, max_tokens=100):
        for _ in range(max_tokens): # just to make sure we dont run into infinite loop if model fails to end its output
            context = idx[:, -context_length:]
            logits, _ = self(context)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.argmax(probs, keepdim=True) # use argmax instead of multinomial, there is only one correct answer
            idx = torch.cat((idx, next_token), dim=1)
            if next_token == 15:
                break
        return idx

In [51]:
stoi = {'+': 10, '*':11, '/': 12, '=': 14, '<END>':15}
itos= {i: op for op, i in stoi.items()}

def encode(num1, num2, op, res=None):
    # encode symbols
    op_enc = torch.tensor(stoi[op], dtype=torch.long).view(1)
    equals_enc = torch.tensor(stoi['='], dtype=torch.long).view(1)

    def _encode_num(num, num_digits, reverse=False):
        if reverse:
            out = torch.tensor([int(digit) for digit in reversed(str(num).zfill(num_digits))], dtype=torch.long)
        else:
            out = torch.tensor([int(digit) for digit in str(num).zfill(num_digits)], dtype=torch.long)
        return out

    # encode input numbers (left pad with zeros until max_digits)
    num1_enc = _encode_num(num1, max_digits)
    num2_enc = _encode_num(num2, max_digits)

    if res == None:
        out = torch.cat([num1_enc, op_enc, num2_enc, equals_enc])
    else:
        res_enc = _encode_num(res, 2*max_digits, True)
        out = torch.cat([num1_enc, op_enc, num2_enc, equals_enc, res_enc])

    return out

def decode(x):
    out = []
    for idx in x:
        if idx < 10: out.append(str(idx.item())) # if its a digit just add it as a str
        elif idx == stoi['<END>']: break # END token means we are done
        else: out.append(itos[idx.item()]) # otherwise encode op

    return "".join(out)

# we sample two random numbers as input and their sum as the label
def sample_mathproblems(num_problems): 
    ops = torch.randint(10, 12, (num_problems, ), dtype=torch.long)
    all_nums = torch.randint(0, 10**(max_digits)-1, (num_problems, 2), dtype=torch.long)
    
    x = torch.zeros((num_problems, context_length), dtype=torch.long)

    for i, (nums, op) in enumerate(zip(all_nums, ops)):
        num1, num2 = nums[0].item(), nums[1].item()
        op_c = itos[op.item()]
        match op_c:
            case '+':
                res = num1 + num2
            case '*':
                res = num1 * num2
            case '/':
                res = num1 // num2
        x[i] = encode(num1, num2, op_c, res)

    input_size = 2*max_digits+2
    masked_loss = -100 * torch.ones((num_problems, input_size-1), dtype=torch.long)
    end_token = stoi['<END>'] * torch.ones((num_problems, 1), dtype=torch.long)
    y = torch.cat([masked_loss, x[:, input_size:], end_token], dim=1)
    return x, y

n_samples = 1_000_000
x, y = sample_mathproblems(n_samples)
x[0], y[0]

(tensor([ 2,  5,  3,  9,  0,  9, 10,  9,  1,  6,  3,  3,  6, 14,  5,  4,  2,  0,
          7,  1,  1,  0,  0,  0,  0,  0]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100,    5,    4,    2,    0,    7,    1,    1,    0,    0,    0,    0,
            0,   15]))

In [52]:
train_size = int(0.8 * n_samples)
val_size = int(0.9 * n_samples)
x_train, y_train = x[:train_size], y[:train_size]
x_val, y_val = x[train_size:val_size], y[train_size:val_size]
x_test, y_test = x[val_size:], y[val_size:]
x_train.shape, y_train.shape, x_val.shape, y_val.shape, x_test.shape, y_test.shape

(torch.Size([800000, 26]),
 torch.Size([800000, 26]),
 torch.Size([100000, 26]),
 torch.Size([100000, 26]),
 torch.Size([100000, 26]),
 torch.Size([100000, 26]))

In [53]:
@torch.no_grad()
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            if split == 'train':
                idx = torch.randint(0, len(x_train), (batch_size, ))
                X, Y = x_train[idx], y_train[idx]
            elif split == 'val':
                idx = torch.randint(0, len(x_val), (batch_size, ))
                X, Y = x_val[idx], y_val[idx]
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_gpt(model, optimizer, train_steps=100_000, eval_iters=200):
    for step in range(train_steps):
        # forward pass
        idx = torch.randint(0, len(x_train), (batch_size,))
        x, y = x_train[idx], y_train[idx]
        _, loss = model(x, y)

        # backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if step % 10_000 == 0:
            losses = estimate_loss(model, eval_iters) 
            train_loss = losses['train']
            val_loss = losses['val']
            print(f"{step}/{train_steps} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [54]:
model = GPT(16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

train_gpt(model, optimizer, train_steps=100_000)

0/100000 - Train Loss: 2.9064, Val Loss: 2.9089
10000/100000 - Train Loss: 1.0955, Val Loss: 1.0995
20000/100000 - Train Loss: 0.9225, Val Loss: 0.9389
30000/100000 - Train Loss: 0.8842, Val Loss: 0.9027
40000/100000 - Train Loss: 0.8825, Val Loss: 0.8954
50000/100000 - Train Loss: 0.8562, Val Loss: 0.8997


KeyboardInterrupt: 

In [62]:
num1 = 100000
num2 = 100000
prob1 = encode(num1, num2, '+').view(1, -1)
res_enc = model.calculate(prob1)[0]
decode(res_enc)

'100000+100000=0000120000000'