<a href="https://colab.research.google.com/github/chen-star/llm_model_trainings/blob/main/3_4_transformer_impl_full_transformer_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1> ⭐ Transformer Decoder ⭐

# ✈ Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import override

# 🔢 Hyperparameters

In [2]:
# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
# Use the same parameters a GPT2-124M
batch_size = 8

num_transformer_blocks = 12

embedding_dimension = 768
num_heads = 12 # embedding_dimension must be divisible by num_heads

context_window_size = 1024
vocabulary_size = 50257

# [1] 🏚 Model Impl

## (1.1) 👓 Multi-head Attention

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embedding_dimension, num_heads):
    super().__init__()

    # define W_Q, W_K, W_V
    self.q_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.k_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)
    self.v_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)

    # define W0
    self.w0_layer = nn.Linear(embedding_dimension, embedding_dimension, bias=False)

    # ***** multi-head *****
    self.num_heads = num_heads
    self.head_dimension = embedding_dimension // num_heads
    # *****************************


  @override
  def forward(self, X):
    batch_size, context_window_size, embedding_dimension = X.shape

    # Q = XW_Q
    # K = XW_K
    # V = XW_V
    Q = self.q_layer(X)
    K = self.k_layer(X)
    V = self.v_layer(X)

    # ***** Split Q,K,V *****
    Q = Q.view(batch_size, context_window_size, self.num_heads, self.head_dimension)
    K = K.view(batch_size, context_window_size, self.num_heads, self.head_dimension)
    V = V.view(batch_size, context_window_size, self.num_heads, self.head_dimension)

    # For attention score calculation, pytorch expects the shape to be
    # [batch_size, num_heads, context_window_size, head_dimension]
    Q = Q.transpose(1,2)
    K = K.transpose(1,2)
    V = V.transpose(1,2)
    # *****************************

    attention_score = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

    # Transpose back
    attention_score = attention_score.transpose(1,2)

    # ***** Merge heads *****
    attention_score = attention_score.reshape(batch_size, context_window_size, embedding_dimension)
    # *****************************

    return self.w0_layer(attention_score)

## (1.2) 🏃 Single MLP

In [5]:
class MLP(nn.Module):
  def __init__(self, embedding_dimension, expansion: int=4):
    super().__init__()

    # define W1, Gelu, W2
    self.w1_layer = nn.Linear(embedding_dimension, expansion * embedding_dimension) # 4x expansion
    self.gelu = nn.GELU()
    self.w2_layer = nn.Linear(expansion * embedding_dimension, embedding_dimension) # 4x contraction


  @override
  def forward(self, X):
    W1 = self.w1_layer(X)
    GELU = self.gelu(W1)
    W2 = self.w2_layer(GELU)

    return W2

## (1.3) 🔲 Transformer Block

In [6]:
class TransformerBlock(nn.Module):
  def __init__(self, embedding_dimension):
    super().__init__()

    # Attention
    self.layerNorm_attention = nn.LayerNorm(embedding_dimension)
    self.attention_heads = MultiHeadAttention(embedding_dimension, num_heads)

    # MLP / FeedForward
    self.layerNorm_mlp = nn.LayerNorm(embedding_dimension)
    self.mlp = MLP(embedding_dimension)


  @override
  def forward(self, X):
    # --- Attention ---
    # X -> layerNorm -> attention_head
    #                                     +   = output
    #                                X
    X = X + self.attention_heads(self.layerNorm_attention(X))

    # --- MLP ---
    # X -> layerNorm -> mlp
    #                         +   = output
    #                     X
    X = X + self.mlp(self.layerNorm_mlp(X))

    return X

## (1.4) 🏢 Model

In [7]:
class LanguageModel(nn.Module):
  def __init__(self, device):
    super().__init__()

    self.device = device

    # ----- Token Embedding + Position Encoding -----
    self.wte = nn.Embedding(vocabulary_size, embedding_dimension) # token embedding
    self.wpe = nn.Embedding(context_window_size, embedding_dimension) # position encoding

    # ----- Transformer Blocks -----
    self.transformer_blocks = nn.Sequential(*[
          TransformerBlock(embedding_dimension) for _ in range(num_transformer_blocks)
        ])

    # ----- Final layernorm -----
    self.final_layernorm = nn.LayerNorm(embedding_dimension)

    # ----- Unembedding -----
    self.unembedding = nn.Linear(embedding_dimension, vocabulary_size, bias=False)
    # tied unembedding weights
    self.unembedding.weight = nn.Parameter(self.wte.weight)


  @override
  def forward(self, token_ids):
    # ----- Token Embedding + Position Encoding -----
    # [batch_size, context_window_size, embedding_dimension]
    token_embedding = self.wte(token_ids)
    # [context_window_size, embedding_dimension]
    position_emcoding = self.wpe(torch.arange(token_ids.shape[-1], device=self.device))
    # [batch_size, context_window_size, embedding_dimension]
    X = token_embedding + position_emcoding

    # ----- Transformer Blocks -----
    X = self.transformer_blocks(X)

    # ----- Final layernorm -----
    X = self.final_layernorm(X)

    # ----- Unembedding -----
    # [batch_size, context_window_size, vocab_size]
    logits = self.unembedding(X)

    return logits


  def generate(self, token_ids, temperature=1.1, num_new_tokens=10):
    for _ in range(0, num_new_tokens):
      # forward
      # [batch_size, context_window_size, vocab_size]
      logits = self(token_ids[:, -context_window_size:])
      # [batch_size, vocab_size]
      logits = logits[:, -1, :] # last token's logits

      # softmax
      # [batch_size, vocab_size]
      probabilities = F.softmax(logits / temperature, dim=-1)

      # sample
      # [batch_size, 1]
      next_token_id = torch.multinomial(probabilities, num_samples=1)

      # append
      token_ids = torch.cat((token_ids, next_token_id), dim=1)

      return token_ids

## (1.5) 🧪 Random Data Test

In [8]:
model = LanguageModel(device).to(device)
model

LanguageModel(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (layerNorm_attention): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention_heads): MultiHeadAttention(
        (q_layer): Linear(in_features=768, out_features=768, bias=False)
        (k_layer): Linear(in_features=768, out_features=768, bias=False)
        (v_layer): Linear(in_features=768, out_features=768, bias=False)
        (w0_layer): Linear(in_features=768, out_features=768, bias=False)
      )
      (layerNorm_mlp): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (w1_layer): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GELU(approximate='none')
        (w2_layer): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
    (1): TransformerBlock(
      (layerNorm_attention): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention_heads):

In [9]:
# pass data once to test
print(f"batch_size: {batch_size}")
print(f"context_window_size: {context_window_size}")
print(f"embedding_dimension: {embedding_dimension}")
print(f"num_heads: {num_heads}")
print(f"head_dimension: {embedding_dimension // num_heads}\n")

random_token_ids = torch.randint(0, vocabulary_size, (batch_size, context_window_size)).to(device)
output = model(random_token_ids)
print(f"Input shape: {random_token_ids.shape}")
print(f"Output shape: {output.shape}")

batch_size: 8
context_window_size: 1024
embedding_dimension: 768
num_heads: 12
head_dimension: 64

Input shape: torch.Size([8, 1024])
Output shape: torch.Size([8, 1024, 50257])


# [2] 🗺 Compare with GPT2 Model

In [10]:
from transformers import AutoModelForCausalLM,GPT2Tokenizer

!pip install torchinfo # not installed by default in colab
from torchinfo import summary

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [11]:
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [13]:
summary(model, input_size=(batch_size, context_window_size), dtypes=[torch.long])

Layer (type:depth-idx)                   Output Shape              Param #
LanguageModel                            [8, 1024, 50257]          --
├─Embedding: 1-1                         [8, 1024, 768]            38,597,376
├─Embedding: 1-2                         [1024, 768]               786,432
├─Sequential: 1-3                        [8, 1024, 768]            --
│    └─TransformerBlock: 2-1             [8, 1024, 768]            --
│    │    └─LayerNorm: 3-1               [8, 1024, 768]            1,536
│    │    └─MultiHeadAttention: 3-2      [8, 1024, 768]            2,359,296
│    │    └─LayerNorm: 3-3               [8, 1024, 768]            1,536
│    │    └─MLP: 3-4                     [8, 1024, 768]            4,722,432
│    └─TransformerBlock: 2-2             [8, 1024, 768]            --
│    │    └─LayerNorm: 3-5               [8, 1024, 768]            1,536
│    │    └─MultiHeadAttention: 3-6      [8, 1024, 768]            2,359,296
│    │    └─LayerNorm: 3-7               [

In [14]:
summary(gpt2, input_size=(batch_size, context_window_size), dtypes=[torch.long])

Layer (type:depth-idx)                             Output Shape              Param #
GPT2LMHeadModel                                    --                        --
├─GPT2Model: 1-1                                   --                        --
│    └─Embedding: 2-1                              [8, 1024, 768]            38,597,376
│    └─Embedding: 2-2                              [1, 1024, 768]            786,432
│    └─Dropout: 2-3                                [8, 1024, 768]            --
│    └─ModuleList: 2-4                             --                        --
│    │    └─GPT2Block: 3-1                         [8, 1024, 768]            7,087,872
│    │    └─GPT2Block: 3-2                         [8, 1024, 768]            7,087,872
│    │    └─GPT2Block: 3-3                         [8, 1024, 768]            7,087,872
│    │    └─GPT2Block: 3-4                         [8, 1024, 768]            7,087,872
│    │    └─GPT2Block: 3-5                         [8, 1024, 768]         