In [6]:
import os

from dvclive import Live
from dvclive.lightning import DVCLiveLogger
import numpy as np
import pytorch_lightning as pl
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [7]:
# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder_size):
        super().__init__()

        # Saves any args passed to __init__ (for example, encoder_size)
        self.save_hyperparameters()

        self.encoder = nn.Sequential(nn.Linear(28 * 28, encoder_size), nn.ReLU(), nn.Linear(encoder_size, 3))
        self.decoder = nn.Sequential(nn.Linear(3, encoder_size), nn.ReLU(), nn.Linear(encoder_size, 28 * 28))

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [8]:
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

In [9]:
# train the model
live = Live(save_dvc_exp=True)
img = np.random.randint(0, 255, (500, 500), np.uint8)
live.log_image("numpy.png", img)
autoencoder = LitAutoEncoder(encoder_size=64)
trainer = pl.Trainer(
    limit_train_batches=200,
    max_epochs=5,
    logger=DVCLiveLogger(experiment=live)
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (mps), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Epoch 4: 100%|██████████| 200/200 [00:01<00:00, 107.11it/s, loss=0.0495, v_num=_run]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 200/200 [00:01<00:00, 106.82it/s, loss=0.0495, v_num=_run]


	dvclive-exp-tracking.ipynb, MNIST/raw/t10k-images-idx3-ubyte, MNIST/raw/t10k-labels-idx1-ubyte, MNIST/raw/train-images-idx3-ubyte, MNIST/raw/t10k-images-idx3-ubyte.gz, MNIST/raw/train-images-idx3-ubyte.gz, MNIST/raw/train-labels-idx1-ubyte.gz, MNIST/raw/train-labels-idx1-ubyte, MNIST/raw/t10k-labels-idx1-ubyte.gz, DvcLiveLogger/dvclive_run/checkpoints/epoch=4-step=1000.ckpt, DvcLiveLogger/dvclive_run/checkpoints/epoch=4-step=1000-v1.ckpt, dvclive/metrics.json, dvclive/params.yaml, dvclive/dvc.yaml, dvclive/report.html, dvclive/plots/metrics/epoch.tsv, dvclive/plots/metrics/train_loss.tsv, dvclive/plots/images/numpy.png
	dvclive-exp-tracking.ipynb, MNIST/raw/t10k-images-idx3-ubyte, MNIST/raw/t10k-labels-idx1-ubyte, MNIST/raw/train-images-idx3-ubyte, MNIST/raw/t10k-images-idx3-ubyte.gz, MNIST/raw/train-images-idx3-ubyte.gz, MNIST/raw/train-labels-idx1-ubyte.gz, MNIST/raw/train-labels-idx1-ubyte, MNIST/raw/t10k-labels-idx1-ubyte.gz, DvcLiveLogger/dvclive_run/checkpoints/epoch=4-step=1000

In [10]:
!dvc exp show

 ──────────────────────────────────────────────────────────────────────────────────── 
  Experiment                 Created        train_loss   epoch   step   encoder_size  
 ──────────────────────────────────────────────────────────────────────────────────── 
  workspace                  -                0.059008       4    999   64            
  main                       Dec 14, 2022            -       -      -   -             
  └── 3350052 [tenty-taco]   02:08 PM         0.059008       4    999   64            
 ──────────────────────────────────────────────────────────────────────────────────── 
