[![](https://img.shields.io/badge/Source%20on%20GitHub-orange)](https://github.com/laminlabs/lamin-mlops/blob/main/docs/wandb.ipynb)
[![](https://img.shields.io/badge/Source%20%26%20report%20on%20LaminHub-mediumseagreen)](https://lamin.ai/laminlabs/lamindata/transform/nrPNwWEVUsL95zKv)

# Weights & Biases

We show how LaminDB can be integrated with W&B to track the training process and associate datasets & parameters with models.

In [None]:
# pip install lamindb torchvision lightning wandb
!lamin init --storage ./lamin-mlops
!wandb login

In [None]:
import lamindb as ln
import wandb
import lightning as pl
from pathlib import Path

from torch import utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from autoencoder import LitAutoEncoder

In [None]:
MODEL_CONFIG = {"hidden_size": 32, "bottleneck_size": 16, "batch_size": 32}

hyperparameter = ln.Feature(name="Autoencoder hyperparameter", is_type=True).save()
hyperparams = ln.Feature.from_dict(MODEL_CONFIG, feature_type=hyperparameter)
ln.save(hyperparams)

metrics_to_annotate = ["train_loss", "val_loss", "current_epoch"]
for metric in metrics_to_annotate:
    dtype = int if metric == "current_epoch" else float
    ln.Feature(name=metric, dtype=dtype).save()

# create all Wandb related features like 'wandb_run_id'
_ = ln.examples.ml_tracking.create_wandb_schema()

In [None]:
# track this notebook/script run so that all checkpoint artifacts are associated with the source code
ln.track(params=MODEL_CONFIG, project=ln.Project(name="Wandb tutorial").save())

## Define a model

We use a basic PyTorch Lightning autoencoder as an example model.

````{dropdown} Code of LitAutoEncoder
```{eval-rst}
.. literalinclude:: autoencoder.py
   :language: python
   :caption: Simple autoencoder model
```
````

## Query & download the MNIST dataset

We saved the MNIST dataset in [curation notebook](/mnist) which now shows up in the Artifact registry:

In [None]:
ln.Artifact.filter(kind="dataset").to_dataframe()

You can also find it on lamin.ai if you were connected your instance.

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/LlMSvBjHuXbs36TBGoCM.png" alt="instance view" width="800px">

Let's get the dataset:

In [None]:
artifact = ln.Artifact.get(key="testdata/mnist")
artifact

And download it to a local cache:

In [None]:
path = artifact.cache()
path

Create a PyTorch-compatible dataset:

In [None]:
dataset = MNIST(path.as_posix(), transform=ToTensor())
dataset

## Monitor training with wandb

Train our example model and track the training progress with `wandb`.

In [None]:
from lightning.pytorch.loggers import WandbLogger

# create the data loader
train_dataset = MNIST(root="./data", train=True, download=True, transform=ToTensor())
val_dataset = MNIST(root="./data", train=False, download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(train_dataset, batch_size=32)
val_loader = utils.data.DataLoader(val_dataset, batch_size=32)

# init model
autoencoder = LitAutoEncoder(
    MODEL_CONFIG["hidden_size"], MODEL_CONFIG["bottleneck_size"]
)

# initialize the logger
wandb_logger = WandbLogger(project="lamin")

# add batch size to the wandb config
wandb_logger.experiment.config["batch_size"] = MODEL_CONFIG["batch_size"]

In [None]:
# Create a LaminDB LightningCallback which also (optionally) annotates checkpoints by desired metrics
wandb_logger.experiment.id
lamindb_callback = ln.integrations.lightning.Callback(
    path=Path("model_checkpoints") / "{wanddblogger.version}_last_epoch.ckpt",
    key=f"testmodels/wandb/{wandb_logger.experiment.id}.ckpt",
    features={
        "wandb_run_id": wandb_logger.experiment.id,
        "wandb_run_name": wandb_logger.experiment.name,
        **{
            metric: None for metric in metrics_to_annotate
        },  # auto-populated through callback
    },
)

# train model
trainer = pl.Trainer(
    limit_train_batches=3,
    max_epochs=5,
    logger=wandb_logger,
    callbacks=[lamindb_callback],
)
trainer.fit(
    model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader
)

In [None]:
wandb_logger.experiment.name

In [None]:
wandb_logger.version

In [None]:
wandb.finish()

**See the training progress in the `wandb` UI:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/awrTvbxrLaiNav17VxBN.png" alt="Wandb training ui" width="800px">

## Save model in LaminDB

**See the checkpoints:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/248fOMXqxT0U4f7LRSgj.png" alt="Wandb check points" width="800px">

All checkpoints are automatically annotated by the specified training metrics and MLflow run ID & name to keep both frameworks in sync:

In [None]:
last_checkpoint_af = ln.Artifact.filter(
    key__startswith="testmodels/wandb/", suffix__endswith="ckpt", is_latest=True
).last()
last_checkpoint_af.describe()

If later on, you want to re-use the checkpoint, you can download it like so:

In [None]:
last_checkpoint_af.cache()

In [None]:
ln.finish()