---
# 02_modelgraph_basics.ipynb
---

## Building a ModelGraph

One of the most powerful aspects of ModularML is its `ModelGraph` abstraction, which represents a directed-acyclic-graph (DAG) of computation. This structure allows multiple `ModelStage` instances to be flexibly connected into a larger model pipeline.

Each `ModelStage` can use any supported backend, such as PyTorch, TensorFlow/Keras, or Scikit-learn. This enables the creation of complex multi-objective modeling workflows using a unified interface.

In this example, we demonstrate a two-stage modeling pipeline: a CNN encoder processes input voltage features into a latent embedding, followed by an MLP regressor that estimates the battery state-of-health (SOH) from this embedding.

ModularML provides pre-built classes for commonly used model types such as sequential CNNs and MLPs. While we use those here, any custom model can be integrated by subclassing `modularml.BaseModel` and implementing the required methods.

Let's import the necessary components:

In [None]:
import modularml as mml
from modularml.core import FeatureSet, ModelGraph, ModelStage, Optimizer
from modularml.models.torch import SequentialCNN, SequentialMLP

We will be utilizing the FeatureSet created from the [01_featureset_basics.ipynb](./01_featureset_basics.ipynb) notebook. 

Let's reload that FeatureSet and underlying FeatureTransforms from the `.joblib` file:

In [None]:
from pathlib import Path

FILE_FEATURE_SET = Path("downloaded_data/charge_samples.joblib")
charge_samples = FeatureSet.load(FILE_FEATURE_SET)
charge_samples

Now we can start creating our `ModelStages`.

The `modularml.models` module provides convenient, pre-built implementations such as `SequentialCNN` and `SequentialMLP`, which allow for rapid prototyping of convolutional and dense architectures with configurable layer depth and hidden sizes. Please refer to the module documentation for a full list of available initialization parameters.

A key feature of the `ModelStage` abstraction is its support for **lazy shape inference**. Input and output shapes do not need to be explicitly specified during model construction. Instead, ModularML dynamically infers the required shapes at runtime based on how FeatureSets and other ModelStages are connected in the ModelGraph.

While input shape inference is automatic, it is generally advisable to specify the desired output shape for clarity and to avoid unintended behavior.

To construct a `ModelStage`, the following arguments are required:

* `model`: The machine learning model to be wrapped, which must inherit from `BaseModel`.

* `label`: A unique string identifier for the stage.

* `upstream_node`: The nodes that feed into this stage. Can be the label (str) of such nodes, or the nodes themselves.

* `optimizer`: An optional `Optimizer` object used for training, required if the model parameters are to be updated during optimization. Note that we can define an Optimzer at the stage-level (if each stage is using a different backend) or at the graph-level (all stages must be the same backend).

In [None]:
ms_encoder = ModelStage(
    model=SequentialCNN(output_shape=(1, 32), n_layers=2, hidden_dim=16, flatten_output=True),
    label="Encoder",
    upstream_node="ChargePulses",  # Note that we could also pass the charge_samples object itself
)

In [None]:
ms_regressor = ModelStage(
    model=SequentialMLP(output_shape=(1, 1), n_layers=2, hidden_dim=16),
    label="Regressor",
    upstream_node=ms_encoder,  # Here, we pass the encoder object itself, but we could also use the string 'Encoder'
)

With both stages defined, we construct the `ModelGraph`. 

`ModelGraph` requires only one argument:

* `nodes`: A list of `ModelStage` or `FeatureSet` instances to incorporate into this ModelGraph. The order of the nodes does not matter, as long as all required inputs are included.

The ModelGraph will handle all data routing, shape inference, and connection validation with the `.build_all()` method.

In [None]:
mg = ModelGraph(
    nodes=[charge_samples, ms_encoder, ms_regressor],
    optimizer=Optimizer(name="adam", backend=mml.Backend.TORCH),
)
mg.build_all()

We see that the missing input_shapes have been correctly inferred to match `charge_samples.feature_shape` and encoder output shape.


ModelGraph has another useful validation function called `dummy_forward`.
This performs a full forward pass of all connected stages with dummy batch data.

In [None]:
all_stage_results = mg.dummy_foward(batch_size=32)
all_stage_results["Regressor"].feature_shape

Great. We have a fully functional ModelGraph that correctly outputs a target with shape (1,1).

Although this ModelGraph is very simple, as the number of nodes increase, it can be difficult to keep track of how all stages are connected.
We can visuallize these node connections with the `visualize` method.

In [None]:
mg.visualize()

---
### Using `ModelGraph.insert`

Instead of rebuilding a whole `ModelGraph` with 
```python
    ModelGraph(nodes=[...], ...)
```
you can modify an existing graph with the `insert` method.

The method signature is:
```python
    ModelGraph.insert(node, before=None, after=None, inplace=True)
```
Arguments:
* `node`: the new graph node (FeatureSet, ModelStage, etc.) to insert.
* `before`: name (or list of names) of downstream node(s) that the new node should feed into.
* `after`: name (or list of names) of upstream node(s) whose outputs should feed into the new node.
* `inplace`: if True, modifies the current graph directly; if False, returns a new graph with the insertion.

How it rewires the graph:
* If both after and before are provided:
  - Inserts the new node between them, replacing the existing connection.
* If only before is given:
  - Redirects all inputs to the specified before node so they first pass through the new node.
* If only after is given:
  - Redirects all outputs from the specified after node so they flow through the new node.

This is useful for quickly adding new head nodes, intermediate model layers, or feature sets without re-instantiating the entire ModelGraph.

Let's start with a slightly more complex ModelGraph to better visualize how connections are modified:

In [None]:
from modularml.core import ConcatStage

nodes = [
    charge_samples,
    ModelStage(label="Encoder A", model=SequentialMLP(output_shape=(1, 32)), upstream_node=charge_samples),
    ModelStage(label="Encoder B", model=SequentialMLP(output_shape=(1, 32)), upstream_node=charge_samples),
    ConcatStage(label="Merge", upstream_nodes=["Encoder A", "Encoder B"], axis=1),
    ModelStage(label="Regressor", model=SequentialMLP(output_shape=(1, 1)), upstream_node="Merge"),
]
mg = ModelGraph(nodes=nodes, optimizer=Optimizer(name="adam", backend=mml.Backend.TORCH))
mg.build_all()
mg.visualize()

**Scenario 1: provided after and before**

Here we insert a new node "Before+After" between our FeatureSet and "Encoder A"

In [None]:
new_node = ModelStage(label="Before+After", model=SequentialMLP(output_shape=(1, 64)), upstream_node="Encoder A")
mg.insert(node=new_node, before="Merge", after="Encoder A", inplace=True)
mg.visualize()

We see that it is inserted between the specified nodes, replacing the existing connection.

We'll remove it before exploring the alternative insert methods.

In [None]:
mg.remove("Before+After")
mg.visualize()

**Scenario 2: provided only before**

We could've achieved the same result using only the 'before' argument.

The difference with this approach only arises when the node we are inserting before has multiple inputs.
In that case, ***all inputs*** get rewired to pass through this new node whereas specifying before and after will only rewire the single connection.

In [None]:
new_node = ModelStage(label="Before", model=SequentialMLP(output_shape=(1, 64)), upstream_node="ChargePulses")
mg.insert(node=new_node, before="Encoder A", inplace=True)
mg.visualize()

As expected, we got the same result as using before and after.

If we instead insert before the 'Merge', we'll see how all inputs get rewired.

In [None]:
mg.remove("Before")
mg.visualize()

In [None]:
new_node = ConcatStage(label="Before", upstream_nodes=["Encoder A", "Encoder B"], axis=1)
mg.insert(node=new_node, before="Merge", inplace=True)
mg.visualize()

In [None]:
mg.remove("Before")
mg.visualize()

**Scenario 3: provided only after**

Similarly, specifying only 'after' results in shifting all downstream connections to the new node.

In [None]:
new_node = ModelStage(label="After", model=SequentialMLP(output_shape=(1, 64)), upstream_node="ChargePulses")
mg.insert(node=new_node, after="ChargePulses", inplace=True)
mg.visualize()

We see that the two downstream connection of 'ChargePulses' were moved onto the new 'After' ModelStage.
Now 'ChargePulses' only outputs into 'After'.

This concludes the **02_modelgraph_basics** notebook.

The next tutorial explain the `Experiment` container and ModelGraph training/evaluation logic: [03_training_and_evaluation.ipynb](./03_training_and_evaluation.ipynb)