<a href="https://colab.research.google.com/github/iterative/dvclive/blob/main/examples/DVCLive-PyTorch-Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DVCLive and PyTorch Lightning

## Setup

In [None]:
!pip install "dvclive[lightning]"

In [None]:
!git init -q
!git config --local user.email "you@example.com"
!git config --local user.name "Your Name"
!dvc init -q
!git commit -m "DVC init"

### Define LightningModule

In [None]:
import lightning.pytorch as pl
import torch

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder_size=64, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, encoder_size),
            torch.nn.ReLU(),
            torch.nn.Linear(encoder_size, 3)
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(3, encoder_size),
            torch.nn.ReLU(),
            torch.nn.Linear(encoder_size, 28 * 28)
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        train_mse = torch.nn.functional.mse_loss(x_hat, x)
        self.log("train_mse", train_mse)
        return train_mse

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_mse = torch.nn.functional.mse_loss(x_hat, x)
        self.log("val_mse", val_mse)
        return val_mse

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

### Dataset and loaders

In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

transform = transforms.ToTensor()
train_set = MNIST(root="MNIST", download=True, train=True, transform=transform)
validation_set = MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set)
validation_loader = torch.utils.data.DataLoader(validation_set)

# Tracking experiments with DVCLive

In [None]:
from dvclive.lightning import DVCLiveLogger

In [None]:
for encoder_size in (64, 128):
  for lr in (1e-3, 0.1):
    model = LitAutoEncoder(encoder_size=encoder_size, lr=lr)
    trainer = pl.Trainer(
        limit_train_batches=200,
        limit_val_batches=100,
        max_epochs=5,
        logger=DVCLiveLogger(log_model=True, report="notebook"),
    )
    trainer.fit(model, train_loader, validation_loader)


## Comparing results

In [None]:
import dvc.api
import pandas as pd

columns = ["Experiment", "encoder_size", "lr", "train.mse", "val.mse"]

df = pd.DataFrame(dvc.api.exp_show(), columns=columns)

df.dropna(inplace=True)
df.reset_index(drop=True, inplace=True)
df


In [None]:
from plotly.express import parallel_coordinates
fig = parallel_coordinates(df, columns, color="val.mse")
fig.show()

In [None]:
!dvc plots diff $(dvc exp list --names-only)

In [None]:
from IPython.display import HTML
HTML(filename='./dvc_plots/index.html')