In [1]:
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
with open('shakespear.txt', 'r') as f:
    text = f.read()

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
encode = lambda cs: [stoi[c] for c in cs]
decode = lambda ixs: ''.join([itos[i] for i in ixs])

decode(encode(text[:100]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [4]:
encoded_text = torch.tensor(encode(text), dtype=torch.long)
trainsize = int(len(encoded_text)*0.9)
trainset = encoded_text[:trainsize]
valset = encoded_text[trainsize:]
trainset.shape[0], valset.shape[0]

(1003854, 111540)

In [24]:
context_length = 8
batch_size = 16
emb_dim = 32

def get_batch(split):
    data = trainset if split == 'train' else valset
    label_idx = torch.randint(context_length, len(data), (batch_size,))
    x = torch.stack([trainset[i-context_length:i] for i in label_idx])
    y = torch.stack([trainset[i+1-context_length:i+1] for i in label_idx])
    return x, y

x, y = get_batch('train')
x[1], y[1]

(tensor([56, 53, 52, 45, 57,  1, 61, 47]),
 tensor([53, 52, 45, 57,  1, 61, 47, 58]))

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, masked=True):
        super().__init__()
        self.n_heads = n_heads
        self.masked = masked
        self.q = nn.Linear(emb_dim, emb_dim, bias=False) #compute for all heads in parrallel, n_heads * head_size = emb_dim
        self.k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.v = nn.Linear(emb_dim, emb_dim, bias=False)
        self.proj = nn.Linear(emb_dim, emb_dim)
        if self.masked:
            self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape
        Q, K, V = self.q(x), self.k(x), self.v(x)
        Q = Q.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #self attention needs to be done individually for each head
        K = K.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2) #transpose needed so we do self attention in the given char context not between the different heads
        V = V.view(B, T, self.n_heads, C//self.n_heads).transpose(1,2)
        wei = Q @ K.transpose(-1, -2)
        if self.masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = wei / C**-0.5
        out = F.softmax(wei, dim=-1) @ V
        out = out.transpose(1,2).contiguous().view(B, T, C)
        return self.proj(out)

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff = nn.Sequential(
        nn.Linear(emb_dim, 4*emb_dim), nn.ReLU(),
        nn.Linear(4*emb_dim, emb_dim)
        )

    def forward(self, x):
        return self.ff(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_head):
        super().__init__()
        self.mha = MultiHeadAttention(n_head)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ff = FeedForward()
        self.ln2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        out = x + self.ff(self.ln2(x))
        return out

In [7]:
class GPT(nn.Module):
    def __init__(self, vocab_size, n_heads=4, n_blocks=2): 
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, emb_dim)
        self.position_embedding = nn.Embedding(context_length, emb_dim)
        self.blocks = nn.Sequential(*[TransformerBlock(n_heads) for _ in range(n_blocks)])
        self.lm_head = nn.Linear(emb_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embs = self.token_embedding(idx) # B, T -> B, T, C
        pos_embs = self.position_embedding(torch.arange(T))
        out = token_embs + pos_embs
        out = self.blocks(out)
        logits = self.lm_head(out)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, num_tokens=10_000):
        for _ in range(num_tokens):
            context = idx[:, -context_length:]
            logits, _ = self(context)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx


In [25]:
@torch.no_grad()
def estimate_loss(model, eval_iters, batch_fn):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = batch_fn(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def train_gpt(model, optimizer, batch_fn, train_steps=100_000, eval_iters=200):
    for step in range(train_steps):
        # forward pass
        x, y = batch_fn('train')
        logits, loss = model(x, y)

        # backward pass
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        if step % 10_000 == 0:
            losses = estimate_loss(model, eval_iters, batch_fn) 
            train_loss = losses['train']
            val_loss = losses['val']
            print(f"{step}/{train_steps} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

In [30]:
char_gpt = GPT(len(chars))
optimizer = torch.optim.AdamW(char_gpt.parameters(), lr=1e-3)

train_gpt(char_gpt, optimizer, get_batch, train_steps=10_000)

0/10000 - Train Loss: 4.4618, Val Loss: 4.4680


In [None]:
generated_idx = char_gpt.generate(torch.zeros((1,1), dtype=torch.long))[0].tolist()
print(decode(generated_idx))


Ange for at to the one a my cousin trio merch gough see it morn and imphiteful reaters, and where commoundr you fatherefore you new,'s drandanior puty.

MEN
Nay;
I'll do with Byeth!
I would not patituitty!

COMINIUS:
Tut, for it war?

Five.

As to was that it:
First by in wrunked:
Etsint there than Romeore conten it rooulive to they eat, a siter call we womarnt, Cluring let in imble durm!
And thou was to

DATLUCKINGHAM:
A fill the marren my are mothat truate it, you:
And,
To you
't be your
and
Troumplours.
O force far, the father iteor hie with me;
And for our her sould you,
And train. You, that tour should my to then some the hath maken belinge edowns you light me.

PERCHIO:
Be own fear'd?

AUMENAMISALGUS:
Therefate, lering humas sirest sentof Clawly you may ere done to your her a cormonts:
If a proput thou handn?
Theirs; his too.

YONUS:
Wherefore my her!

GONUS:
Her ment twom to outed Soed on
mevibus!
Prother hath her may fell'd
So judge, not.

GLOUCESTER:
But were Humes; heat As p

## GPT Calculator

In [43]:
optoi = {'+': 10, '-':11, '*':12, '/': 13, '=': 14, '<END>':15}
itoop = {i: op for op, i in optoi.items()}
context_length = 18

# we sample two random numbers as input and their sum as the label
def get_mathproblem(split='train'):
    operation = '+' #TODO: Expand to all four basic operations
    max_digits = 4 if split == 'train' else 5

    # sample input data
    first_nums = torch.randint(0, 9, (batch_size, max_digits), dtype=torch.long)
    sum_symbols = torch.ones((batch_size,1), dtype=torch.long) * optoi[operation]
    second_nums = torch.randint(0, 9, (batch_size, max_digits), dtype=torch.long)
    equals_symbols = torch.ones((batch_size,1), dtype=torch.long) * optoi['=']
    x = torch.cat((first_nums, sum_symbols, second_nums, equals_symbols), dim=1)

    # sample output data
    masked_labels = -100 * torch.ones(batch_size, x.shape[1]-1, dtype=torch.long) # mask loss for first n inputs (-100 gets ignored by pytorch)
    first_nums = torch.tensor([int(''.join(map(str, num.tolist()))) for num in first_nums])
    second_nums = torch.tensor([int(''.join(map(str, num.tolist()))) for num in second_nums])

    if operation == '+':
        results = first_nums + second_nums
    labels = [[int(digit) for digit in reversed(str(result.item()))] for result in results]
    labels = torch.tensor(list(zip(*itertools.zip_longest(*labels, fillvalue=0))), dtype=torch.long)
    end_tokens = torch.ones((batch_size, 1), dtype=torch.long) * optoi['<END>']
    x = torch.cat((x, labels), dim=1)
    y = torch.cat((masked_labels, labels, end_tokens), dim=1)
    
    return x, y

x, y = get_mathproblem('train')
print(x.shape, y.shape)
x[1], y[1]

torch.Size([16, 15]) torch.Size([16, 15])


(tensor([ 6,  0,  3,  1, 10,  6,  2,  3,  1, 14,  2,  6,  2,  2,  1]),
 tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100,    2,    6,    2,
            2,    1,   15]))

In [44]:
adder = GPT(16) # numbers from 0 to 15
optimizer = torch.optim.AdamW(adder.parameters(), lr=1e-3)

train_gpt(adder, optimizer, get_mathproblem)

0/100000 - Train Loss: 2.8650, Val Loss: 3.0690
10000/100000 - Train Loss: 0.0120, Val Loss: 17.2601
20000/100000 - Train Loss: 0.0014, Val Loss: 15.5410
30000/100000 - Train Loss: 0.0001, Val Loss: 17.2003
40000/100000 - Train Loss: 0.0000, Val Loss: 16.9990


KeyboardInterrupt: 

In [63]:
example = torch.tensor([1, 0, 0, 0, 10, 1, 0, 0, 0, 14], dtype=torch.long).view(1, -1)
out = adder.generate(example, num_tokens=6)
out[0, 10:16]

tensor([ 0,  0,  0,  2,  0, 15])

tensor([ 8,  4,  2,  4,  0, 15])