[![](https://img.shields.io/badge/Source%20on%20GitHub-orange)](https://github.com/laminlabs/lamin-mlops/blob/main/docs/mlflow.ipynb)

# MLFlow

We show how LaminDB can be integrated with [MLflow](https://mlflow.org/) to track the training process and associate datasets & parameters with models.

In [None]:
# pip install lamindb torchvision lightning wandb
!lamin init --storage ./lamin-mlops

In [None]:
import lamindb as ln
import lightning as pl
import mlflow
from pathlib import Path

from torch import utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from autoencoder import LitAutoEncoder

```{dropdown} Tracking models in both LaminDB and MLFlow
It is not always necessary to track all model parameters and metrics in both LaminDB and MLFlow.
However, if specific artifacts or runs should be queryable by specific model attributes such as, for example, the learning rate, then these attributes should be tracked.
Below, we show exemplary how to do that for the batch size and learning rate but the approach generalizes to more features.
```

In [None]:
# define model run parameters & features
MODEL_CONFIG = {"batch_size": 32, "lr": 0.001}

hyperparameter = ln.Feature(name="Autoencoder hyperparameter", is_type=True).save()
hyperparams = ln.Feature.from_dict(MODEL_CONFIG, str_as_cat=True)
for param in hyperparams:
    param.type = hyperparameter
    param.save()

ln.track(params=MODEL_CONFIG, project=ln.Project(name="MLflow tutorial").save())

## Define a model

We use a basic PyTorch Lightning autoencoder as an example model.

````{dropdown} Code of LitAutoEncoder
```{eval-rst}
.. literalinclude:: autoencoder.py
   :language: python
   :caption: Simple autoencoder model
```
````

## Query & download the MNIST dataset

We saved the MNIST dataset in a [curation notebook](/mnist) which now shows up in the Artifact registry:

In [None]:
ln.Artifact.filter(kind="dataset").to_dataframe()

Let's get the dataset:

In [None]:
artifact = ln.Artifact.get(key="testdata/mnist")
artifact

And download it to a local cache:

In [None]:
path = artifact.cache()
path

Create a PyTorch-compatible dataset:

In [None]:
dataset = MNIST(path.as_posix(), transform=ToTensor())
dataset

## Monitor training with MLflow

Train our example model and track the training progress with `MLflow`.

In [None]:
# enable MLFlow PyTorch autologging
mlflow.pytorch.autolog()

In [None]:
from lamindb.integrations import lightning as lnpl

with mlflow.start_run() as mlflow_run:
    train_dataset = MNIST(
        root="./data", train=True, download=True, transform=ToTensor()
    )
    val_dataset = MNIST(root="./data", train=False, download=True, transform=ToTensor())

    train_loader = utils.data.DataLoader(train_dataset, batch_size=32)
    val_loader = utils.data.DataLoader(val_dataset, batch_size=32)

    autoencoder = LitAutoEncoder(32, 16)

    ckpt_dir = Path("model_checkpoints")
    ckpt_filename = "{mlflow_run.info.run_id}_last_epoch.ckpt"
    artifact_key = f"testmodels/mlflow/{mlflow_run.info.run_id}.ckpt"  # every run makes a new version of this artifact

    metrics = [
        ("epoch", int),
        ("global_step", int),
        ("train_loss", float),
        ("train_loss_step", float),
        ("val_loss", float),
        ("val_loss_step", float),
    ]

    # Create a LaminDB LightningCallback which also annotates check points by desired metrics
    metrics_to_annotate = ["train_loss", "val_loss"]
    for metric in metrics_to_annotate:
        ln.Feature(name=metric, dtype=float).save()
    lamindb_callback = lnpl.LightningCallback(
        path=ckpt_dir / ckpt_filename, key=artifact_key, annotate_by=metrics_to_annotate
    )

    trainer = pl.Trainer(
        accelerator="cpu",
        limit_train_batches=3,
        max_epochs=5,
        callbacks=[lamindb_callback],
    )

    trainer.fit(
        model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader
    )

    # Register model_summary.txt
    local_model_summary_path = (
        f"{mlflow_run.info.artifact_uri.removeprefix('file://')}/model_summary.txt"
    )
    mlflow_model_summary_af = ln.Artifact(
        local_model_summary_path,
        key=local_model_summary_path,
        kind="model",
    ).save()

**See the training progress in the `mlflow` UI:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/C0seowxsq4Du2B4T0001.png" alt="MLFlow run UI" width="800px">

**See the experiment overview:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/n0xxFoMRtZPiQ7VT0002.png" alt="MLFlow experiment overview UI" width="800px">

In [None]:
last_checkpoint_af = ln.Artifact.get(
    key__startswith="testmodels/mlflow/", suffix__endswith="ckpt", is_latest=True
)
last_checkpoint_af.describe()

If later on, you want to re-use the checkpoint, you can get it via:

In [None]:
last_checkpoint_af.cache()

In [None]:
last_checkpoint_af.view_lineage()

In [None]:
ln.finish()