# Wandb

We show how LaminDB can be integrated with Wandb to track the whole training process, associate data with models, and facilitate model querying based on hyperparameters, among other criteria.

In [None]:
# uncomment below to install necessary dependencies for this notebook:
# !pip install 'lamindb[jupyter,aws]' -q
# !pip install wandb -qU
# !pip install torch torchvision torchaudio lightning -q

In [None]:
# you can also pass s3://my-bucket
!lamin init --storage ./lamin-mlops

In [None]:
import lamindb as ln
import wandb

ln.settings.transform.stem_uid = "tULn4Va2yERp"
ln.settings.transform.version = "1"

ln.track()

In [None]:
!wandb login

## Define a model

Define a simple autoencoder as an example model using PyTorch Lightning.

In [None]:
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L


class LitAutoEncoder(L.LightningModule):
    def __init__(self, hidden_size, bottleneck_size):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, bottleneck_size)
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_size, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, 28 * 28)
        )
        # save hyper-parameters to self.hparams auto-logged by wandb
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


## Query & cache MNIST dataset from LaminDB

**The dataset shows up in LaminHub:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/LlMSvBjHuXbs36TBGoCM.png" width="800px">

We can either query it by UID from there or query it by any other metadata combination.

Here, by description:

In [None]:
training_data_artifact = ln.Artifact.filter(description="MNIST-dataset").one()
training_data_artifact

Let's cache the dataset:

In [None]:
cache_path = training_data_artifact.cache()
cache_path

Create a pytorch-compatible dataset:

In [None]:
!ls -r /Users/falexwolf/repos/lamin-mlops/docs/lamin-mlops/.lamindb/zAYD0B1Rw7lBFsJS

In [None]:
dataset = MNIST(cache_path / "raw", transform=ToTensor())
dataset

## Monitor training with wandb

Train our example model and track training progress with Wandb.

In [None]:
MODEL_CONFIG = {
    "hidden_size": 32,
    "bottleneck_size": 16,
    "batch_size": 32
}

In [None]:
# create PyTorch dataloader
train_loader = utils.data.DataLoader(dataset, batch_size=MODEL_CONFIG["batch_size"], shuffle=True)
# init model
autoencoder = LitAutoEncoder(MODEL_CONFIG["hidden_size"], MODEL_CONFIG["bottleneck_size"])

In [None]:
from lightning.pytorch.loggers import WandbLogger

# initialise the wandb logger
wandb_logger = WandbLogger(project="lamin")
# add batch size to the wandb config
wandb_logger.experiment.config["batch_size"] = MODEL_CONFIG["batch_size"]

In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint

# store checkpoints to disk and upload to LaminDB after training
checkpoint_callback = ModelCheckpoint(
    dirpath=f"model_checkpoints/{wandb_logger.version}", 
    filename="last_epoch",
    save_top_k=1,
    monitor="train_loss"
)

In [None]:
# train model
trainer = L.Trainer(
    accelerator="cpu",
    limit_train_batches=3, 
    max_epochs=2,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

In [None]:
wandb_logger.experiment.name

In [None]:
wandb_logger.version

In [None]:
wandb.finish()

**Check out the training progress on the Wandb UI:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/awrTvbxrLaiNav17VxBN.png" width="800px">

## Save model in LaminDB

Upload the model checkpoint of the trained model to LaminDB.

We annotate the LaminDB Artifact with the wandb experiment ID and the hyper parameters.

In [None]:
# save checkpoint in LaminDB
ckpt_artifact = ln.Artifact(
    f"model_checkpoints/{wandb_logger.version}",
    description="model-checkpoint",
    type="model",
).save()

In [None]:
# create a label with the wandb experiment name
experiment_label = ln.ULabel(
    name=wandb_logger.experiment.name, 
    description="wandb experiment name"
).save()
# annotate the artifact
ckpt_artifact.ulabels.add(experiment_label)

In [None]:
# define the associated model hyperparameters in ln.Param
for k, v in MODEL_CONFIG.items():
    ln.Param(name=k, dtype=type(v).__name__).save()
# annotate the artifact with them
ckpt_artifact.params.add_values(MODEL_CONFIG)

In [None]:
# show info about the checkpoint artifact
ckpt_artifact.describe()

**Look at saved checkpoints in LaminHub:**

<img src="https://lamin-site-assets.s3.amazonaws.com/.lamindb/248fOMXqxT0U4f7LRSgj.png" width="800px">

In [None]:
# save notebook
# ln.finish()