# Module 9.1: State Space Models

**Goal**: Implement SSM from scratch and understand O(n) complexity

**Time**: 90 minutes

**Concepts Covered**:
- SSM implementation from scratch
- Continuous vs discrete SSM
- O(n) complexity demonstration
- Long sequence handling
- Compare SSM vs Transformer

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import torch.nn as nn
import numpy as np

class StateSpaceModel(nn.Module):
    """Simple State Space Model (SSM)"""
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # State space parameters
        self.A = nn.Parameter(torch.randn(d_state, d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.randn(d_model, d_model))
        
    def forward(self, u):
        """Forward pass: O(n) complexity"""
        batch_size, seq_len, d_model = u.shape
        h = torch.zeros(batch_size, seq_len, self.d_state, device=u.device)
        y = torch.zeros(batch_size, seq_len, self.d_model, device=u.device)
        
        # Recurrent computation
        for t in range(seq_len):
            if t == 0:
                h[:, t] = torch.matmul(u[:, t], self.B.t())
            else:
                h[:, t] = torch.matmul(h[:, t-1], self.A.t()) + torch.matmul(u[:, t], self.B.t())
            y[:, t] = torch.matmul(h[:, t], self.C.t()) + torch.matmul(u[:, t], self.D.t())
        
        return y

# Test SSM
d_model = 64
d_state = 16
ssm = StateSpaceModel(d_model, d_state)

seq_len = 2048
u = torch.randn(1, seq_len, d_model)

output = ssm(u)

print(f"Input shape: {u.shape}")
print(f"Output shape: {output.shape}")
print(f"Complexity: O(seq_len × d_state × d_model) = O(n)")
print(f"Transformer attention: O(seq_len² × d_model) = O(n²)")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.