# Module 9.4: Hybrid Architectures

**Goal**: Build Jamba-style hybrid models

**Time**: 90 minutes

**Concepts Covered**:
- Jamba-style hybrid (Mamba + Attention)
- Layer placement strategies
- Attention frequency experiments
- Quality vs speed trade-offs
- Custom hybrid builder

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn as nn

class HybridBlock(nn.Module):
    """Hybrid Mamba + Attention block (Jamba-style)"""
    def __init__(self, d_model, num_heads=8, mamba_expand=2):
        super().__init__()
        self.d_model = d_model
        
        # Mamba layer
        from advanced_architectures.mamba import SelectiveSSM
        self.mamba = SelectiveSSM(d_model)
        
        # Attention layer
        self.attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
    def forward(self, x, use_attention=True):
        # Mamba branch
        mamba_out = self.mamba(x)
        x = self.norm1(x + mamba_out)
        
        # Optional attention branch
        if use_attention:
            attn_out, _ = self.attention(x, x, x)
            x = self.norm2(x + attn_out)
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)
        
        return x

def create_hybrid_model(num_layers, attention_frequency=4):
    """Create hybrid model with attention every N layers"""
    layers = []
    for i in range(num_layers):
        use_attention = (i % attention_frequency == 0)
        layers.append(HybridBlock(d_model=512, use_attention=use_attention))
    return nn.Sequential(*layers)

print("Hybrid Architecture Benefits:")
print("- Mamba: Efficient long-context (O(n))")
print("- Attention: Strong short-range dependencies")
print("- Best of both worlds: Quality + Speed")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.