# Tutorial 14: Computation Graph and Node System

In this tutorial, we'll explore BrainState's computation graph system based on the `graph.Node` class, which provides the foundation for building complex neural network architectures.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Understand the graph.Node base class and its role
- Build computation graphs with nested modules
- Manage state in graph structures
- Traverse and inspect computation graphs
- Optimize graph operations
- Create custom graph nodes and operators
- Implement advanced architectures using the graph system

## What is the Graph System?

BrainState's graph system provides:
- **Hierarchical structure**: Organize modules in tree-like graphs
- **State management**: Automatic tracking of parameters and state
- **Composability**: Combine simple nodes into complex networks
- **Introspection**: Query and analyze model structure

In [1]:
import brainstate
import jax
import jax.numpy as jnp

# Set random seed
brainstate.random.seed(42)

## 1. The graph.Node Base Class

All BrainState modules inherit from `graph.Node`, which provides core functionality for the computation graph.

In [2]:
# Basic Node example
class SimpleNode(brainstate.graph.Node):
    """A simple node with parameters and state."""
    
    def __init__(self, size: int):
        super().__init__()
        # Parameters (trainable)
        self.weight = brainstate.ParamState(brainstate.random.randn(size))
        self.bias = brainstate.ParamState(jnp.zeros(size))
        
        # Hidden state (non-trainable, changes during computation)
        self.activation = brainstate.ShortTermState(jnp.zeros(size))
        
        # Regular attributes
        self.size = size
    
    def __call__(self, x):
        # Compute output
        output = x * self.weight.value + self.bias.value
        # Update hidden state
        self.activation.value = output
        return output

# Create and use
node = SimpleNode(size=5)
x = jnp.ones(5)
output = node(x)

print(f"Output: {output}")
print(f"\nNode attributes:")
print(f"  weight: {node.weight.value.shape}")
print(f"  bias: {node.bias.value.shape}")
print(f"  activation: {node.activation.value.shape}")
print(f"  size: {node.size}")

Output: [ 0.60576403  0.7990441  -0.908927   -0.63525754 -1.2226585 ]

Node attributes:
  weight: (5,)
  bias: (5,)
  activation: (5,)
  size: 5


### Node Features

In [3]:
# Explore Node capabilities
print("=" * 60)
print("Node State Management")
print("=" * 60)

# Get all states
all_states = node.states()
print(f"\nAll states: {list(all_states.keys())}")

# Get only parameters
params = node.states(brainstate.ParamState)
print(f"Parameters: {list(params.keys())}")

# Get only hidden states
hidden = node.states(brainstate.ShortTermState)
print(f"Hidden states: {list(hidden.keys())}")

# Count parameters
total_params = sum(s.value.size for s in params.values())
print(f"\nTotal parameters: {total_params}")

Node State Management


AttributeError: 'SimpleNode' object has no attribute 'states'

## 2. Building Hierarchical Graphs

Nodes can contain other nodes, creating a hierarchical computation graph.

In [None]:
# Hierarchical model
class Layer(brainstate.graph.Node):
    """A single layer."""
    
    def __init__(self, in_features, out_features, name=None):
        super().__init__(name=name)
        self.linear = brainstate.nn.Linear(in_features, out_features)
    
    def __call__(self, x):
        return jnp.tanh(self.linear(x))

class MLP(brainstate.graph.Node):
    """Multi-layer perceptron with hierarchical structure."""
    
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = []
        
        for i in range(len(layer_sizes) - 1):
            layer = Layer(
                layer_sizes[i], 
                layer_sizes[i + 1],
                name=f'layer_{i}'
            )
            self.layers.append(layer)
            # Register as attribute for proper graph construction
            setattr(self, f'layer_{i}', layer)
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Create hierarchical model
model = MLP([10, 20, 15, 5])
x = brainstate.random.randn(3, 10)
output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nModel structure:")
print(f"  Number of layers: {len(model.layers)}")

### Graph Traversal

In [None]:
# Traverse the computation graph
def print_graph_structure(node, prefix="", is_last=True):
    """Recursively print graph structure."""
    marker = "└── " if is_last else "├── "
    node_name = node.__class__.__name__
    
    # Get node identifier
    if hasattr(node, 'name') and node.name:
        node_id = f"{node_name}(name='{node.name}')"
    else:
        node_id = node_name
    
    print(f"{prefix}{marker}{node_id}")
    
    # Get child nodes
    children = []
    for attr_name in dir(node):
        if attr_name.startswith('_'):
            continue
        try:
            attr = getattr(node, attr_name)
            if isinstance(attr, brainstate.graph.Node):
                children.append((attr_name, attr))
        except:
            pass
    
    # Print children
    extension = "    " if is_last else "│   "
    for i, (name, child) in enumerate(children):
        is_last_child = (i == len(children) - 1)
        print_graph_structure(child, prefix + extension, is_last_child)

print("Graph Structure:")
print("=" * 60)
print_graph_structure(model)

## 3. State Collection and Management

In [None]:
# Collect states from hierarchical model
print("State Collection:")
print("=" * 60)

# All parameters in the model
all_params = model.states(brainstate.ParamState)
print(f"\nTotal parameter tensors: {len(all_params)}")
print("\nParameter details:")
for name, param in list(all_params.items())[:5]:  # Show first 5
    print(f"  {name}: shape={param.value.shape}, size={param.value.size}")

# Calculate total parameters
total_params = sum(p.value.size for p in all_params.values())
print(f"\nTotal parameters: {total_params:,}")
print(f"Memory (float32): {total_params * 4 / 1024:.2f} KB")

### State Manipulation

In [None]:
# Save and restore states
class StatefulModel(brainstate.graph.Node):
    def __init__(self):
        super().__init__()
        self.linear = brainstate.nn.Linear(5, 3)
        self.counter = brainstate.ShortTermState(jnp.array(0))
    
    def __call__(self, x):
        self.counter.value += 1
        return self.linear(x)

model = StatefulModel()

# Use the model
x = brainstate.random.randn(2, 5)
y1 = model(x)
y2 = model(x)
print(f"Counter after 2 calls: {model.counter.value}")

# Save state
saved_states = {}
for name, state in model.states().items():
    saved_states[name] = state.value.copy()

print(f"\nSaved {len(saved_states)} states")

# Continue using
for _ in range(5):
    model(x)
print(f"Counter after 7 total calls: {model.counter.value}")

# Restore state
for name, value in saved_states.items():
    model.states()[name].value = value

print(f"Counter after restoration: {model.counter.value}")

## 4. Advanced Graph Patterns

### Skip Connections (ResNet-style)

In [None]:
class ResidualBlock(brainstate.graph.Node):
    """Residual block with skip connection."""
    
    def __init__(self, dim):
        super().__init__()
        self.linear1 = brainstate.nn.Linear(dim, dim)
        self.linear2 = brainstate.nn.Linear(dim, dim)
    
    def __call__(self, x):
        # Main path
        residual = x
        x = jax.nn.relu(self.linear1(x))
        x = self.linear2(x)
        
        # Skip connection
        return jax.nn.relu(x + residual)

class ResNet(brainstate.graph.Node):
    """Simple ResNet with multiple residual blocks."""
    
    def __init__(self, dim, n_blocks):
        super().__init__()
        self.blocks = [ResidualBlock(dim) for _ in range(n_blocks)]
        # Register blocks
        for i, block in enumerate(self.blocks):
            setattr(self, f'block_{i}', block)
    
    def __call__(self, x):
        for block in self.blocks:
            x = block(x)
        return x

# Test ResNet
resnet = ResNet(dim=10, n_blocks=3)
x = brainstate.random.randn(5, 10)
output = resnet(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nResNet structure:")
print_graph_structure(resnet)

### Multi-Path Networks (Inception-style)

In [None]:
class InceptionBlock(brainstate.graph.Node):
    """Inception-style block with multiple parallel paths."""
    
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # Path 1: 1x1 conv
        self.path1 = brainstate.nn.Linear(in_dim, out_dim // 4)
        
        # Path 2: 1x1 -> 3x3 (simulated with linear)
        self.path2_a = brainstate.nn.Linear(in_dim, out_dim // 4)
        self.path2_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)
        
        # Path 3: 1x1 -> 5x5 (simulated)
        self.path3_a = brainstate.nn.Linear(in_dim, out_dim // 4)
        self.path3_b = brainstate.nn.Linear(out_dim // 4, out_dim // 4)
        
        # Path 4: pool -> 1x1
        self.path4 = brainstate.nn.Linear(in_dim, out_dim // 4)
    
    def __call__(self, x):
        # Execute all paths
        out1 = jax.nn.relu(self.path1(x))
        
        out2 = jax.nn.relu(self.path2_a(x))
        out2 = jax.nn.relu(self.path2_b(out2))
        
        out3 = jax.nn.relu(self.path3_a(x))
        out3 = jax.nn.relu(self.path3_b(out3))
        
        out4 = jax.nn.relu(self.path4(x))
        
        # Concatenate outputs
        return jnp.concatenate([out1, out2, out3, out4], axis=-1)

# Test Inception block
inception = InceptionBlock(in_dim=16, out_dim=32)
x = brainstate.random.randn(4, 16)
output = inception(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nInception block has {len(inception.states(brainstate.ParamState))} parameter tensors")

## 5. Dynamic Graphs and Conditional Execution

In [None]:
class AdaptiveDepthNetwork(brainstate.graph.Node):
    """Network that adapts its depth based on input."""
    
    def __init__(self, dim, max_depth=5):
        super().__init__()
        self.dim = dim
        self.max_depth = max_depth
        
        # Create all possible layers
        self.layers = []
        for i in range(max_depth):
            layer = brainstate.nn.Linear(dim, dim)
            self.layers.append(layer)
            setattr(self, f'layer_{i}', layer)
        
        # Confidence predictor
        self.confidence = brainstate.nn.Linear(dim, 1)
    
    def __call__(self, x, confidence_threshold=0.9):
        """Process input, stopping early if confident."""
        depth_used = 0
        
        for i, layer in enumerate(self.layers):
            x = jax.nn.relu(layer(x))
            depth_used = i + 1
            
            # Check confidence (simplified)
            conf = jax.nn.sigmoid(self.confidence(x))
            
            # In practice, you'd use jax.lax.cond for JIT compatibility
            if jnp.mean(conf) > confidence_threshold:
                break
        
        return x, depth_used

# Test adaptive network
adaptive_net = AdaptiveDepthNetwork(dim=8, max_depth=5)

# Easy input (should stop early)
x_easy = jnp.ones((3, 8)) * 0.1
out_easy, depth_easy = adaptive_net(x_easy, confidence_threshold=0.5)

# Hard input (may use more layers)
x_hard = brainstate.random.randn(3, 8) * 2
out_hard, depth_hard = adaptive_net(x_hard, confidence_threshold=0.9)

print(f"Easy input used {depth_easy} layers")
print(f"Hard input used {depth_hard} layers")
print(f"\nAdaptive execution saves computation!")

## 6. Graph Optimization and Analysis

In [None]:
# Analyze graph structure
def analyze_graph(node, prefix="root"):
    """Analyze computation graph statistics."""
    stats = {
        'total_nodes': 0,
        'total_params': 0,
        'param_tensors': 0,
        'node_types': {},
        'depth': 0
    }
    
    def traverse(n, depth=0):
        stats['total_nodes'] += 1
        stats['depth'] = max(stats['depth'], depth)
        
        # Count node type
        node_type = n.__class__.__name__
        stats['node_types'][node_type] = stats['node_types'].get(node_type, 0) + 1
        
        # Count parameters
        params = n.states(brainstate.ParamState)
        stats['param_tensors'] += len(params)
        stats['total_params'] += sum(p.value.size for p in params.values())
        
        # Traverse children
        for attr_name in dir(n):
            if attr_name.startswith('_'):
                continue
            try:
                attr = getattr(n, attr_name)
                if isinstance(attr, brainstate.graph.Node):
                    traverse(attr, depth + 1)
            except:
                pass
    
    traverse(node)
    return stats

# Analyze different models
models = {
    'MLP': MLP([10, 20, 15, 5]),
    'ResNet': ResNet(dim=10, n_blocks=3),
    'Inception': InceptionBlock(in_dim=16, out_dim=32)
}

print("Graph Analysis:")
print("=" * 80)
print(f"{'Model':<15} {'Nodes':<8} {'Depth':<8} {'Param Tensors':<15} {'Total Params':<15}")
print("-" * 80)

for name, model in models.items():
    stats = analyze_graph(model)
    print(f"{name:<15} {stats['total_nodes']:<8} {stats['depth']:<8} "
          f"{stats['param_tensors']:<15} {stats['total_params']:<15,}")

print("\n" + "=" * 80)

# Detailed analysis of one model
print("\nDetailed Analysis of ResNet:")
resnet_stats = analyze_graph(models['ResNet'])
print(f"  Node types: {resnet_stats['node_types']}")
print(f"  Memory footprint: {resnet_stats['total_params'] * 4 / 1024:.2f} KB")

## 7. Custom Graph Operations

In [None]:
# Parameter initialization across graph
def initialize_graph(node, init_fn):
    """Initialize all parameters in graph with given function."""
    params = node.states(brainstate.ParamState)
    for name, param in params.items():
        param.value = init_fn(param.value.shape)
    return node

# Test initialization
model = MLP([5, 10, 5])

# Custom initialization: Xavier/Glorot
def xavier_init(shape):
    if len(shape) == 1:
        return jnp.zeros(shape)
    fan_in, fan_out = shape[0], shape[-1]
    limit = jnp.sqrt(6.0 / (fan_in + fan_out))
    return brainstate.random.uniform(-limit, limit, shape)

# Before initialization
params_before = model.states(brainstate.ParamState)
sample_param_before = list(params_before.values())[0].value
print(f"Before initialization (random):")
print(f"  Sample mean: {jnp.mean(sample_param_before):.4f}")
print(f"  Sample std: {jnp.std(sample_param_before):.4f}")

# Initialize
initialize_graph(model, xavier_init)

# After initialization
params_after = model.states(brainstate.ParamState)
sample_param_after = list(params_after.values())[0].value
print(f"\nAfter Xavier initialization:")
print(f"  Sample mean: {jnp.mean(sample_param_after):.4f}")
print(f"  Sample std: {jnp.std(sample_param_after):.4f}")

### Freeze/Unfreeze Parameters

In [None]:
# Parameter freezing for transfer learning
class FreezableModel(brainstate.graph.Node):
    """Model with freezable layers."""
    
    def __init__(self):
        super().__init__()
        self.feature_extractor = brainstate.nn.Linear(10, 20, name='features')
        self.classifier = brainstate.nn.Linear(20, 5, name='classifier')
        self.frozen_params = set()
    
    def freeze_features(self):
        """Freeze feature extractor parameters."""
        feature_params = self.feature_extractor.states(brainstate.ParamState)
        for name in feature_params.keys():
            self.frozen_params.add(id(feature_params[name]))
    
    def get_trainable_params(self):
        """Get only non-frozen parameters."""
        all_params = self.states(brainstate.ParamState)
        return {k: v for k, v in all_params.items() 
                if id(v) not in self.frozen_params}
    
    def __call__(self, x):
        x = jax.nn.relu(self.feature_extractor(x))
        return self.classifier(x)

# Test freezing
model = FreezableModel()

print("Before freezing:")
print(f"  Total params: {len(model.states(brainstate.ParamState))}")
print(f"  Trainable params: {len(model.get_trainable_params())}")

# Freeze features
model.freeze_features()

print("\nAfter freezing features:")
print(f"  Total params: {len(model.states(brainstate.ParamState))}")
print(f"  Trainable params: {len(model.get_trainable_params())}")
print(f"  Trainable param names: {list(model.get_trainable_params().keys())}")

## 8. Practical Example: Neural Architecture Search (NAS)

In [None]:
# Simple NAS with different layer types
class SearchableBlock(brainstate.graph.Node):
    """Block that can choose between different operations."""
    
    def __init__(self, dim):
        super().__init__()
        # Multiple operation choices
        self.ops = {
            'linear': brainstate.nn.Linear(dim, dim),
            'skip': lambda x: x,
            'zero': lambda x: jnp.zeros_like(x),
        }
        
        # Architecture parameters (which op to use)
        self.arch_params = brainstate.ParamState(jnp.ones(len(self.ops)) / len(self.ops))
        
        # Register linear op for state tracking
        self.linear = self.ops['linear']
    
    def __call__(self, x, mode='soft'):
        """Execute block.
        
        mode='soft': weighted combination of all ops
        mode='hard': use best op only
        """
        if mode == 'soft':
            # Soft selection: weighted combination
            weights = jax.nn.softmax(self.arch_params.value)
            output = jnp.zeros_like(x)
            
            for i, (name, op) in enumerate(self.ops.items()):
                output += weights[i] * op(x)
            
            return output
        else:
            # Hard selection: use best op
            best_idx = jnp.argmax(self.arch_params.value)
            ops_list = list(self.ops.values())
            return ops_list[best_idx](x)

class SearchableNetwork(brainstate.graph.Node):
    """Network with searchable architecture."""
    
    def __init__(self, dim, n_blocks):
        super().__init__()
        self.blocks = [SearchableBlock(dim) for _ in range(n_blocks)]
        for i, block in enumerate(self.blocks):
            setattr(self, f'block_{i}', block)
    
    def __call__(self, x, mode='soft'):
        for block in self.blocks:
            x = block(x, mode=mode)
        return x
    
    def get_architecture(self):
        """Get discovered architecture."""
        arch = []
        op_names = list(self.blocks[0].ops.keys())
        
        for i, block in enumerate(self.blocks):
            weights = jax.nn.softmax(block.arch_params.value)
            best_op = op_names[jnp.argmax(weights)]
            arch.append((i, best_op, float(jnp.max(weights))))
        
        return arch

# Create searchable network
nas_net = SearchableNetwork(dim=8, n_blocks=3)
x = brainstate.random.randn(4, 8)

# Simulate architecture search (random for demo)
for block in nas_net.blocks:
    # Randomly prefer different operations
    block.arch_params.value = brainstate.random.randn(len(block.ops))

# Get discovered architecture
architecture = nas_net.get_architecture()

print("Discovered Architecture:")
print("=" * 60)
for block_id, op_name, confidence in architecture:
    print(f"  Block {block_id}: {op_name:<10} (confidence: {confidence:.2%})")

# Test both modes
out_soft = nas_net(x, mode='soft')
out_hard = nas_net(x, mode='hard')

print(f"\nOutput shape (soft): {out_soft.shape}")
print(f"Output shape (hard): {out_hard.shape}")

## Summary

In this tutorial, we covered:

1. **graph.Node Base Class**: Foundation of BrainState's module system
2. **Hierarchical Graphs**: Building complex nested structures
3. **State Management**: Collecting and manipulating states across graphs
4. **Advanced Patterns**: ResNet, Inception, adaptive depth networks
5. **Graph Analysis**: Introspection and statistics
6. **Custom Operations**: Initialization, freezing, graph manipulation
7. **Practical Applications**: Neural architecture search

## Key Takeaways

- **graph.Node** provides hierarchical structure and state management
- States are automatically tracked across the graph
- Graphs can be **traversed, analyzed, and manipulated**
- **Skip connections** and **multi-path networks** are easy to implement
- Graph operations enable **transfer learning** and **NAS**
- The graph system is **composable** and **flexible**

## Best Practices

1. Always call `super().__init__()` in custom nodes
2. Register child nodes as attributes for proper graph construction
3. Use descriptive names for better debugging
4. Leverage state types (ParamState, ShortTermState) appropriately
5. Design for composability - small, reusable components
6. Document your graph structure for clarity

## Next Steps

In the next tutorial, we'll explore:
- **Mixin System**: Mode, JointMode, Batching, Training
- Computation modes for different behaviors
- Custom mixins for specialized functionality