<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/LLM_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# --- 1. Tokenization and Embedding ---
# A very simple tokenizer for demonstration purposes.
# In a real-world scenario, this would be a more complex BPE or WordPiece tokenizer.
class SimpleTokenizer:
    def __init__(self, vocab):
        # Create a mapping from token to integer ID and vice versa.
        self.vocab = vocab
        self.token_to_id = {token: i for i, token in enumerate(vocab)}
        self.id_to_token = {i: token for i, token in enumerate(vocab)}

    def encode(self, text):
        # Convert a string of text into a list of token IDs.
        tokens = text.split()
        return [self.token_to_id.get(token, self.token_to_id['<unk>']) for token in tokens]

    def decode(self, ids):
        # Convert a list of token IDs back into a string.
        return " ".join([self.id_to_token.get(id, '<unk>') for id in ids])

# --- 2. Token and Positional Embedding ---
# This class combines token embeddings and positional embeddings.
class TokenAndPositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len=512):
        super().__init__()
        # Token embedding layer: a lookup table for each token ID.
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Positional embedding layer: a lookup table for each position.
        # This injects information about the order of tokens in the sequence.
        self.positional_embedding = nn.Embedding(max_len, d_model)

    def forward(self, token_ids):
        # Get the token embeddings.
        token_embs = self.token_embedding(token_ids)

        # Create position IDs from 0 to sequence length - 1.
        position_ids = torch.arange(token_ids.size(1), device=token_ids.device)
        position_embs = self.positional_embedding(position_ids)

        # Add the two embeddings together. This is a crucial step in the Transformer architecture.
        # The PDF explains why summation is preferred over concatenation.
        return token_embs + position_embs

# --- 3. Multi-Head Attention with Dropout ---
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        super().__init__()
        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.W_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        q = self.Wq(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.Wk(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Causal mask for attention
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        attention_scores.masked_fill_(mask, float('-inf'))

        attention_weights = torch.softmax(attention_scores, dim=-1)

        # Apply dropout to attention weights as mentioned in the document
        attention_weights = self.dropout(attention_weights)

        context_vector = torch.matmul(attention_weights, v)

        context_vector = context_vector.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_out(context_vector)

        return output

# --- 4. Feed-Forward Network with GELU and Dropout ---
class FeedForward(nn.Module):
    def __init__(self, d_model, ff_dim, dropout_rate=0.1):
        super().__init__()
        # The document states that inputs are projected into a four-times larger space.
        self.linear1 = nn.Linear(d_model, ff_dim)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)
        # Then shrunk back to the original dimension.
        self.linear2 = nn.Linear(ff_dim, d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# --- 5. The Transformer Block with Residual Connections ---
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        # Document's diagram shows LayerNorm after Attention and FeedForward
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # A 4x expansion for the feed-forward network, as is common
        self.feed_forward = FeedForward(d_model, d_model * 4, dropout_rate)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        # The residual connection from the input is added to the attention output
        # followed by normalization, as per the document's diagram.
        attention_output = self.attention(self.norm1(x))
        x = x + self.dropout1(attention_output)

        # A second residual connection for the feed-forward network.
        ff_output = self.feed_forward(self.norm2(x))
        x = x + self.dropout2(ff_output)

        return x

# --- 6. Putting it all together: A more complete Simple LLM Model ---
class SimpleLLM(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, max_len=512):
        super().__init__()
        self.embedding = TokenAndPositionalEmbedding(vocab_size, d_model, max_len)
        # Stacking multiple Transformer blocks as is done in real LLMs.
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads) for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        self.linear_output = nn.Linear(d_model, vocab_size)

    def forward(self, token_ids):
        x = self.embedding(token_ids)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.final_norm(x)
        logits = self.linear_output(x)
        return logits

def get_loss(logits, targets, pad_token_id):
    pred_logits_flat = logits.reshape(-1, logits.size(-1))
    targets_flat = targets.reshape(-1)

    loss_function = nn.CrossEntropyLoss(ignore_index=pad_token_id)
    loss = loss_function(pred_logits_flat, targets_flat)

    return loss

# --- 7. Demonstration ---
if __name__ == "__main__":
    vocab = ["<unk>", "<pad>", "hello", "world", "this", "is", "a", "test", "llm", "demo", "training", "works"]
    pad_token_id = 1
    vocab_size = len(vocab)
    d_model = 16
    num_heads = 4
    num_layers = 2

    tokenizer = SimpleTokenizer(vocab)
    model = SimpleLLM(vocab_size, d_model, num_heads, num_layers)

    # Expanded input text for a more meaningful training demo.
    input_text = "hello world this is a test llm demo training works"
    token_ids = tokenizer.encode(input_text)
    input_tensor = torch.tensor([token_ids])

    print("--- Model Structure ---")
    print(model)
    print("\n--- Input and Tokenization ---")
    print(f"Input text: '{input_text}'")
    print(f"Token IDs: {input_tensor}")

    # Get initial prediction
    with torch.no_grad():
        logits = model(input_tensor)

    print("\n--- Model Output (Logits) ---")
    print(f"Logits shape: {logits.shape}")

    print("\n--- Next-Word Prediction Demo ---")
    targets = input_tensor.clone()
    prediction_logits = logits[0, 0]
    predicted_token_id = torch.argmax(prediction_logits).item()
    predicted_token = tokenizer.decode([predicted_token_id])

    print(f"\nPredicting next token after 'hello'...")
    print(f"Predicted token ID: {predicted_token_id}")
    print(f"Predicted token: '{predicted_token}'")

    print("\n--- Training Loop Demonstration ---")

    # Create the training batch and targets.
    input_batch = torch.tensor([tokenizer.encode(input_text)])
    targets_batch = input_batch.clone()

    optimizer = optim.AdamW(model.parameters(), lr=0.001)

    # Run the training loop with early stopping.
    num_epochs = 100
    patience = 5
    min_delta = 1e-4
    best_loss = float('inf')
    epochs_no_improve = 0

    # Get initial loss to show improvement.
    with torch.no_grad():
        initial_logits = model(input_batch)
        initial_loss = get_loss(initial_logits[:, :-1, :], targets_batch[:, 1:], pad_token_id)
    print(f"Initial Loss (before training): {initial_loss.item():.4f}")

    for epoch in range(num_epochs):
        model.train() # Set the model to training mode
        logits = model(input_batch)
        loss = get_loss(logits[:, :-1, :], targets_batch[:, 1:], pad_token_id)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Check for early stopping
        if loss.item() < best_loss - min_delta:
            best_loss = loss.item()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break

    # Get final loss to show improvement.
    with torch.no_grad():
        final_logits = model(input_batch)
        final_loss = get_loss(final_logits[:, :-1, :], targets_batch[:, 1:], pad_token_id)

    print(f"Final Loss (after {epoch + 1} epochs): {final_loss.item():.4f}")

    # --- Next-Word Prediction after training ---
    print("\n--- Next-Word Prediction Demo after Training ---")
    with torch.no_grad():
        logits_after_training = model(input_batch)
        # To get the prediction for 'hello', we need to look at the second token's logits
        # The target for 'hello' is 'world'
        prediction_logits_after_training = logits_after_training[0, 0]
        predicted_token_id_after_training = torch.argmax(prediction_logits_after_training).item()
        predicted_token_after_training = tokenizer.decode([predicted_token_id_after_training])

    print(f"Predicted token ID after training: {predicted_token_id_after_training}")
    print(f"Predicted token after training: '{predicted_token_after_training}'")

    print("\nThis concludes the training loop demonstration with loss improvement and a final prediction.")

--- Model Structure ---
SimpleLLM(
  (embedding): TokenAndPositionalEmbedding(
    (token_embedding): Embedding(12, 16)
    (positional_embedding): Embedding(512, 16)
  )
  (transformer_blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attention): MultiHeadAttention(
        (Wq): Linear(in_features=16, out_features=16, bias=True)
        (Wk): Linear(in_features=16, out_features=16, bias=True)
        (Wv): Linear(in_features=16, out_features=16, bias=True)
        (W_out): Linear(in_features=16, out_features=16, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear1): Linear(in_features=16, out_features=64, bias=True)
        (gelu): GELU(approximate='none')
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=16, bias=True)
