# Tutorial

In this `energnn` tutorial, we will review :
- How to install `energnn`.
- The interaction with typical implementations of `energnn.problem.Problem`, `energnn.problem.ProblemBatch` and `energnn.problem.ProblemLoader`.
- The creation of a GNN model using `energnn.model.ready_to_use`,
- The training of the model with `energnn.trainer.SimpleTrainer`,
- The usage of the trained model.

## Installation

To install the latest stable release of `energnn` on CPU,
```bash
pip install energnn
```
For the GPU version,
```bash
pip install energnn --extra gpu
```

## Problem Class

Let's consider the following use case.
Knowing the pair $(A, b)$, we wish to find $x$ such that $Ax = b$.
Let us generate a random problem instance and explore its interface.

In [1]:
from tests.utils import TestProblemGenerator

pb_generator = TestProblemGenerator(seed=7, n_max=4)
problem = pb_generator.generate_problem()

### Context Graph

The input of our GNN model is referred to as **the context**, instantiated as an `energnn.graph.Graph` object.

In this case, it is the pair $(A, b)$, framed as a Hyper Heterogeneous Multi Graph (H2MG).

In [2]:
# Let us explore the context structure
print(problem.context_structure)

         Addresses Features
Name                       
arrow   [from, to]  [value]
source        [id]  [value]


In [3]:
# Print the context graph associated to the problem instance
context, _ = problem.get_context()
print(context)

arrow
          addresses       features
               from   to     value
object_id                         
0               0.0  1.0  1.457083
1               0.0  3.0  3.200915
2               1.0  2.0  2.025772
3               2.0  1.0  1.216732
4               3.0  2.0  0.748781
source
          addresses  features
                 id     value
object_id                    
0               0.0 -0.155426
1               1.0  1.770309
2               2.0 -0.406632
3               3.0  0.654355



The **context** of this specific problem instance has:

- 5 **arrow** objects,
- 4 **source** objects.

### Decision Graph
The output of our GNN model is referred to as **the decision**, instantiated as an `energnn.graph.Graph` object.

In this case, it is the variable $x$.

This specific problem class has a helper method called `get_zero_decision` (not part of the mandatory interface),
that returns a decision of the right shape and structure, filled with zeros.

In [4]:
# Let us explore the decision structure
print(problem.decision_structure)

       Addresses Features
Name                     
source      None  [value]


Notice that decisions concern only a subset of the classes available in the context, and that they have no addresses and just features.

In [5]:
# Print the context graph associated to the problem instance
decision, _ = problem.get_zero_decision()
print(decision)

source
          features
             value
object_id         
0              0.0
1             -0.0
2              0.0
3              0.0



### Objective Function
The score of a given decision is instantiated as a `float`.

In this case, we use the Mean Squared Error $\frac{1}{2} \Vert x - x^\star \Vert^2$, where $x^\star$ is the solution.

We can evaluate it by injecting the zero decision we have just retrieved.

In [6]:
metrics, _ = problem.get_metrics(decision=decision)
print(metrics)

0.2227775752544403


### Gradient Graph
The gradient of the objective function is instantiated as an `energnn.graph.Graph` object.

In this case, it is the vector $x-x^\star$.
Notice that for more complex use cases, the gradient can have more complex expressions, or even require Monte-Carlo simulations.

We can evaluate it by injecting the zero decision.

In [7]:
gradient, _ = problem.get_gradient(decision=decision)
print(gradient)

source
           features
              value
object_id          
0         -0.070729
1          0.334200
2         -0.873894
3         -0.103574



Notice that this gradient is the exact same type of object as the decision.

Just as a quick sanity check, we can perform a gradient descent and make sure that the objective decreases.

In [8]:
from energnn.graph import JaxGraph

alpha = 0.5

objective, _ = problem.get_metrics(decision=decision)
print(f"Step 0, objective = {objective}")

for i in range(10):
    gradient, _ = problem.get_gradient(decision=decision)

    # Update decision
    numpy_gradient = gradient.to_numpy_graph()  # For now, we need to convert the gradient to a numpy graph.
    numpy_decision = decision.to_numpy_graph()
    numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
    decision = JaxGraph.from_numpy_graph(numpy_decision)

    objective, _ = problem.get_metrics(decision=decision)
    print(f"Step {i}, objective = {objective}")

Step 0, objective = 0.2227775752544403
Step 0, objective = 0.05569439381361008
Step 1, objective = 0.013923597522079945
Step 2, objective = 0.003480899380519986
Step 3, objective = 0.0008702246705070138
Step 4, objective = 0.00021755615307483822
Step 5, objective = 5.438903463073075e-05
Step 6, objective = 1.3597240467788652e-05
Step 7, objective = 3.399258957870188e-06
Step 8, objective = 8.49789046242222e-07
Step 9, objective = 2.1243486969524383e-07


The objective function successfully decreases! Now, let's explore how multiple problems can be batched together.

## Problem Batch

Interacting with a single problem instance is useful at inference time, or for debugging purposes.
But to train a whole Graph Neural Network model, it is necessary to process batches of problem instances altogether.

In [9]:
from tests.utils import TestProblemGenerator

pb_generator = TestProblemGenerator(seed=9, n_max=3)
problem_batch = pb_generator.generate_problem_batch(batch_size=3)

# Let us explore the context and decision structures
print("Context Structure:\n", problem_batch.context_structure, "\n")
print("Decision Structure:\n", problem_batch.decision_structure)

Context Structure:
          Addresses Features
Name                       
arrow   [from, to]  [value]
source        [id]  [value] 

Decision Structure:
        Addresses Features
Name                     
source      None  [value]


Here, contexts are still graphs, but this time with an extra dimension:

In [10]:
context, _ = problem_batch.get_context()
print(context)

arrow
                   addresses       features
                        from   to     value
batch_id object_id                         
0        0               0.0  0.0 -1.116066
         1               1.0  0.0 -0.481135
         2               1.0  1.0 -1.517331
         3               1.0  2.0 -0.490872
         4               2.0  1.0 -0.647947
         5               2.0  2.0  0.635891
         6               0.0  0.0  0.000000
         7               0.0  0.0  0.000000
         8               0.0  0.0  0.000000
1        0               0.0  0.0 -0.857040
         1               0.0  1.0  1.528224
         2               0.0  2.0  0.904988
         3               1.0  0.0  0.541645
         4               1.0  1.0  0.701052
         5               1.0  2.0 -0.054635
         6               2.0  0.0  0.081804
         7               2.0  1.0 -1.281731
         8               2.0  2.0  0.158457
2        0               0.0  0.0 -0.745505
         1               0

Notice that the different contexts of the batch do not have the same connectivity, and do not have the same number of **arrow** and **source** objects.
To batch the different contexts together, it is thus necessary to pad them with zeros.

Still, a `ProblemBatch` can be handled in a very similar way to a single `Problem`.

In [11]:
alpha = 0.5

decision, _ = problem_batch.get_zero_decision()
objective, _ = problem_batch.get_metrics(decision=decision)
print(f"Step 0, objective = {objective}")

for i in range(10):
    gradient, _ = problem_batch.get_gradient(decision=decision)

    # Update decision
    numpy_gradient = gradient.to_numpy_graph()
    numpy_decision = decision.to_numpy_graph()
    numpy_decision.feature_flat_array -= alpha * numpy_gradient.feature_flat_array
    decision = JaxGraph.from_numpy_graph(numpy_decision)

    objective, _ = problem_batch.get_metrics(decision=decision)
    print(f"Step {i}, objective = {objective}")

Step 0, objective = [1.205530047416687, 0.3518087863922119, 0.006591798271983862]
Step 0, objective = [0.30138251185417175, 0.08795219659805298, 0.0016479495679959655]
Step 1, objective = [0.07534561306238174, 0.021988049149513245, 0.0004119873046875]
Step 2, objective = [0.018836403265595436, 0.005497012287378311, 0.000102996826171875]
Step 3, objective = [0.0047090970911085606, 0.0013742543524131179, 2.574920654296875e-05]
Step 4, objective = [0.0011772741563618183, 0.00034356300602667034, 6.4373016357421875e-06]
Step 5, objective = [0.00029431749135255814, 8.589089702581987e-05, 1.6093254089355469e-06]
Step 6, objective = [7.357935101026669e-05, 2.1472886146511883e-05, 4.023313522338867e-07]
Step 7, objective = [1.8394850485492498e-05, 5.3682501857110765e-06, 1.0058283805847168e-07]
Step 8, objective = [4.598733994498616e-06, 1.342021732853027e-06, 2.514570951461792e-08]
Step 9, objective = [1.149680656453711e-06, 3.3547598832228687e-07, 6.28642737865448e-09]


Notice that there is a scores are now lists of `float`.

## Problem Loader

Being able to process problem instances per batch is nice, but not enough.
To train a Graph Neural Network, we'll need to iterate over multiple minibatches of problem instances.
That's where the `ProblemLoader` class comes in.

In [12]:
from tests.utils import TestProblemLoader

problem_loader = TestProblemLoader(batch_size=4, seed=7, dataset_size=16, n_max=4)

# Let us explore the context and decision structures
print("Context Structure:\n", problem_loader.context_structure, "\n")
print("Decision Structure:\n", problem_loader.decision_structure)

Context Structure:
          Addresses Features
Name                       
arrow   [from, to]  [value]
source        [id]  [value] 

Decision Structure:
        Addresses Features
Name                     
source      None  [value]


It allows to iterate over batches of problems.

In [13]:
for problem_batch in problem_loader:
    context, _ = problem_batch.get_context()
    decision, _ = problem_batch.get_zero_decision()
    objective, _ = problem_batch.get_metrics(decision=decision)
    print("Objective:", objective)

Objective: [0.2227775752544403, 0.5266321897506714, 0.04437382146716118, 0.17635096609592438]
Objective: [0.3786073327064514, 0.6252117156982422, 0.338590145111084, 0.010212971828877926]
Objective: [1.4664279222488403, 0.7719758749008179, 1.2088079452514648, 0.7055290937423706]
Objective: [1.4254088401794434, 2.1467173099517822, 0.37498927116394043, 0.3698219358921051]


## Graph Neural Network Model

Let us instantiate a small Graph Neural Network model, that is adapted to the context and decision structure of our problem class.

In [14]:
from energnn.model.ready_to_use import TinyRecurrentEquivariantGNN

model = TinyRecurrentEquivariantGNN(
    in_structure=problem_loader.context_structure,
    out_structure=problem_loader.decision_structure
)

Make sure that your model is in evaluation mode first!

In [15]:
model.eval()  # Set the model in evaluation mode.
# model.train()  # To set the model in train mode.

It is able to take as input a context and return a decision.

In [16]:
problem = pb_generator.generate_problem()
context, _ = problem.get_context()
decision, _ = model(context)
print(decision)

source
           features
              value
object_id          
0          0.237434



It can also process batches of contexts and return batches of decisions.

In [17]:
problem_batch = pb_generator.generate_problem_batch(batch_size=4)
context, _ = problem_batch.get_context()
decision, _ = model.forward_batch(graph=context)
print(decision)

source
                    features
                       value
batch_id object_id          
0        0          0.294847
         1          0.000000
         2          0.000000
1        0          0.000000
         1          3.260567
         2          0.000000
2        0          0.326514
         1          0.000000
         2          0.000000
3        0         -0.784679
         1         -0.000000
         2         -0.000000



## Trainer

Let us train our Graph Neural Network model over a problem loader. The core training loop is defined by the following pseudocode.

```python
for problem_batch in problem_loader:
    context, _ = problem_batch.get_context()
    decision, _ = model.forward_batch(context)
    gradient, _ = problem_batch.get_gradient(decision)
    model.backprop(gradient)
```

In practice, we use `energnn.trainer` to implement the training logic, and allow to use :

- `optax` for the optimizer,
- `orbax` for checkpointing and saving/loading models.

In [18]:
from energnn.trainer import SimpleTrainer
import optax

trainer = SimpleTrainer(model=model, gradient_transformation=optax.adam(learning_rate=1e-3))

The training is performed by iterating over a **train** loader, and the validation score is periodically computed on a **validation** loader.

In [19]:
train_loader = TestProblemLoader(seed=7, dataset_size=64, batch_size=4, n_max=3)
val_loader = TestProblemLoader(seed=8, dataset_size=8, batch_size=4, n_max=3)

In [20]:
_ = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    eval_before_training=True,
    n_epochs=3,
)

Validation: 100%|██████████| 2/2 [00:02<00:00,  1.35s/batch, metrics=2.1002e+00]
Epoch 1/3: 100%|██████████| 16/16 [00:23<00:00,  1.46s/batch]
Validation: 100%|██████████| 2/2 [00:02<00:00,  1.24s/batch, metrics=1.0778e+00]
Epoch 2/3: 100%|██████████| 16/16 [00:20<00:00,  1.28s/batch]
Validation: 100%|██████████| 2/2 [00:02<00:00,  1.22s/batch, metrics=9.9687e-01]
Epoch 3/3: 100%|██████████| 16/16 [00:21<00:00,  1.33s/batch]
Validation: 100%|██████████| 2/2 [00:02<00:00,  1.26s/batch, metrics=9.3473e-01]
