Skip to content

Latest commit

 

History

History
368 lines (269 loc) · 13.2 KB

README.md

File metadata and controls

368 lines (269 loc) · 13.2 KB

ML – From Scratch to Llama2, Mistral and Phi-2 Pytorch Code

Background

RNN (1986), LSTM (1997), GRU (2014), Attention for Sequence Learning

Sequence learning is used to learn from sequence data such as texts, audio and video. Sequence data is hard due to order dependency and variable length. Recurrent Neural Networks (RNN) were state-of-the-art (SOTA) in sequence learning.

RNNs vs Feed-Forward Networks (FNNs): [Ref1], [Ref2].

Feed Forward Recurrent
Input Length Fixed Variable
Data Flow One-way, Top-Down
(IN -> hidden0 -> OUT)
Directed Graph
(feedback loops)
Generation Parallel
Many Per Iteration
Sequential
One Per Iteration

RNNs are effective at short-term dependencies but struggle with long-term dependencies due to vanishing and exploding gradients.

Long Short-Term Memory (LSTM), a RNN variant, mitigate this issue with memory cells and gating that allow information to be retained over over longer intervals, thousands of steps earlier. [Ref].

Gated Recurrent Unit (GRU), another variant, simplify the LSTM architecture by combining the forget and input gates into a single update gate. [Ref]

However, both LSTM and GRU remain limited by sequence length, requiring large networks and considerable processing time to expand the dependency window. (TODO Add Ref)

Prior to the Transformers architecture, attention was another technique explored to improve modeling of dependencies in RNNs. [Google's Neural Machine Translation, Ref]

Naive Transformer Model (2017) – Multi-Head Self-Attention & MLP/FFN

Attention Is All You Need Paper

vanilla to modern transformer

# Transformer parameters
vocab_size = 32 * 1024  # 32K words/tokens embeddings in vocabulary
embedding_dim = 4096    # 4096 embedding dimension (dmodel)
max_seq_length = 2048   # 2048 maximum input tokens

Positional Encoding

def build_absolute_positional_encoding_naive() -> torch.Tensor:
    positions = torch.arange(max_seq_length)    # [0,1,2, ... max_seq_length]
    positions = positions.unsqueeze(1)          # [0,1,2, ... max_seq_length][]
    embeddings = torch.arange(embedding_dim)    # [0,1,2, ... embedding_dim]
    
    # pos/10000^(2i/dmodel)
    angle = positions / torch.pow(10000, (2 * embeddings / embedding_dim))
    
    # PE(pos, 2i+0) = sin( pos/10000^(2i/dmodel) )
    # PE(pos, 2i+1) = cos( pos/10000^(2i/dmodel) )
    positional_encoding = torch.empty(max_seq_length, embedding_dim)
    positional_encoding[:, 0::2] = torch.sin(angle[:, 0::2])
    positional_encoding[:, 1::2] = torch.cos(angle[:, 1::2])
    return positional_encoding

Multi-Head Attention

# q, k, v – layers were learned during training
q = nn.Linear(embedding_dim, embedding_dim, bias=False)
k = nn.Linear(embedding_dim, embedding_dim, bias=False)
v = nn.Linear(embedding_dim, embedding_dim, bias=False)

def multi_head_attention_naive(input_embd: torch.Tensor,
                               num_heads: int) -> torch.Tensor:
    # [seq_len, hidden_dim] (no batching)
    seq_length = input_embd.size(0)
    head_length = embedding_dim // num_heads

    q1 = q(input_embd)
    k1 = k(input_embd)
    v1 = v(input_embd)

    # We work on per-head sequences (no batching for now)
    # Rearrange data as [num_heads, seq_length, head_length]
    q2 = q1.view(seq_length, num_heads, head_length).transpose(0, 1)
    k2 = k1.view(seq_length, num_heads, head_length).transpose(0, 1)
    v2 = v1.view(seq_length, num_heads, head_length).transpose(0, 1)

    # q * k_transposed / sqrt(dk)
    dk = head_length
    qk = q2.matmul(k2.transpose(1, 2)) / math.sqrt(dk)

    # out = softmax(qk) * v (no out projection for now)
    mh_attn = nn.functional.softmax(qk, dim=-1)
    mh_attn_out = mh_attn.matmul(v2)

    # Rearrange data back as [seq_length, embedding_dim] 
    mh_attn_out = mh_attn_out.transpose(0, 1).reshape(seq_length, embedding_dim)
    return mh_attn_out

Normalization

# Does mean and std normalization, then applies learned weight and bias
post_attn_norm_layer = nn.LayerNorm(embedding_dim)  # learnable
post_ffn_norm_layer = nn.LayerNorm(embedding_dim)   # learnable

def post_attention_norm(input_embd: torch.Tensor) -> torch.Tensor:
    return post_attn_norm_layer(input_embd)

def post_feed_forward_norm(input_embd: torch.Tensor) -> torch.Tensor:
  return post_ffn_norm_layer(input_embd)

Feed Forward (aka MLP)

ffn_linear1 = nn.Linear(embedding_dim, 4 * embedding_dim)
ffn_linear2 = nn.Linear(4 * embedding_dim, embedding_dim)
ffn_act = nn.ReLU()
ffn_drop = nn.Dropout(0.0) # Training-only, drops x% to avoid overfit
  
def feed_forward_naive(input_embd: torch.Tensor):
    hidden_states = ffn_linear1(input_embd)
    hidden_states = ffn_drop(ffn_act(hidden_states))
    hidden_states = ffn_linear2(hidden_states)
    return hidden_states

Transformer Decoder Block (Everything Together Naive)

prompt = 'Which fruits do you like?'
input_ids = tokenizer.encode(prompt, return_tensors="pt").squeeze(0)  # remove batching
llm_model(input_ids)

# llm_model forward()
self.vocab_eos_token = '<|end_of_text|>'
self.positional_encoding = build_absolute_positional_encoding_naive()
has_finished = False

while not has_finished:
    hidden_states = self.vocab_embedding(input_ids)
    # [seq_len, hidden_dim] (no batching)
    seq_length = hidden_states.size(0)
    # On naive, applies positional_encoding to Q, K, V
    hidden_states = hidden_states + self.positional_encoding[:seq_length]
    
    for layer in self.decoder:
        # On naive, num_heads is the same across Q, K, V
        residual = hidden_states
        hidden_states = layer.multi_head_attention_naive(hidden_states)
        hidden_states = layer.post_attention_norm(residual + hidden_states)
        
        residual = hidden_states
        hidden_states = layer.feed_forward_naive(hidden_states)
        hidden_states = layer.post_feed_forward_norm(residual + hidden_states)

    logits = self.lm_head(hidden_states)

    # Greedy sampling over last/most-recent next-token
    next_token_id = torch.argmax(logits[-1, :])
    input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)])

    if len(input_ids) >= self.max_seq_length or next_token_id.item() == self.vocab_eos_token:
        has_finished = True

H2-2024 Modern Transformer Improvements

New approaches that I have not yet seen applied in popular open models:

Improvements observed in models like LLaMA, Mixtral, Gemma, Grok, and Phy:

  • Rotary Positional Encoding QK-only (1D Text, 2D Images)

    • Improves generalization to sequences longer than those seen during training
  • RMS Normalization

    • Normalization without centering, potentially enhancing performance in deeper models
  • Multi-Head Attention Improvements

    • O(n^2): FlashAttention 2 & 3, Xformers, SPDA
    • O(n): Mamba, Reformers and Linformer
  • Grouped Query Attention (e.g. 32x Q, 8x K/V heads)

    • Significantly reduces memory and increases performance
  • Sliding context window

    • Enables longer context while capping attention cost
  • KV Cache with sliding window

    • Computes and caches attention for each new token once
    • Significantly boosts performance at the cost of higher memory usage
  • Multi Layer Perceptron (MLP) Improvements

    • SwiGLU and GELU activations
      • Outperforms traditional ReLU
    • Sparse Mixture of Experts (e.g. 8x MLPs, 2x selected per-token)
      • Enhances task-specific performance while reducing computational cost by activating only the two most relevant experts per-token

Multi-Head Self-Attention (2024) SOTA Implementations

Efficient Attention Reference List

Research Implementations

Production Implementation (Jan24): (Huggingface Transformers, vLLM)

qkv_weight = torch.cat((q_proj, k_proj, v_proj), dim=0)

hidden_states = vocab_embedding(input_ids) # [batches, seq_length, embedding_dim]
batches = hidden_states.size(0)
seq_length = hidden_states.size(1)

# Shared Attention Preamble Processing
def mha_proj_and_pos_encode(hidden_states: torch.Tensor, position_ids: torch.Tensor,
                            num_heads: int = 32):
    # QKV Projection
    qkv = F.linear(hidden_states, qkv_weight)
    query, key, value = qkv.split([embedding_dim, embedding_dim, embedding_dim], dim=-1)

    # Relative Rotary Position Embedding
    query, key = apply_rotary_pos_emb(query, key, position_ids)

    # Note: Non KV-Cached Step

    # Multi-Head Reshape. [batch, seq, embd] -> [batch*seq, num_heads, head_size]
    query = query.view(-1, num_heads, head_size)
    key = key.view(-1, num_heads, head_size)
    value = value.view(-1, num_heads, head_size)
    return query, key, value
# XFormers Memory Efficient Attention
def mha_xformers(hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
    query, key, value = mha_proj_and_pos_encode(hidden_states, position_ids)
    
    # Causal Attention Mask (all seqs have same length due to padding) 
    attn_bias = BlockDiagonalCausalMask.from_seqlens([seq_length] * batches)

    mh_attn_out = xformers.ops.memory_efficient_attention_forward(
        query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), attn_bias=attn_bias,
    )

    mh_attn_out = mh_attn_out.reshape(batches, seq_length, embedding_dim)
    mh_attn_out = F.linear(mh_attn_out, out_proj)
    return mh_attn_out
# Torch SDPA Attention
def mha_sdpa(hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
    query, key, value = mha_proj_and_pos_encode(hidden_states, position_ids)

    mh_attn_out = torch.nn.functional.scaled_dot_product_attention(
        query, key, value, attn_mask=None, is_causal=True,
    )

    mh_attn_out = mh_attn_out.reshape(batches, seq_length, embedding_dim)
    mh_attn_out = F.linear(mh_attn_out, out_proj)
    return mh_attn_out

Test: 100 iterations of 32 attention layers (xavier_normal_ initialized)

  • Baseline (Naive)
  • ~2.5x - xformers speed-up
  • ~3.5x - sdpa speed-up

Multi-Modal LLMs References

  • Image split into 16x16 <img> patches, <img_break> for row-break

  • VisionTransformer (Conv2D + Transformer)

    • Conv2D(3, 1024, 16, 16). Image → [1, patches, 1024]
    • Transformer 16 layers, 1k embd, 16 QKV heads, MLP 4k intermediate with SiLU
      • ~256M. 16x (12M MLP, 4M Attn, 2k RMSNorm)
      • 2D RoPE, No LM_Head
  • VisionLanguageAdapter (MLP)

    •  up_proj [1k → 5k] bias=True, then GELU
    • out_proj [5k → 5k] bias=True
  • PixtralForCausalLM

    • Multi-modal embd. <img> replaced with VisionAdapter embeddings
    • 128k Vocabulary, 5k Embedding dimension
    • Grouped Query Attention (32xQ, 8xKV)
    • MLP 14k Intermediate size
    • 40 layers

LLMs References

Phi 3 (May 2024, Microsoft)

Llama 3 (Apr 2024, Meta)

  • Llama 3-8B vs Llama 2-7B
    • Vocabulary Size ~125k vs 32k
    • Context Length 8k vs 4k
    • Grouped Query Attention (32xQ, 8xKV)
    • MLP up-proj size 14k vs ~11k
    • dtype bfloat16 vs float16
    • 32 layers in both

Mamba (2023, Stanford)

https://github.com/havenhq/mamba-chat
https://arxiv.org/abs/2312.00752

Phi 1 & 2 (2023, Microsoft)

  • 2.5B parameters
  • TODO

Mistral-1 (2023, Mistral)

  • Sliding Window Attention (4096 tokens)
  • Sliding KV-Cache
    • KV-Cache Pre-Fill & Chunking
  • FFN with Sparse Mixture of Experts
    • 8x FFN/MLP Models, 2x Used Per-Token
  • SiLU (Sigmoid Linear Unit) instead of RELU
  • Others
    • Byte-fallback Tokenizer
    • Model Partition
    • Efficient Queries Batching/Bundling

Llama-1 & 2 (2023, Meta)

  • Moved Back RMS Normalization
  • Rotary Positional Encoder (Q, K only)
  • Group Query Attention on 34B/70B
  • SwiGLU instead of RELU

S4 (2021, Stanford)

Switch Transformer (2021, Google)

  • Mixture of Experts on FFN

GPT-2 & 3 (2019~2020, OpenAI)

  • ~1.5B and ~175B parameters respectively

Transformer-XL (2019, Google)

  • KV-Cache

GPT Model (2018, OpenAI)

  • ~120M parameters
  • Encoder Only

TODO Good Images Ref