# Tutorial 16: Type System and Annotations

In this tutorial, we'll explore BrainState's advanced type system, which provides powerful type checking and validation capabilities for neural network development.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Use the typing module for type annotations
- Work with OneOfTypes for union-like types
- Use JointTypes for composite type requirements
- Apply type annotations to neural network code
- Validate inputs using type constraints
- Design type-safe APIs
- Follow typing best practices

## Why Types Matter

Strong typing provides:
- **Early error detection**: Catch bugs before runtime
- **Better documentation**: Types document expected inputs/outputs
- **IDE support**: Better autocomplete and hints
- **Refactoring safety**: Types help prevent breaking changes

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from typing import Union, Optional, Tuple, List, Dict, Callable

# Set random seed
bst.random.seed(42)

## 1. Basic Type Annotations

Start with standard Python typing.

In [None]:
# Basic typed layer
class TypedLinear(bst.graph.Node):
    """Linear layer with type annotations."""
    
    def __init__(self, in_features: int, out_features: int, use_bias: bool = True):
        super().__init__()
        self.in_features: int = in_features
        self.out_features: int = out_features
        
        # Parameter with type annotation
        self.weight: bst.ParamState = bst.ParamState(
            bst.random.randn(in_features, out_features) * 0.1
        )
        
        if use_bias:
            self.bias: Optional[bst.ParamState] = bst.ParamState(
                jnp.zeros(out_features)
            )
        else:
            self.bias: Optional[bst.ParamState] = None
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass.
        
        Args:
            x: Input array of shape (..., in_features)
            
        Returns:
            Output array of shape (..., out_features)
        """
        output: jnp.ndarray = x @ self.weight.value
        
        if self.bias is not None:
            output = output + self.bias.value
        
        return output

# Create typed layer
layer = TypedLinear(in_features=10, out_features=5)
x = bst.random.randn(3, 10)
output = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nType annotations help IDEs understand the code!")

### Generic Types

In [None]:
# Using Union and Optional
from typing import Union, Optional

class FlexibleActivation(bst.graph.Node):
    """Activation that accepts function or string."""
    
    def __init__(self, activation: Union[str, Callable[[jnp.ndarray], jnp.ndarray]]):
        super().__init__()
        
        # Store activation
        if isinstance(activation, str):
            # String activation name
            activations = {
                'relu': jax.nn.relu,
                'tanh': jnp.tanh,
                'sigmoid': jax.nn.sigmoid,
            }
            self.activation: Callable = activations[activation]
        else:
            # Custom function
            self.activation: Callable = activation
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.activation(x)

# Test with string
act1 = FlexibleActivation('relu')
x = jnp.array([-1.0, 0.0, 1.0])
print(f"ReLU activation: {act1(x)}")

# Test with custom function
def custom_activation(x):
    return jnp.maximum(0, x) ** 2

act2 = FlexibleActivation(custom_activation)
print(f"Custom activation: {act2(x)}")

## 2. OneOfTypes: Union-like Types

BrainState's `OneOfTypes` provides advanced union type functionality.

In [None]:
from brainstate.mixin import OneOfTypes

# OneOfTypes example
# Define a type that can be int, float, or array
NumericType = OneOfTypes[int, float, jnp.ndarray]

class ScalableLayer(bst.graph.Node):
    """Layer with flexible scaling."""
    
    def __init__(self, features: int, scale: Union[int, float, jnp.ndarray]):
        super().__init__()
        self.features = features
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
        
        # Scale can be scalar or array
        if isinstance(scale, (int, float)):
            self.scale = jnp.array(scale)
        else:
            self.scale = scale
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return (x @ self.weight.value) * self.scale

# Test with different scale types
x = bst.random.randn(2, 5)

# Scalar scale
layer1 = ScalableLayer(5, scale=2.0)
out1 = layer1(x)
print(f"Scalar scale (2.0): mean output = {jnp.mean(out1):.3f}")

# Array scale (per-feature)
layer2 = ScalableLayer(5, scale=jnp.array([0.5, 1.0, 1.5, 2.0, 2.5]))
out2 = layer2(x)
print(f"Array scale: output shape = {out2.shape}")

## 3. JointTypes: Composite Type Requirements

`JointTypes` allows specifying that a value must satisfy multiple type constraints.

In [None]:
from brainstate.mixin import JointTypes

# JointTypes usage (conceptual example)
# In practice, JointTypes is used internally by BrainState

class ValidatedLayer(bst.graph.Node):
    """Layer with input validation."""
    
    def __init__(self, features: int, min_val: float = -1.0, max_val: float = 1.0):
        super().__init__()
        self.features = features
        self.min_val = min_val
        self.max_val = max_val
        self.weight = bst.ParamState(bst.random.randn(features, features) * 0.1)
    
    def validate_input(self, x: jnp.ndarray) -> bool:
        """Validate input constraints."""
        # Check type
        if not isinstance(x, jnp.ndarray):
            return False
        
        # Check shape
        if x.shape[-1] != self.features:
            return False
        
        # Check value range
        if jnp.any(x < self.min_val) or jnp.any(x > self.max_val):
            return False
        
        return True
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        if not self.validate_input(x):
            raise ValueError(
                f"Input must be ndarray with shape (..., {self.features}) "
                f"and values in [{self.min_val}, {self.max_val}]"
            )
        
        return x @ self.weight.value

# Test validation
layer = ValidatedLayer(features=5, min_val=-2.0, max_val=2.0)

# Valid input
x_valid = bst.random.randn(3, 5) * 0.5  # Values in range
try:
    out = layer(x_valid)
    print("Valid input: Success!")
except ValueError as e:
    print(f"Error: {e}")

# Invalid input (out of range)
x_invalid = bst.random.randn(3, 5) * 10  # Values likely out of range
try:
    out = layer(x_invalid)
    print("Invalid input: Success!")
except ValueError as e:
    print(f"Expected error: Input validation failed")

## 4. Type Annotations for Neural Networks

In [None]:
# Comprehensive type annotations
from typing import Tuple, List, Dict, Optional

class WellTypedMLP(bst.graph.Node):
    """MLP with comprehensive type annotations."""
    
    def __init__(
        self,
        layer_sizes: List[int],
        activation: str = 'relu',
        dropout_rate: Optional[float] = None,
        use_batchnorm: bool = False
    ):
        """Initialize MLP.
        
        Args:
            layer_sizes: List of layer dimensions [input, hidden1, ..., output]
            activation: Activation function name
            dropout_rate: Dropout rate (None to disable)
            use_batchnorm: Whether to use batch normalization
        """
        super().__init__()
        
        self.layer_sizes: List[int] = layer_sizes
        self.activation_name: str = activation
        self.dropout_rate: Optional[float] = dropout_rate
        self.use_batchnorm: bool = use_batchnorm
        
        # Build layers
        self.layers: List[bst.nn.Linear] = []
        for i in range(len(layer_sizes) - 1):
            layer = bst.nn.Linear(layer_sizes[i], layer_sizes[i + 1])
            self.layers.append(layer)
            setattr(self, f'layer_{i}', layer)
        
        # Optional batch norm
        if use_batchnorm:
            self.batchnorms: Optional[List[bst.nn.BatchNorm1d]] = [
                bst.nn.BatchNorm1d(size) for size in layer_sizes[1:-1]
            ]
            for i, bn in enumerate(self.batchnorms):
                setattr(self, f'bn_{i}', bn)
        else:
            self.batchnorms: Optional[List] = None
        
        # Optional dropout
        if dropout_rate is not None:
            self.dropout: Optional[bst.nn.Dropout] = bst.nn.Dropout(dropout_rate)
        else:
            self.dropout: Optional[bst.nn.Dropout] = None
        
        # Activation function
        activations: Dict[str, Callable] = {
            'relu': jax.nn.relu,
            'tanh': jnp.tanh,
            'sigmoid': jax.nn.sigmoid,
        }
        self.activation: Callable[[jnp.ndarray], jnp.ndarray] = activations[activation]
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass.
        
        Args:
            x: Input array of shape (batch, input_dim)
            
        Returns:
            Output array of shape (batch, output_dim)
        """
        # Process through layers
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            
            # Optional batch norm
            if self.batchnorms is not None:
                x = self.batchnorms[i](x)
            
            # Activation
            x = self.activation(x)
            
            # Optional dropout
            if self.dropout is not None:
                x = self.dropout(x)
        
        # Final layer (no activation)
        x = self.layers[-1](x)
        return x
    
    def get_config(self) -> Dict[str, any]:
        """Get configuration dictionary.
        
        Returns:
            Configuration dict
        """
        return {
            'layer_sizes': self.layer_sizes,
            'activation': self.activation_name,
            'dropout_rate': self.dropout_rate,
            'use_batchnorm': self.use_batchnorm,
        }

# Create typed MLP
model = WellTypedMLP(
    layer_sizes=[10, 20, 15, 5],
    activation='relu',
    dropout_rate=0.3,
    use_batchnorm=True
)

# Test
x = bst.random.randn(8, 10)
output = model(x)

print(f"Model configuration:")
for key, value in model.get_config().items():
    print(f"  {key}: {value}")

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")

## 5. Type Hints for State Management

In [None]:
# Typed state management
from typing import TypeVar, Generic

class TypedStatefulLayer(bst.graph.Node):
    """Layer with explicitly typed states."""
    
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        
        # Parameters with type hints
        self.W_input: bst.ParamState = bst.ParamState(
            bst.random.randn(input_dim, hidden_dim) * 0.1
        )
        self.W_hidden: bst.ParamState = bst.ParamState(
            bst.random.randn(hidden_dim, hidden_dim) * 0.1
        )
        self.bias: bst.ParamState = bst.ParamState(
            jnp.zeros(hidden_dim)
        )
        
        # Hidden state
        self.hidden: bst.ShortTermState = bst.ShortTermState(
            jnp.zeros(hidden_dim)
        )
        
        # Metrics
        self.call_count: bst.ShortTermState = bst.ShortTermState(
            jnp.array(0, dtype=jnp.int32)
        )
    
    def reset_state(self) -> None:
        """Reset hidden state."""
        self.hidden.value = jnp.zeros_like(self.hidden.value)
        self.call_count.value = jnp.array(0, dtype=jnp.int32)
    
    def get_parameters(self) -> Dict[str, jnp.ndarray]:
        """Get parameter dictionary.
        
        Returns:
            Dict mapping parameter names to values
        """
        params: Dict[str, jnp.ndarray] = {}
        for name, state in self.states(bst.ParamState).items():
            params[name] = state.value
        return params
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass with state update.
        
        Args:
            x: Input of shape (batch, input_dim)
            
        Returns:
            Output of shape (batch, hidden_dim)
        """
        # Update count
        self.call_count.value += 1
        
        # Compute new hidden state
        h_new: jnp.ndarray = jnp.tanh(
            x @ self.W_input.value +
            self.hidden.value @ self.W_hidden.value +
            self.bias.value
        )
        
        # Update state
        self.hidden.value = h_new
        
        return h_new

# Test typed stateful layer
layer = TypedStatefulLayer(input_dim=5, hidden_dim=10)

# Process sequence
sequence = [bst.random.randn(1, 5) for _ in range(5)]

for i, x in enumerate(sequence):
    output = layer(x)
    if i == 0:
        print(f"Step {i}: output shape = {output.shape}")

print(f"\nProcessed {layer.call_count.value} steps")
print(f"Final hidden state norm: {jnp.linalg.norm(layer.hidden.value):.3f}")

# Get parameters
params = layer.get_parameters()
print(f"\nParameters: {list(params.keys())}")

## 6. Type-Safe API Design

In [None]:
# Type-safe initialization
from typing import Protocol, runtime_checkable

@runtime_checkable
class Initializer(Protocol):
    """Protocol for weight initializers."""
    
    def __call__(self, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
        """Generate initial weights."""
        ...

class XavierInitializer:
    """Xavier/Glorot initialization."""
    
    def __call__(self, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
        if len(shape) < 2:
            return jnp.zeros(shape, dtype=dtype)
        
        fan_in, fan_out = shape[0], shape[-1]
        limit = jnp.sqrt(6.0 / (fan_in + fan_out))
        return bst.random.uniform(-limit, limit, shape).astype(dtype)

class HeInitializer:
    """He initialization for ReLU networks."""
    
    def __call__(self, shape: Tuple[int, ...], dtype: jnp.dtype = jnp.float32) -> jnp.ndarray:
        if len(shape) < 2:
            return jnp.zeros(shape, dtype=dtype)
        
        fan_in = shape[0]
        std = jnp.sqrt(2.0 / fan_in)
        return (bst.random.randn(*shape) * std).astype(dtype)

class ConfigurableLayer(bst.graph.Node):
    """Layer with type-safe initialization."""
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        initializer: Initializer = XavierInitializer()
    ):
        super().__init__()
        
        # Verify initializer conforms to protocol
        if not isinstance(initializer, Initializer):
            raise TypeError(f"Initializer must implement Initializer protocol")
        
        # Initialize weights
        self.weight: bst.ParamState = bst.ParamState(
            initializer((in_features, out_features))
        )
        self.bias: bst.ParamState = bst.ParamState(
            jnp.zeros(out_features)
        )
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return x @ self.weight.value + self.bias.value

# Test different initializers
print("Initializer Comparison:")
print("=" * 60)

x = bst.random.randn(10, 20)

# Xavier initialization
layer_xavier = ConfigurableLayer(20, 10, initializer=XavierInitializer())
out_xavier = layer_xavier(x)
print(f"Xavier init - weight std: {jnp.std(layer_xavier.weight.value):.4f}")

# He initialization
layer_he = ConfigurableLayer(20, 10, initializer=HeInitializer())
out_he = layer_he(x)
print(f"He init - weight std: {jnp.std(layer_he.weight.value):.4f}")

## 7. Type Checking and Validation

In [None]:
# Runtime type checking
def check_shape(x: jnp.ndarray, expected_shape: Tuple[int, ...], name: str = "input"):
    """Validate array shape.
    
    Args:
        x: Array to check
        expected_shape: Expected shape (use -1 for any size)
        name: Name for error messages
        
    Raises:
        ValueError: If shape doesn't match
    """
    if len(x.shape) != len(expected_shape):
        raise ValueError(
            f"{name} has {len(x.shape)} dimensions, expected {len(expected_shape)}"
        )
    
    for i, (actual, expected) in enumerate(zip(x.shape, expected_shape)):
        if expected != -1 and actual != expected:
            raise ValueError(
                f"{name} dimension {i} is {actual}, expected {expected}"
            )

class ValidatedConv(bst.graph.Node):
    """Convolution with strict input validation."""
    
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        self.kernel = bst.ParamState(
            bst.random.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.1
        )
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """Apply convolution.
        
        Args:
            x: Input of shape (batch, height, width, in_channels)
            
        Returns:
            Output of shape (batch, height, width, out_channels)
        """
        # Validate input shape
        check_shape(x, (-1, -1, -1, self.in_channels), "input")
        
        # Apply convolution (simplified)
        # In real implementation, would use jax.lax.conv
        batch, h, w, _ = x.shape
        
        # Placeholder: just return correct shape
        output = jnp.zeros((batch, h, w, self.out_channels))
        
        return output

# Test validation
conv = ValidatedConv(in_channels=3, out_channels=16, kernel_size=3)

# Valid input
x_valid = bst.random.randn(2, 32, 32, 3)
try:
    out = conv(x_valid)
    print(f"Valid input: output shape = {out.shape}")
except ValueError as e:
    print(f"Error: {e}")

# Invalid input (wrong channels)
x_invalid = bst.random.randn(2, 32, 32, 5)
try:
    out = conv(x_invalid)
    print(f"Invalid input: output shape = {out.shape}")
except ValueError as e:
    print(f"Expected error: {e}")

## 8. Best Practices Summary

In [None]:
# Best practices example
from typing import Any, Dict, List, Optional, Tuple, Union, Callable

class BestPracticeModel(bst.graph.Node):
    """Model demonstrating typing best practices.
    
    This class shows:
    - Complete type annotations
    - Docstrings with type information
    - Runtime validation
    - Clear return types
    """
    
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int,
        activation: Union[str, Callable[[jnp.ndarray], jnp.ndarray]] = 'relu',
        dropout_rate: Optional[float] = None
    ) -> None:
        """Initialize model.
        
        Args:
            input_dim: Input feature dimension
            hidden_dims: List of hidden layer dimensions
            output_dim: Output dimension
            activation: Activation function (name or callable)
            dropout_rate: Dropout rate (None to disable)
        """
        super().__init__()
        
        # Store config
        self.input_dim: int = input_dim
        self.hidden_dims: List[int] = hidden_dims
        self.output_dim: int = output_dim
        
        # Build architecture
        all_dims: List[int] = [input_dim] + hidden_dims + [output_dim]
        self.layers: List[bst.nn.Linear] = []
        
        for i, (in_d, out_d) in enumerate(zip(all_dims[:-1], all_dims[1:])):
            layer: bst.nn.Linear = bst.nn.Linear(in_d, out_d)
            self.layers.append(layer)
            setattr(self, f'layer_{i}', layer)
    
    def forward(self, x: jnp.ndarray) -> jnp.ndarray:
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
            
        Raises:
            ValueError: If input shape is invalid
        """
        # Validate
        if x.shape[-1] != self.input_dim:
            raise ValueError(
                f"Expected input dim {self.input_dim}, got {x.shape[-1]}"
            )
        
        # Process
        for layer in self.layers:
            x = layer(x)
        
        return x
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        return self.forward(x)
    
    def summary(self) -> Dict[str, Any]:
        """Get model summary.
        
        Returns:
            Dictionary with model information
        """
        n_params: int = sum(
            p.value.size for p in self.states(bst.ParamState).values()
        )
        
        summary: Dict[str, Any] = {
            'input_dim': self.input_dim,
            'hidden_dims': self.hidden_dims,
            'output_dim': self.output_dim,
            'n_layers': len(self.layers),
            'n_parameters': n_params,
        }
        
        return summary

# Create and test
model: BestPracticeModel = BestPracticeModel(
    input_dim=10,
    hidden_dims=[20, 15],
    output_dim=5
)

summary: Dict[str, Any] = model.summary()
print("Model Summary:")
print("=" * 60)
for key, value in summary.items():
    print(f"  {key}: {value}")

## Summary

In this tutorial, we covered:

1. **Basic Type Annotations**: Using standard Python typing
2. **Generic Types**: Union, Optional, Tuple, List, Dict
3. **OneOfTypes**: BrainState's union-like types
4. **JointTypes**: Composite type requirements
5. **Neural Network Types**: Comprehensive typing for models
6. **State Management Types**: Typed state handling
7. **Type-Safe APIs**: Protocols and runtime checking
8. **Validation**: Runtime type and shape validation
9. **Best Practices**: Complete typing guidelines

## Key Takeaways

- **Type annotations improve code quality** and catch bugs early
- Use **Optional** for nullable values
- Use **Union** for multiple type options
- **Document types** in docstrings
- **Validate inputs** at runtime when needed
- **Protocol** enables structural typing
- Types make **refactoring safer**

## Best Practices

1. Always annotate function signatures
2. Use Optional for nullable parameters
3. Document types in docstrings
4. Validate critical inputs at runtime
5. Use Protocol for interface requirements
6. Keep type annotations simple and readable
7. Use typing for self-documentation

## Next Steps

In the next tutorial, we'll explore:
- **Utility Functions**: filter, struct, PrettyObject, DictManager
- Helper functions for common tasks
- Pretty printing and visualization
- Dictionary and structure management