<a href="https://colab.research.google.com/github/niyathikukkapalli/mini-llm/blob/main/Shakespeare_stuff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch # imports pytorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)
        self.num_heads = num_heads
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape[0], x.shape[1], x.shape[2]
        num_heads, d_head = self.num_heads, d_model // self.num_heads
        Q = self.Wq(x).reshape(batch_size, seq_len, num_heads, d_head).transpose(1, 2) #(batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_head)
        K = self.Wk(x).reshape(batch_size, seq_len, num_heads, d_head).transpose(1, 2) #(batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_head)
        V = self.Wv(x).reshape(batch_size, seq_len, num_heads, d_head).transpose(1, 2) #(batch_size, seq_len, d_model) -> (batch_size, num_heads, seq_len, d_head)

        mask = torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool)
        mask = torch.tril(mask)

        attn_scores = Q @ K.transpose(2, 3) / d_head**0.5
        attn_scores = torch.masked_fill(attn_scores, ~mask, -1e9)
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = attn_weights @ V

        context = context.transpose(1, 2).reshape(batch_size, seq_len, d_model)
        x = self.Wo(context)
        return x



class FFN(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W1 = nn.Linear(d_model, 4 * d_model)
        self.W2 = nn.Linear(4 * d_model, d_model)
    def forward(self, x):
        x = F.gelu(self.W1(x))
        x = self.W2(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.mhsa = MultiHeadSelfAttention(num_heads, d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model)
    def forward(self, x):
        x = x + self.mhsa(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, num_heads, seq_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(seq_len, d_model)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(num_heads, d_model) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(d_model)
        self.out_proj = nn.Linear(d_model, vocab_size)
    def forward(self, x):
        seq_len = x.shape[1]
        embs = self.embedding(x)
        pos_embs = self.pos_embedding(torch.arange(seq_len, device=x.device, dtype=torch.long))
        x = embs + pos_embs
        for transformer_block in self.transformer_blocks:
            x = transformer_block(x)
        x = self.ln(x)
        x = self.out_proj(x)
        return x

In [None]:
from torch.utils.data import Dataset, DataLoader

vocab_size = 65
d_model = 96
n_layers = 4
num_heads = 4
seq_len = 32
batch_size = 16
lr=1e-4
num_epochs = 20
device = "cuda" if torch.cuda.is_available() else "cpu"

with open("/content/sample_data/shakespeare.txt") as f:
    text_corpus = f.read()

char2token = {c:i for i, c in enumerate(sorted(set(text_corpus)))}
token2char = {i:c for c, i in char2token.items()}

def tokenize(string):
    return [char2token[c] for c in string]
def decode(tokens):
    return "".join([token2char[token] for token in tokens])

text_tokens = tokenize(text_corpus)

labels = text_tokens[1:]
inputs = text_tokens[:-1]
trunc_len = len(inputs) - len(inputs) % seq_len
labels = torch.tensor(labels[:trunc_len]).reshape(-1, seq_len)
inputs = torch.tensor(inputs[:trunc_len]).reshape(-1, seq_len)

class ShakespeareDataset(Dataset):
    def __init__(self, inputs, labels):
        super().__init__()
        self.inputs = inputs
        self.labels = labels
    def __len__(self):
        return len(self.inputs)
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

dataset = ShakespeareDataset(inputs, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = Transformer(vocab_size, d_model, n_layers, num_heads, seq_len).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

@torch.no_grad()
def generate(prompt, num_tokens):
    tokens = tokenize(prompt)
    tokens = torch.tensor([tokens], device=device) #(batch_dim, len_of_seq)
    for _ in range(num_tokens):
        logits = model(tokens[:, -seq_len:]) #(batch_dim, seq_len, vocab_size)
        probs = F.softmax(logits, dim=-1)[:, -1, :] #(batch_dim, seq_len, vocab_size) -> (batch_dim, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1) #(batch_dim, 1)
        tokens = torch.cat([tokens, next_token], dim=-1)
    return decode(tokens[0].cpu().tolist())


for epoch in range(num_epochs):
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx % 500 == 0:
            print("\n" + generate(prompt="\n", num_tokens=1000) + "\n")
        inputs, labels = batch[0].to(device), batch[1].to(device)
        logits = model(inputs) #inputs: (batch_size, seq_len) -> (batch_size, seq_len, vocab_size); labels: (batch_size, seq_len)
        loss = loss_fn(logits.reshape(-1, vocab_size), labels.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch}, batch {batch_idx}: loss={loss.item()}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m



DUKE OF ARH:
Our, but heaves, befole house?

HENRY:
Nay, Proclediess' no death ble! leef; Sorter.

LUCIO:
Govice an what me?

OKENE:
But of God TyMEan:
Richard life, Parition in!

SIe 'Tis are sluck.

HROMER:
Ayoring find fortck, beye broved jealies
For the vain of my friend; preset'
And of that of your been no chince,
This both within thy father streath!

Servant Lord, that, I two with I rep,
HoLe matters I till coes of sprose affalled vioullit him sweet st the kneebugh
I will men: at shall feel to mysent
Where any will her propleing
And the ablack mader'd but I be hope turt entil thy caush chasbenger uple a dive
To were'! thou had be then pours--

GLUCKENt:
To have:
I bein me now, all; I what worth he dew.

TRANxtranced her I have not all kins, far
all stripproud that fights pind
Teccard appartience.

SICINIU:
But my god came man piration.

First Might, can him man thy cast of things;
Of, those heards he may casingtl