[![](https://img.shields.io/badge/Source%20on%20GitHub-orange)](https://github.com/laminlabs/lamin-mlops/blob/main/docs/wandb.ipynb)
[![](https://img.shields.io/badge/Source%20%26%20report%20on%20LaminHub-mediumseagreen)](https://lamin.ai/laminlabs/lamindata/transform/nrPNwWEVUsL95zKv)

# Wandb

We show how LaminDB can be integrated with Wandb to track the whole training process, associate data with models, and facilitate model querying based on hyperparameters, among other criteria.

In [None]:
# !pip install -q 'lamindb[jupyter,aws]' torch torchvision lightning wandb
!lamin init --storage ./lamin-mlops
!wandb login

In [None]:
import lamindb as ln
import wandb

ln.context.uid = "tULn4Va2yERp0000"
ln.context.track()

## Define a model

Define a simple autoencoder as an example model using PyTorch Lightning.

In [None]:
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning


class LitAutoEncoder(lightning.LightningModule):
    def __init__(self, hidden_size, bottleneck_size):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, bottleneck_size)
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_size, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, 28 * 28)
        )
        # save hyper-parameters to self.hparams auto-logged by wandb
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


## Query & download the MNIST dataset

We curated the MNIST dataset in [this notebook](https://lamin.ai/laminlabs/lamindata/transform/mwaEQepEtFeh5zKv) and it now shows up in the artifact registry:

In [None]:
ln.Artifact.filter(type="dataset").df()

You can also see it on lamin.ai if you connected your instance.

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/LlMSvBjHuXbs36TBGoCM.png" width="800px">

Let's get the dataset:

In [None]:
artifact = ln.Artifact.get(key="testdata/mnist")
artifact

And download it to a local cache:

In [None]:
path = artifact.cache()
path

Create a pytorch-compatible dataset:

In [None]:
dataset = MNIST(path.as_posix(), transform=ToTensor())
dataset

## Monitor training with wandb

Train our example model and track the training progress with `wandb`.

In [None]:
from lightning.pytorch.loggers import WandbLogger

MODEL_CONFIG = {
    "hidden_size": 32,
    "bottleneck_size": 16,
    "batch_size": 32
}

# create the data loader
train_loader = utils.data.DataLoader(dataset, batch_size=MODEL_CONFIG["batch_size"], shuffle=True)

# init model
autoencoder = LitAutoEncoder(MODEL_CONFIG["hidden_size"], MODEL_CONFIG["bottleneck_size"])

# initialize the logger
wandb_logger = WandbLogger(project="lamin")

# add batch size to the wandb config
wandb_logger.experiment.config["batch_size"] = MODEL_CONFIG["batch_size"]

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

# store checkpoints to disk and upload to LaminDB after training
checkpoint_callback = ModelCheckpoint(
    dirpath=f"model_checkpoints/{wandb_logger.version}", 
    filename="last_epoch",
    save_top_k=1,
    monitor="train_loss"
)

# train model
trainer = lightning.Trainer(
    accelerator="cpu",
    limit_train_batches=3, 
    max_epochs=2,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

In [None]:
wandb_logger.experiment.name

In [None]:
wandb_logger.version

In [None]:
wandb.finish()

**See the training progress in the `wandb` UI:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/awrTvbxrLaiNav17VxBN.png" width="800px">

## Save model in LaminDB

In [None]:
# save checkpoint as a model in LaminDB
artifact = ln.Artifact(
    f"model_checkpoints/{wandb_logger.version}",
    key="testmodels/litautoencoder",  # is automatically versioned
    type="model",
).save()

# create a label with the wandb experiment name
experiment_label = ln.ULabel(
    name=wandb_logger.experiment.name, 
    description="wandb experiment name"
).save()

# annotate the model artifact
artifact.ulabels.add(experiment_label)

# define the associated model hyperparameters in ln.Param
for k, v in MODEL_CONFIG.items():
    ln.Param(name=k, dtype=type(v).__name__).save()
artifact.params.add_values(MODEL_CONFIG)

# describe the artifact
artifact.describe()

**See the checkpoints:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/248fOMXqxT0U4f7LRSgj.png" width="800px">

If later on, you want to re-use the checkpoint, you can download it like so:

In [None]:
ln.Artifac.get(key='testmodels/litautoencoder').cache()

Or on the CLI:
```
lamin get artifact --key 'testmodels/litautoencoder'
```

In [None]:
# save notebook
# ln.context.finish()