# DVCLive and Lightning Fabric

## Install dvclive

In [None]:
!pip install "dvclive[lightning]"

## Initialize DVC Repository

In [None]:
!git init -q
!git config --local user.email "you@example.com"
!git config --local user.name "Your Name"
!dvc init -q
!git commit -m "DVC init"

## Imports

In [None]:
import argparse
from os import path
from types import SimpleNamespace

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.rank_zero import rank_zero_only
from torch.optim.lr_scheduler import StepLR
from torchmetrics.classification import Accuracy
from torchvision.datasets import MNIST

from dvclive.fabric import DVCLiveLogger

DATASETS_PATH = ("Datasets")

## Setup model code

Adapted from https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/image_classifier/train_fabric.py.

Look for the `logger` statements where DVCLiveLogger calls were added.

In [None]:
class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


def run(hparams):
    # Create the DVCLive Logger
    logger = DVCLiveLogger(report="notebook")

    # Log dict of hyperparameters
    logger.log_hyperparams(hparams.__dict__)

    # Create the Lightning Fabric object. The parameters like accelerator, strategy, devices etc. will be proided
    # by the command line. See all options: `lightning run model --help`
    fabric = Fabric()

    seed_everything(hparams.seed)  # instead of torch.manual_seed(...)

    transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])

    # Let rank 0 download the data first, then everyone will load MNIST
    with fabric.rank_zero_first(local=False):  # set `local=True` if your filesystem is not shared between machines
        train_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=True, transform=transform)
        test_dataset = MNIST(DATASETS_PATH, download=fabric.is_global_zero, train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=hparams.batch_size,
    )
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size)

    # don't forget to call `setup_dataloaders` to prepare for dataloaders for distributed training.
    train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader)

    model = Net()  # remove call to .to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=hparams.lr)

    # don't forget to call `setup` to prepare for model / optimizer for distributed training.
    # the model is moved automatically to the right device.
    model, optimizer = fabric.setup(model, optimizer)

    scheduler = StepLR(optimizer, step_size=1, gamma=hparams.gamma)

    # use torchmetrics instead of manually computing the accuracy
    test_acc = Accuracy(task="multiclass", num_classes=10).to(fabric.device)

    # EPOCH LOOP
    for epoch in range(1, hparams.epochs + 1):
        # TRAINING LOOP
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            # NOTE: no need to call `.to(device)` on the data, target
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            fabric.backward(loss)  # instead of loss.backward()

            optimizer.step()
            if (batch_idx == 0) or ((batch_idx + 1) % hparams.log_interval == 0):
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
                )

                # Log dict of metrics
                logger.log_metrics({"loss": loss.item()})

                if hparams.dry_run:
                    break

        scheduler.step()

        # TESTING LOOP
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for data, target in test_loader:
                # NOTE: no need to call `.to(device)` on the data, target
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction="sum").item()

                # WITHOUT TorchMetrics
                # pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                # correct += pred.eq(target.view_as(pred)).sum().item()

                # WITH TorchMetrics
                test_acc(output, target)

                if hparams.dry_run:
                    break

        # all_gather is used to aggregated the value across processes
        test_loss = fabric.all_gather(test_loss).sum() / len(test_loader.dataset)

        print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({100 * test_acc.compute():.0f}%)\n")

        # log additional metrics
        logger.log_metrics({"test_loss": test_loss, "test_acc": 100 * test_acc.compute()})

        test_acc.reset()

        if hparams.dry_run:
            break

    # When using distributed training, use `fabric.save`
    # to ensure the current process is allowed to save a checkpoint
    if hparams.save_model:
        fabric.save("mnist_cnn.pt", model.state_dict())

        # `logger.experiment` provides access to the `dvclive.Live` instance where you can use additional logging methods.
        # Check that `rank_zero_only.rank == 0` to avoid logging in other processes.
        if rank_zero_only.rank == 0:
          logger.experiment.log_artifact("mnist_cnn.pt")

    # Call finalize to save final results as a DVC experiment
    logger.finalize("success")

## Train the model

In [None]:
hparams = SimpleNamespace(batch_size=64, epochs=5, lr=1.0, gamma=0.7, dry_run=False, seed=1, log_interval=10, save_model=True)
run(hparams)