In [18]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer

In [19]:
h = nn.Embedding(10, 10)

In [20]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [21]:
# hyperparameters
batch_size = 32 
block_size = 8 
learning_rate = 1e-3
device = torch.device("mps")
vocab_size = tokenizer.vocab_size
n_embed = 512
head_size = 64
n_head = 8
block_iter = 5

In [22]:
sentence = "Let's see how you can tokenize this sentence"

In [23]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [24]:
data = torch.tensor(tokenizer(text).input_ids, dtype=torch.long)

Token indices sequence length is longer than the specified maximum sequence length for this model (338025 > 1024). Running this sequence through the model will result in indexing errors


In [25]:
data.shape

torch.Size([338025])

In [26]:
n = int(0.9 * len(data))
train_data = data[:n + 2]
test_data = data[n + 2:]


# x shape: Batch_size, Block_size, n_embed

In [27]:
class Head(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size)
        self.value = nn.Linear(n_embed, head_size)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size )))
        self.n_embed = n_embed
        self.head_size = head_size

    def forward(self, x):
        k = self.key(x) # B, T, head_size
        q = self.query(x) # B, T, head_size
        v = self.value(x) # B, T, head_size
        sa = (q @ k.transpose(-2, -1)) * self.head_size ** (-0.5) 
        sa = sa.masked_fill(self.tril==0, float("-inf"))
        x = F.softmax(sa, dim=-1) @ v
        return x



In [28]:
class MultiHead(nn.Module):
    def __init__(self, n_embed, head_size, n_head):
        super().__init__()
        self.head_size = head_size
        self.n_embed = n_embed
        self.heads = [Head(n_embed, head_size) for _ in range (n_head)]

    def forward(self, x):
        b, t, h = x.shape
        output = torch.empty((b, t, 0))
        for head in self.heads:
            output = torch.cat((output, head(x)), -1)
        return output
        

In [29]:
class FFN(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.ffn = nn.Sequential(nn.Linear(n_embed, n_embed  * 4), nn.ReLU(), nn.Linear(n_embed * 4, n_embed))

    def forward(self, x):
        return self.ffn(x)
    

In [30]:
class Block(nn.Module):
    def __init__(self, n_embed, head_size, n_head):
        super().__init__()
        self.mh = MultiHead(n_embed, head_size, n_head)
        self.ln = nn.LayerNorm(n_embed)
        self.ffn = FFN(n_embed)

    def forward(self, x):
        x = x + self.mh(self.ln(x))
        x = x + self.ffn(self.ln(x))
        return x

In [31]:
class MiniLLM(nn.Module):
    def __init__(self, vocab_size, n_embed, head_size, n_head, block_iter):
        super().__init__()
        self.embedding_token = nn.Embedding(vocab_size, n_embed)
        self.embedding_pos = nn.Embedding(vocab_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, head_size, n_head) for _ in range(block_iter)])
        self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
        self.embedding_token.weight = self.lm_head.weight
        self.block_iter = block_iter
        
        

    def forward(self, x, target=None):
        b, t = x.shape # B, T
        embed = self.embedding_token(x) #B, T, C
        pos = self.embedding_pos(torch.arange(t)) #C
        x = embed + pos #B, T, C
        x = self.blocks(x) #B, T, C
        if target is not None:
            logits = self.lm_head(x) #B, T, 1
            logits = logits.view(b*t, -1)
            target = target.view(b*t)
            loss = F.cross_entropy(logits, target)
        else :
            logits = self.lm_head(x)[:, [-1], :]#B, 1
            loss = None
        return logits, loss


    def generate(self, x, num_new_token=100, context_limit=8): 
        y = x
        for _ in range(num_new_token):
            logits, _ = self(x)
            logit = torch.argmax(logits, dim=-1)
            y = torch.cat((y, logit), 1)
            if x.shape[-1] < context_limit: 
                x = torch.cat((x, logit), 1)
            else: 
                x = torch.cat((x[:, 1:], logit), 1)
        return y
    
        
        
        
        
        

In [32]:
model = MiniLLM(vocab_size, n_embed, head_size, n_head, block_iter)

In [None]:
optimizer = optim.Adam(model.parameters())

In [40]:


num_epochs = 10000

ar = torch.arange(block_size)
br = torch.arange(1, block_size+1)
losses = []
for _ in range(num_epochs): 
    idx = torch.randint(len(train_data) - block_size - 1, (batch_size, 1))
    x = train_data[idx+ar]
    y = train_data[idx+br]
    logits, loss = model(x, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss)
    if _%50==0: 
        print(sum(losses)/len(losses), _)
        losses = []
b, t = x.shape

tensor(3.3457, grad_fn=<DivBackward0>) 0
tensor(3.3613, grad_fn=<DivBackward0>) 50
tensor(3.3707, grad_fn=<DivBackward0>) 100
tensor(3.3738, grad_fn=<DivBackward0>) 150
tensor(3.4087, grad_fn=<DivBackward0>) 200
tensor(3.3156, grad_fn=<DivBackward0>) 250
tensor(3.3274, grad_fn=<DivBackward0>) 300
tensor(3.2881, grad_fn=<DivBackward0>) 350
tensor(3.3368, grad_fn=<DivBackward0>) 400
tensor(3.3335, grad_fn=<DivBackward0>) 450
tensor(3.2873, grad_fn=<DivBackward0>) 500
tensor(3.1878, grad_fn=<DivBackward0>) 550
tensor(3.3186, grad_fn=<DivBackward0>) 600
tensor(3.2683, grad_fn=<DivBackward0>) 650
tensor(3.2914, grad_fn=<DivBackward0>) 700
tensor(3.2677, grad_fn=<DivBackward0>) 750
tensor(3.3069, grad_fn=<DivBackward0>) 800
tensor(3.2453, grad_fn=<DivBackward0>) 850
tensor(3.2582, grad_fn=<DivBackward0>) 900
tensor(3.2367, grad_fn=<DivBackward0>) 950
tensor(3.2502, grad_fn=<DivBackward0>) 1000
tensor(3.2176, grad_fn=<DivBackward0>) 1050
tensor(3.2319, grad_fn=<DivBackward0>) 1100
tensor(3.22

In [41]:
y = model.generate(x, num_new_token=500, context_limit=8)


In [42]:
text = tokenizer.decode(y[0,:])

In [43]:
text

":\nMarry, my child, and and my brother,\nThoughts for my poor lord, this this is the world goes,\nThe Clifford Clifford's rigour of the statute,\nTo make him the precedent and witness witness\nThe words that he hath hath fallen by prompt prompt me to the blood of your blood,\nCurrents from thence thence,\n\nSICINIUS:\nThe gods, for your voices!\n\nMENENIUS:\nThe gods have neither neither, my lord,\nBut repetition will what what thou go?\n\nMERCUTIO:\nNay, but he's dead; and he he knows not not,\nThe ape is dead, but that's my lord.\n\nKING RICHARD III:\nThe advancement of your children children, go we\nInjurious Margaret!\n\nMERCUTIO:\nNay, but he's dead; and he he knows not not,\nThe ape is dead, but that's my lord.\n\nKING RICHARD III:\nThe advancement of your children children, go we\nInjurious Margaret!\n\nMERCUTIO:\nNay, but he's dead; and he he knows not not,\nThe ape is dead, but that's my lord.\n\nKING RICHARD III:\nThe advancement of your children children, go we\nInjurious M