In [126]:
import torch
import torch.nn as nn
import torch.nn.functional as F   
import numpy as np     

In [None]:
with open('mini_shakespeare.txt') as file:
    file_content = file.read()

alphabet = sorted(list(set(file_content)))
stoi = {char: i for i, char in enumerate(alphabet)}
itos = {i: char for i, char in enumerate(alphabet)}
encode = lambda x: [stoi[char] for char in x]
decode = lambda x: [itos[encoded_char] for encoded_char in x]
data = torch.tensor(encode(file_content), dtype=torch.long)
cutoff = int(len(data)*0.9)
train = data[:cutoff]
val = data[cutoff:]

vocab_size = len(alphabet)
block_size = 8
batch_size = 32
embed_dim = 64
num_heads = 4 ## must be divsor of embed_dim

def get_batch(split):
    data = train if split == "train" else val   
    batch_indicies = torch.randint(len(data) - block_size, (batch_size,1))
    x = torch.stack([data[i:i+block_size] for i in batch_indicies])
    y = torch.stack([data[i+1:i+block_size+1] for i in batch_indicies]) 
    return x, y

class AttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.queries = nn.Linear(embed_dim, head_size, bias = False)
        self.keys = nn.Linear(embed_dim, head_size, bias = False)
        self.value_down = nn.Linear(embed_dim, head_size, bias = False)
        self.head_size = head_size
    def forward(self, x):
        q = self.queries(x) ## (B, T, 32)
        k = self.keys(x).transpose(1,2) ## (B, 32, T)
        trans = q @ k /np.sqrt(self.head_size) ## (B, T, T)  
        censored = torch.masked_fill(trans, ~torch.tril(torch.ones((batch_size, block_size, block_size))).bool(), float("-inf"))
        softmax = torch.softmax(censored, dim=2) ## (B, Querys, Keys)
        attention = softmax @ self.value_down(x) ## (B, T, T) @ (B, T, 32)
        return attention

class MultiHeaded(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.value_up = nn.Linear(head_size*num_heads, embed_dim, bias = False)

    def forward(self, x):
        concatenated_value_downs = torch.cat([h(x) for h in self.heads], dim = -1)
        return self.value_up(concatenated_value_downs)

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embeding = nn.Embedding(vocab_size, embed_dim)
        self.multiheaded_attention = MultiHeaded(num_heads, embed_dim//num_heads)
        self.final = nn.Linear(embed_dim, vocab_size)  # final logits over vocab

    def forward(self, x, y = None):
        logits = self.token_embeding(x)
        logits = self.multiheaded_attention(logits)
        logits = self.final(logits)  # final logits over vocab
        
        loss = None
        if y != None:
            B, T, C = logits.shape
            x_reshaped = logits.view(B*T, C)
            y_reshaped = y.view(B*T)
            loss = F.cross_entropy(x_reshaped, y_reshaped)
        return logits, loss
    
    def generate(self, context, max_tokens = 100):
        for i in range(max_tokens):
            x = context[:, -block_size:]
            logits, _ = self(x) ## (B, T, C)
            logits = logits[:, -1, :] ## (B, C) last token
            probs = F.softmax(logits, dim=1)
            next_token = torch.multinomial(probs, num_samples=1) # gets next token
            context = torch.concat((context, next_token[:1]), dim=1)

        return  context



In [207]:
training_steps = 5000
model = Transformer()
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
for _ in range(training_steps):
    x, y = get_batch('train')
    logits, loss = model(x,y)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()
    if _ % 1000 == 0:
        print(f"{_}: {loss}")
print(loss)
    

0: 4.196011066436768
1000: 2.2246246337890625
2000: 2.2780117988586426
3000: 2.2658772468566895
4000: 2.130584955215454
tensor(2.3902, grad_fn=<NllLossBackward0>)


In [208]:
# start with a context of just one token, e.g. the index for "H"
start = torch.tensor([encode("HI MY NAME IS")], dtype=torch.long)

# generate 100 new tokens
out = model.generate(start, max_tokens=100)

# convert indices back to characters
generated_text = ''.join([itos[int(i)] for i in out[0]])
print(generated_text)

HI MY NAME ISTHENTERTIIf Ifithalave rd.
Werne Phalelleol wed,
Thilon.

AUMIFor, Anof her f b.

Lilch wor fald th 
