[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/basf/mlipx/blob/main/docs/source/notebooks/structure_relaxation.ipynb)

# Structure Relaxtion with Custom Nodes

You can combine `mlipx` with custom code by writing ZnTrack nodes.
We will write a Node to perform a geometry relaxation similar to `mlipx.StructureOptimization`.

In [None]:
# We will create a GIT and DVC repository in a temporary directory
import os
import tempfile

temp_dir = tempfile.TemporaryDirectory()
os.chdir(temp_dir.name)

Like all `mlipx` Nodes we will use a GIT and DVC repository to run experiments.
To make our custom code available, we structure our project like

```
relaxation/
   ├── .git/
   ├── .dvc/
   ├── src/__init__.py
   ├── src/relaxation.py
   ├── models.py
   └── main.py
```

to allow us to import our code `from src.relaxation import Relax`.
Alternatively, you can package your code and import it like any other Python package.

In [None]:
!git init
!dvc init --quiet
!mkdir src
!touch src/__init__.py

Initialized empty Git repository in /tmp/tmpd5ehmvex/.git/


[0m

The code we want to put into our `Relax` `Node` is the following:


```python
from ase.optimize import BFGS
import ase.io

data: list[ase.Atoms]
calc: ase.calculator.Calculator

end_structures = []
for atoms in data:
    atoms.set_calculator(calc)
    opt = BFGS(atoms)
    opt.run(fmax=0.05)
    end_structures.append(atoms)

ase.io.write('end_structures.xyz', end_structures)
```

To do so, we need to identify and define the inputs and outputs of our code.
We provide the `data: list[ase.Atoms]` from a data loading Node.
Therefore, we use `data: list = zntrack.deps()`.
If you want to read directly from file you could use `data_path: str = zntrack.deps_path()`.
We access the calculator in a similar way using `model: NodeWithCalculator = zntrack.deps()`.
`mlipx` provides the `NodeWithCalculator` abstract base class for a common communication on how to share `ASE` calculators.
Another convention is providing inputs as `data: list[ase.Atoms]` and outputs as `frames: list[ase.Atoms]`.
As we save our data to a file, we define `frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')` to store the output trajetory in the node working directory (nwd) as `frames.xyz`.
The `zntrack.nwd` provides a unique directory per `Node` to store the data at.
As the communication between `mlipx` nodes is based on `ase.Atoms` we define a `@frames` property.
Within this, we could also alter the `ase.Atoms` object, thus making the node communication independent of the file format facilitating data communication via code or Data as Code (DaC).
To summarize, each Node provides all the information on how to `save` and `load` the produced data, simplifying communication and reducing issues with different file format conventions.

Besides the implemented fields, there are also `x: dict = zntrack.params`, `x: dict = zntrack.metrics` and `x: pd.DataFrame = zntrack.plots` and their corresponding file path versions `x: str|pathlib.Path = zntrack.params_path`, `zntrack.metrics_path` and `zntrack.plots_path`.
For general outputs there is `x: any = zntrack.outs`. More information can be found at https://dvc.org/doc/start/data-pipelines/metrics-parameters-plots .

In [None]:
%%writefile src/relaxation.py
import zntrack
from mlipx.abc import NodeWithCalculator
from ase.optimize import BFGS
import ase.io
import pathlib



class Relax(zntrack.Node):
    data: list = zntrack.deps()
    model: NodeWithCalculator = zntrack.deps()
    frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')

    def run(self):
        end_structures = []
        for atoms in self.data:
            atoms.set_calculator(self.model.get_calculator())
            opt = BFGS(atoms)
            opt.run(fmax=0.05)
            end_structures.append(atoms)
        with open(self.frames_path, 'w') as f:
            ase.io.write(f, end_structures, format='extxyz')
    
    @property
    def frames(self) -> list[ase.Atoms]:
        with self.state.fs.open(self.frames_path, "r") as f:
            return ase.io.read(f, format='extxyz', index=':')


Writing src/relaxation.py


With this Node definition, we can import the Node and connect it with `mlipx` to form a graph.

In [None]:
import zntrack
from src.relaxation import Relax

import mlipx

In [None]:
project = zntrack.Project()

emt = mlipx.GenericASECalculator(
    module="ase.calculators.emt",
    class_name="EMT",
)

with project:
    confs = mlipx.Smiles2Conformers(smiles="CCCC", num_confs=5)
    relax = Relax(data=confs.frames, model=emt)

project.build()

2024-10-28 13:38:39,442 - INFO: Saving params.yaml


100%|██████████| 2/2 [00:00<00:00, 467.44it/s]


To execute the graph, we make use of `dvc repro` via `project.repro`.

In [None]:
project.repro(build=False)

Running stage 'Smiles2Conformers':
> zntrack run mlipx.nodes.smiles.Smiles2Conformers --name Smiles2Conformers
Generating lock file 'dvc.lock'
Updating lock file 'dvc.lock'

Running stage 'Relax':
> zntrack run src.relaxation.Relax --name Relax
      Step     Time          Energy          fmax
BFGS:    0 13:38:43        4.725117        4.227687
BFGS:    1 13:38:43        3.694942        2.695542
BFGS:    2 13:38:43        3.002297        1.492195
BFGS:    3 13:38:43        2.821368        1.084164
BFGS:    4 13:38:43        2.664348        1.026548
BFGS:    5 13:38:43        2.491416        0.799464
BFGS:    6 13:38:43        2.425102        0.284223
BFGS:    7 13:38:43        2.421374        0.196049
BFGS:    8 13:38:43        2.419329        0.177517
BFGS:    9 13:38:43        2.415170        0.194933
BFGS:   10 13:38:43        2.410946        0.242349
BFGS:   11 13:38:43        2.406163        0.232391
BFGS:   12 13:38:43        2.403577        0.115007
BFGS:   13 13:38:43        2.

Once the graph has been executed, we can look at the resulting structures.

In [None]:
relax.frames

[Atoms(symbols='C4H10', pbc=False, calculator=SinglePointCalculator(...)),
 Atoms(symbols='C4H10', pbc=False, calculator=SinglePointCalculator(...)),
 Atoms(symbols='C4H10', pbc=False, calculator=SinglePointCalculator(...)),
 Atoms(symbols='C4H10', pbc=False, calculator=SinglePointCalculator(...)),
 Atoms(symbols='C4H10', pbc=False, calculator=SinglePointCalculator(...))]

In [None]:
temp_dir.cleanup()