# MLFlow

We show how LaminDB can be integrated with [MLFlow](https://mlflow.org/) to track the training process and associate datasets & parameters with models.

In [None]:
# !pip install 'lamindb[jupyter]' torchvision lightning wandb
!lamin init --storage ./lamin-mlops

In [None]:
import lamindb as ln
import mlflow
import lightning

from torch import utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from autoencoder import LitAutoEncoder

ln.track()

## Define a model

We use a basic PyTorch Lightning autoencoder as an example model.

{dropdown}
```{eval-rst}
.. literalinclude:: autoencoder.py
   :language: python
   :caption: Simple autoencoder model
```

## Query & download the MNIST dataset

We saved the MNIST dataset in [curation notebook](/01_mnist) which now shows up in the Artifact registry:

In [None]:
ln.Artifact.filter(kind="dataset").df()

You can also find it on lamin.ai if you were connected your instance.

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/LlMSvBjHuXbs36TBGoCM.png" width="800px">

Let's get the dataset:

In [None]:
artifact = ln.Artifact.get(key="testdata/mnist")
artifact

And download it to a local cache:

In [None]:
path = artifact.cache()
path

Create a PyTorch-compatible dataset:

In [None]:
dataset = MNIST(path.as_posix(), transform=ToTensor())
dataset

## Monitor training with mlflow

Train our example model and track the training progress with `mlflow`.

In [None]:
mlflow.pytorch.autolog()

MODEL_CONFIG = {"hidden_size": 32, "bottleneck_size": 16, "batch_size": 32}

# Start MLflow run
with mlflow.start_run() as run:
    train_dataset = MNIST(
        root="./data", train=True, download=True, transform=ToTensor()
    )
    train_loader = utils.data.DataLoader(
        train_dataset, batch_size=MODEL_CONFIG["batch_size"]
    )

    # Initialize model
    autoencoder = LitAutoEncoder(
        MODEL_CONFIG["hidden_size"], MODEL_CONFIG["bottleneck_size"]
    )

    # Create checkpoint callback
    from lightning.pytorch.callbacks import ModelCheckpoint

    checkpoint_callback = ModelCheckpoint(
        dirpath="model_checkpoints",
        filename=f"{run.info.run_id}_last_epoch",
        save_top_k=1,
        monitor="train_loss",
    )

    # Train model
    trainer = lightning.Trainer(
        accelerator="cpu",
        limit_train_batches=3,
        max_epochs=2,
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model=autoencoder, train_dataloaders=train_loader)

    # Get run information
    run_id = run.info.run_id
    metrics = mlflow.get_run(run_id).data.metrics
    params = mlflow.get_run(run_id).data.params

    # Access model artifacts path
    model_uri = f"runs:/{run_id}/model"
    artifacts_path = run.info.artifact_uri

**See the training progress in the `mlflow` UI:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/C0seowxsq4Du2B4T0000.png" width="800px">

## Save model in LaminDB

In [None]:
# save checkpoint as a model in LaminDB
artifact = ln.Artifact(
    f"model_checkpoints/{run_id}_last_epoch.ckpt",
    key="testmodels/mlflow/litautoencoder.ckpt",  # is automatically versioned
    type="model",
).save()

# create a label with the mlflow experiment name
mlflow_run_name = mlflow.get_run(run_id).data.tags.get(
    "mlflow.runName", f"run_{run_id}"
)
experiment_label = ln.ULabel(
    name=mlflow_run_name, description="mlflow experiment name"
).save()

# annotate the model Artifact
artifact.ulabels.add(experiment_label)

# define the associated model hyperparameters in ln.Param
for k, v in MODEL_CONFIG.items():
    ln.Param(name=k, dtype=type(v).__name__).save()
artifact.params.add_values(MODEL_CONFIG)

# look at Artifact annotations
artifact.describe()
artifact.params

**See the checkpoints:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/n0xxFoMRtZPiQ7VT0001.png" width="800px">

If later on, you want to re-use the checkpoint, you can download it like so:

In [None]:
ln.Artifact.get(key="testmodels/mlflow/litautoencoder.ckpt").cache()

Or on the CLI:
```
lamin get artifact --key 'testmodels/litautoencoder'
```

In [None]:
ln.finish()

In [None]:
!rm -rf ./lamin-mlops
!lamin delte lamin-mlops