# IR Analysis and Optimization

The core capability of the brainscale framework lies in converting user-defined neural network models into efficient Intermediate Representation (IR) and performing deep analysis and optimization based on this representation. By extracting information flow and dependency relationships between key components of the model, brainscale can generate efficient online learning training code, enabling fast training and inference of neural networks.

This guide provides a comprehensive introduction to the IR analysis and optimization process in brainscale, including:
- **Model Information Extraction**: Obtaining complete structural information of the model
- **State Group Analysis**: Identifying sets of interdependent state variables
- **Parameter-State Relationships**: Analyzing connection relationships between parameters and hidden states
- **State Perturbation Mechanism**: Implementing efficient gradient computation and optimization

## Environment Setup and Model Preparation

### Dependency Imports

In [1]:
import brainscale
import brainstate
import brainunit as u

# Set up simulation environment
brainstate.environ.set(dt=0.1 * u.ms)

We use the ALIF model + STP synapse model from brainscale for testing as an example. This model is defined in `brainscale._etrace_model_test`.

### Example Model

We use brainscale's built-in ALIF (Adaptive Leaky Integrate-and-Fire) neuron model combined with STP (Short-Term Plasticity) synapse model as a demonstration case:

In [2]:
from brainscale._etrace_model_test import ALIF_STPExpCu_Dense_Layer

# Create network instance
n_in = 3    # Input dimension
n_rec = 4   # Recurrent layer dimension

net = ALIF_STPExpCu_Dense_Layer(n_in, n_rec)
brainstate.nn.init_all_states(net)

# Prepare input data
input_data = brainstate.random.rand(n_in)

## 1. Model Information Extraction: `ModuleInfo`

### 1.1 What is ModuleInfo

`ModuleInfo` is brainscale's complete description of neural network models, containing all key information of the model:
- **Input/Output Interface**: Data flow entry and exit points of the model
- **State Variables**: Dynamic variables such as neuron states and synaptic states
- **Parameter Variables**: Trainable parameters such as weights and biases
- **Computational Graph Structure**: Computational logic in JAX expression form

### 1.2 Extracting Model Information

In [3]:
# Extract complete model information
info = brainscale.extract_module_info(net, input_data)
print("Model information extraction completed")
print(f"Number of hidden states: {len(info.hidden_path_to_invar)}")
print(f"Number of compiled model states: {len(info.compiled_model_states)}")

Model information extraction completed
Number of hidden states: 5
Number of compiled model states: 6


### 1.3 Core Components of ModuleInfo Explained

#### Hidden State Mapping Relationships
- **`hidden_path_to_invar`**: Hidden state path → input variable mapping
- **`hidden_path_to_outvar`**: Hidden state path → output variable mapping
- **`invar_to_hidden_path`**: Input variable → hidden state path mapping
- **`outvar_to_hidden_path`**: Output variable → hidden state path mapping
- **`hidden_outvar_to_invar`**: State transfer relationship from output variables to input variables

#### Training Parameter Management
- **`weight_invars`**: List of input variables for all trainable parameters
- **`weight_path_to_invars`**: Mapping relationship from parameter paths to variables
- **`invar_to_weight_path`**: Reverse mapping from variables to parameter paths

#### Computational Graph Representation
- **`closed_jaxpr`**: Closed JAX expression of the model, describing the complete computational flow

### 1.4 Practical Application Examples

In [4]:
# View specific paths of hidden states
print("=== Hidden State Paths ===")
for path, var in info.hidden_path_to_invar.items():
    print(f"Path: {path}")
    print(f"Variable: {var}")
    print("---")

# View training parameter information
print("=== Training Parameter Information ===")
for path, invars in info.weight_path_to_invars.items():
    print(f"Parameter path: {path}")
    print(f"Number of variables: {len(invars)}")
    for i, var in enumerate(invars):
        print(f"  Variable{i}: {var}")

=== Hidden State Paths ===
Path: ('neu', 'V')
Variable: Var(id=2593518192448):float32[4]
---
Path: ('neu', 'a')
Variable: Var(id=2593518192512):float32[4]
---
Path: ('stp', 'u')
Variable: Var(id=2593518192576):float32[4]
---
Path: ('stp', 'x')
Variable: Var(id=2593518192640):float32[4]
---
Path: ('neu', '_before_updates', "(<class 'brainscale.nn.Expon'>, (4,), {'tau': 10. * msecond}) // (<class 'brainstate.nn.CUBA'>, (), {})", 'syn', 'g')
Variable: Var(id=2593518192896):float32[4]
---
=== Training Parameter Information ===
Parameter path: ('syn', 'comm', 'weight_op')
Number of variables: 2
  Variable0: Var(id=2593518192768):float32[4]
  Variable1: Var(id=2593518192832):float32[7,4]


## 2. State Group Analysis: `HiddenGroup`

### 2.1 Concept of State Groups

State Groups (HiddenGroup) are an important concept in brainscale that organize state variables in the model with the following characteristics:
- **Interdependence**: Direct or indirect dependency relationships exist between state variables
- **Element-wise Operations**: Operations between variables are primarily element-wise mathematical operations
- **Synchronized Updates**: These states need coordinated updates within time steps

### 2.2 State Group Extraction

In [5]:
# Extract state groups from ModuleInfo
hidden_groups, hid_path_to_group = brainscale.find_hidden_groups_from_minfo(info)

print(f"Number of state groups discovered: {len(hidden_groups)}")
print(f"Number of state path to group mappings: {len(hid_path_to_group)}")

Number of state groups discovered: 1
Number of state path to group mappings: 5


### 2.3 State Group Structure Analysis

In [6]:
# Detailed analysis of the first state group
if hidden_groups:
    group = hidden_groups[0]
    print("=== Detailed State Group Information ===")
    print(f"Group index: {group.index}")
    print(f"Number of state paths included: {len(group.hidden_paths)}")
    print(f"Number of hidden states: {len(group.hidden_states)}")
    print(f"Number of input variables: {len(group.hidden_invars)}")
    print(f"Number of output variables: {len(group.hidden_outvars)}")

    print("\n--- State Path List ---")
    for i, path in enumerate(group.hidden_paths):
        print(f"{i+1}. {path}")

=== Detailed State Group Information ===
Group index: 0
Number of state paths included: 5
Number of hidden states: 5
Number of input variables: 5
Number of output variables: 5

--- State Path List ---
1. ('stp', 'u')
2. ('stp', 'x')
3. ('neu', 'V')
4. ('neu', '_before_updates', "(<class 'brainscale.nn.Expon'>, (4,), {'tau': 10. * msecond}) // (<class 'brainstate.nn.CUBA'>, (), {})", 'syn', 'g')
5. ('neu', 'a')


### 2.4 Practical Significance of State Groups

State group identification is important for optimization:
- **Parallel Computing**: Operations within the same state group can be executed in parallel
- **Memory Optimization**: Related states can be arranged compactly in memory
- **Gradient Computation**: Simplifies the computational graph structure for backpropagation

## 3. Parameter-State Relationship Analysis: `HiddenParamOpRelation`

### 3.1 Definition of Relationship Groups

`HiddenParamOpRelation` describes the operational relationships between training parameters and hidden states, which is key to understanding the model's learning mechanism:
- **Parameter Operations**: How weights act on state variables
- **Connection Patterns**: Which hidden states are affected by parameters
- **Computational Dependencies**: How parameter updates affect state transitions

### 3.2 Relationship Extraction

In [7]:
# Extract parameter-state relationships
hidden_param_op = brainscale.find_hidden_param_op_relations_from_minfo(info, hid_path_to_group)

print(f"Number of parameter-state relationships discovered: {len(hidden_param_op)}")

Number of parameter-state relationships discovered: 1


### 3.3 Detailed Relationship Structure

In [8]:
if hidden_param_op:
    relation = hidden_param_op[0]
    print("=== Parameter-State Relationship Details ===")
    print(f"Parameter path: {relation.path}")
    print(f"Input variable x: {relation.x}")
    print(f"Output variable y: {relation.y}")
    print(f"Number of affected hidden groups: {len(relation.hidden_groups)}")
    print(f"Number of connected hidden paths: {len(relation.connected_hidden_paths)}")

    print("\n--- Connected Hidden Paths ---")
    for path in relation.connected_hidden_paths:
        print(f"  {path}")

=== Parameter-State Relationship Details ===
Parameter path: ('syn', 'comm', 'weight_op')
Input variable x: Var(id=2593518203520):float32[7]
Output variable y: Var(id=2593518203712):float32[4]
Number of affected hidden groups: 1
Number of connected hidden paths: 2

--- Connected Hidden Paths ---
  ('neu', '_before_updates', "(<class 'brainscale.nn.Expon'>, (4,), {'tau': 10. * msecond}) // (<class 'brainstate.nn.CUBA'>, (), {})", 'syn', 'g')
  ('neu', 'V')


### 3.4 Gradient Computation Optimization

Analysis of parameter-state relationships enables brainscale to:
- **Precise Tracking**: Determine which states are affected by parameter updates
- **Efficient Computation**: Only compute necessary gradient paths
- **Memory Savings**: Avoid storing unnecessary intermediate gradients

## 4. State Perturbation Mechanism: `HiddenStatePerturbation`

### 4.1 Perturbation Mechanism Principles

State perturbation is the core technology of brainscale for implementing efficient gradient computation. By applying small perturbations to hidden states $y = f(x, h+\Delta)$, where $\Delta \to 0$, the system can:
- **Numerical Gradients**: Compute gradients through finite differences
- **Automatic Differentiation**: Build efficient backpropagation graphs

### 4.2 Perturbation Information Extraction

In [9]:
# Extract state perturbation information
hidden_perturb = brainscale.add_hidden_perturbation_from_minfo(info)

print("=== State Perturbation Information ===")
print(f"Number of perturbation variables: {len(hidden_perturb.perturb_vars)}")
print(f"Number of perturbation paths: {len(hidden_perturb.perturb_hidden_paths)}")
print(f"Number of perturbation states: {len(hidden_perturb.perturb_hidden_states)}")

=== State Perturbation Information ===
Number of perturbation variables: 5
Number of perturbation paths: 5
Number of perturbation states: 5


### 4.3 Perturbation Path Analysis

In [10]:
print("\n--- Perturbation Path Details ---")
for i, (path, var, state) in enumerate(
    zip(
        hidden_perturb.perturb_hidden_paths,
        hidden_perturb.perturb_vars,
        hidden_perturb.perturb_hidden_states
    )
):
    print(f"{i+1}. Path: {path}")
    print(f"   Variable: {var}")
    print(f"   State type: {type(state).__name__}")
    print("---")


--- Perturbation Path Details ---
1. Path: ('stp', 'u')
   Variable: Var(id=2593518099072):float32[4]
   State type: ETraceState
---
2. Path: ('stp', 'x')
   Variable: Var(id=2593518150208):float32[4]
   State type: ETraceState
---
3. Path: ('neu', '_before_updates', "(<class 'brainscale.nn.Expon'>, (4,), {'tau': 10. * msecond}) // (<class 'brainstate.nn.CUBA'>, (), {})", 'syn', 'g')
   Variable: Var(id=2593518143808):float32[4]
   State type: ETraceState
---
4. Path: ('neu', 'V')
   Variable: Var(id=2593518141952):float32[4]
   State type: ETraceState
---
5. Path: ('neu', 'a')
   Variable: Var(id=2593518143360):float32[4]
   State type: ETraceState
---


### 4.4 Role of Perturbation in Training

- **Gradient Precision**: Provides high-precision gradient estimation
- **Computational Efficiency**: Reduces unnecessary computational overhead
- **Numerical Stability**: Maintains numerical stability of the training process

## Summary

This guide provides a comprehensive introduction to the core concepts and practical methods of IR analysis and optimization in the brainscale framework. Through deep understanding of model information extraction, state group analysis, parameter relationship analysis, and state perturbation mechanisms, developers can:

- **Improve Performance**: Significantly enhance model training and inference efficiency through IR optimization
- **Enhance Understanding**: Gain deep insights into the internal structure and computational flow of neural network models
- **Optimize Design**: Design more efficient model architectures based on analysis results
- **Debugging Capabilities**: Quickly locate and solve problems in model training

The IR analysis capabilities of brainscale provide powerful tool support for efficient implementation of neural networks and serve as an important foundation for achieving high-performance online learning. Users can now customize IR analysis and optimization processes to meet specific application requirements.