Sequence learning is used to learn from sequence data such as texts, audio and video. Sequence data is hard due to order dependency and variable length. Recurrent Neural Networks (RNN) were state-of-the-art (SOTA) in sequence learning.
RNNs vs Feed-Forward Networks (FNNs): [Ref1], [Ref2].
Feed Forward | Recurrent | |
---|---|---|
Input Length | Fixed | Variable |
Data Flow | One-way, Top-Down (IN -> hidden0 -> OUT) |
Directed Graph (feedback loops) |
Generation | Parallel Many Per Iteration |
Sequential One Per Iteration |
RNNs are effective at short-term dependencies but struggle with long-term dependencies due to vanishing and exploding gradients.
Long Short-Term Memory (LSTM), a RNN variant, mitigate this issue with memory cells and gating that allow information to be retained over over longer intervals, thousands of steps earlier. [Ref].
Gated Recurrent Unit (GRU), another variant, simplify the LSTM architecture by combining the forget and input gates into a single update gate. [Ref]
However, both LSTM and GRU remain limited by sequence length, requiring large networks and considerable processing time to expand the dependency window. (TODO Add Ref)
Prior to the Transformers architecture, attention was another technique explored to improve modeling of dependencies in RNNs. [Google's Neural Machine Translation, Ref]
Attention Is All You Need Paper
# Transformer parameters
vocab_size = 32 * 1024 # 32K words/tokens embeddings in vocabulary
embedding_dim = 4096 # 4096 embedding dimension (dmodel)
max_seq_length = 2048 # 2048 maximum input tokens
def build_absolute_positional_encoding_naive() -> torch.Tensor:
positions = torch.arange(max_seq_length) # [0,1,2, ... max_seq_length]
positions = positions.unsqueeze(1) # [0,1,2, ... max_seq_length][]
embeddings = torch.arange(embedding_dim) # [0,1,2, ... embedding_dim]
# pos/10000^(2i/dmodel)
angle = positions / torch.pow(10000, (2 * embeddings / embedding_dim))
# PE(pos, 2i+0) = sin( pos/10000^(2i/dmodel) )
# PE(pos, 2i+1) = cos( pos/10000^(2i/dmodel) )
positional_encoding = torch.empty(max_seq_length, embedding_dim)
positional_encoding[:, 0::2] = torch.sin(angle[:, 0::2])
positional_encoding[:, 1::2] = torch.cos(angle[:, 1::2])
return positional_encoding
# q, k, v – layers were learned during training
q = nn.Linear(embedding_dim, embedding_dim, bias=False)
k = nn.Linear(embedding_dim, embedding_dim, bias=False)
v = nn.Linear(embedding_dim, embedding_dim, bias=False)
def multi_head_attention_naive(input_embd: torch.Tensor,
num_heads: int) -> torch.Tensor:
# [seq_len, hidden_dim] (no batching)
seq_length = input_embd.size(0)
head_length = embedding_dim // num_heads
q1 = q(input_embd)
k1 = k(input_embd)
v1 = v(input_embd)
# We work on per-head sequences (no batching for now)
# Rearrange data as [num_heads, seq_length, head_length]
q2 = q1.view(seq_length, num_heads, head_length).transpose(0, 1)
k2 = k1.view(seq_length, num_heads, head_length).transpose(0, 1)
v2 = v1.view(seq_length, num_heads, head_length).transpose(0, 1)
# q * k_transposed / sqrt(dk)
dk = head_length
qk = q2.matmul(k2.transpose(1, 2)) / math.sqrt(dk)
# out = softmax(qk) * v (no out projection for now)
mh_attn = nn.functional.softmax(qk, dim=-1)
mh_attn_out = mh_attn.matmul(v2)
# Rearrange data back as [seq_length, embedding_dim]
mh_attn_out = mh_attn_out.transpose(0, 1).reshape(seq_length, embedding_dim)
return mh_attn_out
# Does mean and std normalization, then applies learned weight and bias
post_attn_norm_layer = nn.LayerNorm(embedding_dim) # learnable
post_ffn_norm_layer = nn.LayerNorm(embedding_dim) # learnable
def post_attention_norm(input_embd: torch.Tensor) -> torch.Tensor:
return post_attn_norm_layer(input_embd)
def post_feed_forward_norm(input_embd: torch.Tensor) -> torch.Tensor:
return post_ffn_norm_layer(input_embd)
ffn_linear1 = nn.Linear(embedding_dim, 4 * embedding_dim)
ffn_linear2 = nn.Linear(4 * embedding_dim, embedding_dim)
ffn_act = nn.ReLU()
ffn_drop = nn.Dropout(0.0) # Training-only, drops x% to avoid overfit
def feed_forward_naive(input_embd: torch.Tensor):
hidden_states = ffn_linear1(input_embd)
hidden_states = ffn_drop(ffn_act(hidden_states))
hidden_states = ffn_linear2(hidden_states)
return hidden_states
prompt = 'Which fruits do you like?'
input_ids = tokenizer.encode(prompt, return_tensors="pt").squeeze(0) # remove batching
llm_model(input_ids)
# llm_model forward()
self.vocab_eos_token = '<|end_of_text|>'
self.positional_encoding = build_absolute_positional_encoding_naive()
has_finished = False
while not has_finished:
hidden_states = self.vocab_embedding(input_ids)
# [seq_len, hidden_dim] (no batching)
seq_length = hidden_states.size(0)
# On naive, applies positional_encoding to Q, K, V
hidden_states = hidden_states + self.positional_encoding[:seq_length]
for layer in self.decoder:
# On naive, num_heads is the same across Q, K, V
residual = hidden_states
hidden_states = layer.multi_head_attention_naive(hidden_states)
hidden_states = layer.post_attention_norm(residual + hidden_states)
residual = hidden_states
hidden_states = layer.feed_forward_naive(hidden_states)
hidden_states = layer.post_feed_forward_norm(residual + hidden_states)
logits = self.lm_head(hidden_states)
# Greedy sampling over last/most-recent next-token
next_token_id = torch.argmax(logits[-1, :])
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)])
if len(input_ids) >= self.max_seq_length or next_token_id.item() == self.vocab_eos_token:
has_finished = True
New approaches that I have not yet seen applied in popular open models:
Improvements observed in models like LLaMA, Mixtral, Gemma, Grok, and Phy:
-
Rotary Positional Encoding QK-only (1D Text, 2D Images)
- Improves generalization to sequences longer than those seen during training
-
RMS Normalization
- Normalization without centering, potentially enhancing performance in deeper models
-
Multi-Head Attention Improvements
- O(n^2): FlashAttention 2 & 3, Xformers, SPDA
- O(n): Mamba, Reformers and Linformer
-
Grouped Query Attention (e.g. 32x Q, 8x K/V heads)
- Significantly reduces memory and increases performance
-
Sliding context window
- Enables longer context while capping attention cost
-
KV Cache with sliding window
- Computes and caches attention for each new token once
- Significantly boosts performance at the cost of higher memory usage
-
Multi Layer Perceptron (MLP) Improvements
- SwiGLU and GELU activations
- Outperforms traditional ReLU
- Sparse Mixture of Experts (e.g. 8x MLPs, 2x selected per-token)
- Enhances task-specific performance while reducing computational cost by activating only the two most relevant experts per-token
- SwiGLU and GELU activations
Efficient Attention Reference List
Research Implementations
- O(n) instead of O(n^2) time and space
2022: Rethinking Attention with Performers
2021: Self-attention Does Not Need O(n2) Memory
Production Implementation (Jan24): (Huggingface Transformers, vLLM)
qkv_weight = torch.cat((q_proj, k_proj, v_proj), dim=0)
hidden_states = vocab_embedding(input_ids) # [batches, seq_length, embedding_dim]
batches = hidden_states.size(0)
seq_length = hidden_states.size(1)
# Shared Attention Preamble Processing
def mha_proj_and_pos_encode(hidden_states: torch.Tensor, position_ids: torch.Tensor,
num_heads: int = 32):
# QKV Projection
qkv = F.linear(hidden_states, qkv_weight)
query, key, value = qkv.split([embedding_dim, embedding_dim, embedding_dim], dim=-1)
# Relative Rotary Position Embedding
query, key = apply_rotary_pos_emb(query, key, position_ids)
# Note: Non KV-Cached Step
# Multi-Head Reshape. [batch, seq, embd] -> [batch*seq, num_heads, head_size]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
return query, key, value
# XFormers Memory Efficient Attention
def mha_xformers(hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
query, key, value = mha_proj_and_pos_encode(hidden_states, position_ids)
# Causal Attention Mask (all seqs have same length due to padding)
attn_bias = BlockDiagonalCausalMask.from_seqlens([seq_length] * batches)
mh_attn_out = xformers.ops.memory_efficient_attention_forward(
query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), attn_bias=attn_bias,
)
mh_attn_out = mh_attn_out.reshape(batches, seq_length, embedding_dim)
mh_attn_out = F.linear(mh_attn_out, out_proj)
return mh_attn_out
# Torch SDPA Attention
def mha_sdpa(hidden_states: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
query, key, value = mha_proj_and_pos_encode(hidden_states, position_ids)
mh_attn_out = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, is_causal=True,
)
mh_attn_out = mh_attn_out.reshape(batches, seq_length, embedding_dim)
mh_attn_out = F.linear(mh_attn_out, out_proj)
return mh_attn_out
Test: 100 iterations of 32 attention layers (xavier_normal_ initialized)
- Baseline (Naive)
- ~2.5x - xformers speed-up
- ~3.5x - sdpa speed-up
-
Image split into 16x16 <img> patches, <img_break> for row-break
-
VisionTransformer (Conv2D + Transformer)
- Conv2D(3, 1024, 16, 16). Image → [1, patches, 1024]
- Transformer 16 layers, 1k embd, 16 QKV heads, MLP 4k intermediate with SiLU
- ~256M. 16x (12M MLP, 4M Attn, 2k RMSNorm)
- 2D RoPE, No LM_Head
-
VisionLanguageAdapter (MLP)
- up_proj [1k → 5k] bias=True, then GELU
- out_proj [5k → 5k] bias=True
-
PixtralForCausalLM
- Multi-modal embd. <img> replaced with VisionAdapter embeddings
- 128k Vocabulary, 5k Embedding dimension
- Grouped Query Attention (32xQ, 8xKV)
- MLP 14k Intermediate size
- 40 layers
- Llama 3-8B vs Llama 2-7B
- Vocabulary Size ~125k vs 32k
- Context Length 8k vs 4k
- Grouped Query Attention (32xQ, 8xKV)
- MLP up-proj size 14k vs ~11k
- dtype bfloat16 vs float16
- 32 layers in both
https://github.com/havenhq/mamba-chat
https://arxiv.org/abs/2312.00752
- 2.5B parameters
- TODO
- Sliding Window Attention (4096 tokens)
- Sliding KV-Cache
- KV-Cache Pre-Fill & Chunking
- FFN with Sparse Mixture of Experts
- 8x FFN/MLP Models, 2x Used Per-Token
- SiLU (Sigmoid Linear Unit) instead of RELU
- Others
- Byte-fallback Tokenizer
- Model Partition
- Efficient Queries Batching/Bundling
- Moved Back RMS Normalization
- Rotary Positional Encoder (Q, K only)
- Group Query Attention on 34B/70B
- SwiGLU instead of RELU
- Mixture of Experts on FFN
- ~1.5B and ~175B parameters respectively
- KV-Cache
- ~120M parameters
- Encoder Only
TODO Good Images Ref