<a href="https://colab.research.google.com/github/niral28/TransformersPuzzles/blob/main/GPTFromScratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [2]:
! wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-12-27 21:34:23--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-12-27 21:34:23 (135 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7df1a0120370>

In [4]:
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [5]:
def get_batch(split: str, context_length:int, batch_size: int, device:str):
  data = train_data if split == 'train' else val_data
  starting_idx = torch.randint(0, len(data)-context_length, (batch_size,))
  x = torch.stack([data[idx:idx+context_length] for idx in starting_idx])
  y = torch.stack([data[idx+1:idx+1+context_length] for idx in starting_idx])
  x, y = x.to(device), y.to(device)
  return x,y

In [38]:
dropout = 0.2
class SelfAttention(nn.Module):
  def __init__(self, embedding_dim:int, attention_dim: int):
    super().__init__()
    self.key_gen = nn.Linear(embedding_dim, attention_dim, bias=False) # Embedding Dim x Attention Dim
    self.query_gen = nn.Linear(embedding_dim, attention_dim, bias=False)
    self.value_gen = nn.Linear(embedding_dim, attention_dim, bias=False)
    self.dropout = nn.Dropout(dropout)

  def forward(self, embedded):
    K = self.key_gen(embedded) # Batch Size X Context Length X Embedding
    Q = self.query_gen(embedded)
    V = self.value_gen(embedded)
    B,T,C = embedded.shape
    scores = Q @ torch.transpose(K, 1, 2) # Batch Size X Context Length X Context Length
    # context_length, attention_dim = K.shape[1], K.shape[2] # Assuming Embedding Dim and Attention Dim are the same
    scores =  scores * (C**-0.5)
    lower_triangular = torch.tril(torch.ones(T, T)) # T x T
    lower_triangular = lower_triangular.to(device)
    scores = scores.masked_fill(lower_triangular[:T, :T] == 0, float('-inf'))
    scores = nn.functional.softmax(scores, dim=-1)
    scores = self.dropout(scores)
    return scores @ V # Tx A matrix

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads: int, embed_size:int, attention_dim: int):
    super().__init__()
    self.head_size = attention_dim // num_heads
    self.heads = nn.ModuleList([SelfAttention(embed_size, self.head_size) for i in range(num_heads)])
    self.proj = nn.Linear(embed_size, embed_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, embedded):
    # Merge all the attention heads
    out = torch.cat([self.heads[i](embedded) for i in range(len(self.heads))], dim=-1)
    out = self.dropout(self.proj(out))
    return out


class FeedForwardNetwork(nn.Module):
  def __init__(self, embed_size, forward_expansion):
    super().__init__()
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size, embed_size)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, embedding):
    return self.dropout(self.feed_forward(embedding))


class Transformer(nn.Module):
  def __init__(self, model_dim: int, num_heads:int):
    super().__init__()
    self.masked_multihead_attn = MultiHeadAttention(num_heads, model_dim, model_dim)
    # The normal multihead attention (un masked) is only used in the case of an encoder-decoder architecture
    #self.multihead_attn = self.MultiHeadAttention(num_heads, model_dim, model_dim, mask=False)
    self.norm1 = nn.LayerNorm(model_dim)
    #self.norm2 = nn.LayerNorm(model_dim)
    self.norm3 = nn.LayerNorm(model_dim)
    self.feed_forward = FeedForwardNetwork(model_dim, 4)

  def forward(self, embedding):
    embedded = embedding + self.masked_multihead_attn(self.norm1(embedding))
    #embedded = embedded + self.multihead_attn(self.norm2(embedded))
    embedded = embedded + self.feed_forward(self.norm3(embedded))
    return embedded


class Decoder(nn.Module):
  def __init__(self, vocab_size: int, context_length:int, model_dim: int, num_heads: int, num_blocks: int):
    super().__init__()
    self.context_length = context_length
    self.word_embedding = nn.Embedding(vocab_size, model_dim)
    self.positional_embedding = nn.Embedding(context_length, model_dim)
    self.transformer_block = nn.Sequential(*[
        Transformer(num_heads=num_heads, model_dim=model_dim) for _ in range(num_blocks)
    ])
    self.linear = nn.Linear(model_dim, vocab_size)
    self.finalLayerNorm = nn.LayerNorm(model_dim) # this is recommended by researchers as it results in a more stable training

  def forward(self, context, targets=None):
    T = context.shape[1] # B x T x 1 (tokenized input)
    positions = torch.arange(T, device=device) # create a list from 0 -> context length
    embedding = self.word_embedding(context)+self.positional_embedding(positions)
    output = self.transformer_block(embedding)
    logits = self.linear(self.finalLayerNorm(output))
    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      try:
        loss = nn.functional.cross_entropy(logits.view(B*T, C), targets.view(B*T))
      except Exception as e:
        # print(logits.shape)
        # print(logits)
        raise e
    return logits, loss

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, context_length, batch_size, device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [44]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200
context_length = 64
batch_size = 16
model_dim = 128
num_heads = 8
num_blocks = 8

model = Decoder(context_length=context_length, vocab_size=len(stoi), model_dim=model_dim, num_heads=num_heads, num_blocks=num_blocks)
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

1.608257 M parameters


In [45]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(m)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f"{name} grad norm: {param.grad.norm()}")
        print()
    # sample a batch of data
    xb, yb = get_batch('train', context_length, batch_size, device)

    # evaluate the loss
    logits, loss = model(xb, yb)
    # print(loss)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # for param in model.parameters():
    #   print(param.grad)
    optimizer.step()


step 0: train loss 4.3821, val loss 4.3845

step 100: train loss 2.5277, val loss 2.5241

step 200: train loss 2.4419, val loss 2.4385

step 300: train loss 2.3572, val loss 2.3688

step 400: train loss 2.2714, val loss 2.2875

step 500: train loss 2.1925, val loss 2.2176

step 600: train loss 2.1276, val loss 2.1609

step 700: train loss 2.0725, val loss 2.1163

step 800: train loss 2.0085, val loss 2.0629

step 900: train loss 1.9676, val loss 2.0328

step 1000: train loss 1.9120, val loss 2.0100

step 1100: train loss 1.8871, val loss 1.9677

step 1200: train loss 1.8527, val loss 1.9527

step 1300: train loss 1.8329, val loss 1.9496

step 1400: train loss 1.7998, val loss 1.9266

step 1500: train loss 1.7791, val loss 1.9050

step 1600: train loss 1.7666, val loss 1.9083

step 1700: train loss 1.7391, val loss 1.8806

step 1800: train loss 1.7314, val loss 1.8711

step 1900: train loss 1.7041, val loss 1.8469

step 2000: train loss 1.6967, val loss 1.8357

step 2100: train loss 1.6

In [46]:
def generate(m, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -context_length:]
        # get the predictions
        logits, loss = m(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        # apply softmax to get probabilities
        probs = nn.functional.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx


In [47]:
def generate(model, context, max_new_tokens):
  generator = torch.manual_seed(1337)
  initial_state = generator.get_state()
  res = ['']
  for _ in range(max_new_tokens):
    context = context[:, -context_length:]
    logits, loss = model(context) # B x T x Vocab Size
    probabilities = nn.functional.softmax(logits[:, -1, :], dim=-1)
    next_char = torch.multinomial(probabilities, num_samples=1)
    context = torch.cat((context, next_char), dim=-1)
    res.append(itos[next_char.item()])
  return ''.join(res)

In [48]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(generate(m,context, max_new_tokens=2000))


SICINIUS:
Me if enriarth, her well
Toice what I wonvions that I mothing distrared,
That are Cluition, my Volverian's
Putheace? is exadement have made to is my read
As four should to the sest.
The bist of heaven here, being, and not of yours.

AMILYCUS:
I'll do this may derevis
Exetter thyself to-mance summedis wait his despince wit.

VIReIENIA:
Would not replaint, full him him
On.
Pray, in lePt, with the mell was anot!

COMILLONIUS:
What's it ned thou marche dove;
What then on true least up this name he of the braign.
Where will not night stip of bo the bear me twine.

BENVOLIOLIO:
I tage! ime it men
Mean of one see in thousabn: way, I tender this?
Henry, nor curse, my law will.

AUTOMPELY:
So mereiIGt shame berace's in switne hand
Onless of your of surse!---'

RIVERS:
I wound in like tempt that; whwat, and hath infixe you
Do swant a was in sins the kind's fail'd?

Second Cabtit's-grain, loves,
O, but for help's; he will I there to my have
To rirslime many have that falmh thand knowr


In [10]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(generate(m, context, max_new_tokens=2000)[0].tolist()))