Skip to content
This repository was archived by the owner on Mar 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/data
43 changes: 25 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ This repo has several
[branches](https://github.com/iterative/dvc-checkpoints-mnist/branches) that
show different methods for using checkpoints (using a similar pipeline):

- The [live](https://github.com/iterative/dvc-checkpoints-mnist/edit/live)
- The [live](https://github.com/iterative/dvc-checkpoints-mnist/tree/live)
scenario introduces full-featured checkpoint usage — integrating with
[DVCLive](https://github.com/iterative/dvclive).
- The [basic](https://github.com/iterative/dvc-checkpoints-mnist/tree/basic)
Expand Down Expand Up @@ -50,16 +50,16 @@ To try it out for yourself:

## Experimenting

Start training the model with `dvc exp run`. It will train for 10 epochs (you
can use `Ctrl-C` to cancel at any time and still recover the results of the
completed epochs), each of which will generate a checkpoint.
Start training the model with `dvc exp run`. It will train for an unlimited
number of epochs, each of which will generate a checkpoint. Use `Ctrl-C` to stop
at the last checkpoint, and simply `dvc exp run` again to resume.

Dvclive will track performance at each checkpoint. Open `logs.html` in your web
browser during training to track performance over time (you will need to refresh
after each epoch completes to see updates). Metrics will also be logged to
`.tsv` files in the `logs` directory.
DVCLive will track performance at each checkpoint. Open `dvclive.html` in your
web browser during training to track performance over time (you will need to
refresh after each epoch completes to see updates). Metrics will also be logged
to `.tsv` files in the `dvclive` directory.

Once the training script completes, you can view the results of the experiment
Once you stop the training script, you can view the results of the experiment
with:

```bash
Expand Down Expand Up @@ -90,18 +90,23 @@ You can manage it like any other DVC
* Run `dvc exp run --reset` to drop all the existing checkpoints and start from
scratch.

## Adding dvclive checkpoints to a DVC project
## Adding `dvclive` checkpoints to a DVC project

Using dvclive to add checkpoints to a DVC project requires a few additional
Using `dvclive` to add checkpoints to a DVC project requires a few additional
lines of code.

In your training script, use `dvclive.log()` to log metrics and
`dvclive.next_step()` to make a checkpoint with those metrics. See the
[train.py](train.py) script for an example:
`dvclive.next_step()` to make a checkpoint with those metrics.
If you need the current epoch number, use `dvclive.get_step()` (e.g.
to use a [learning rate
schedule](https://en.wikipedia.org/wiki/Learning_rate#Learning_rate_schedule)
or stop training after a fixed number of epochs). See the
[train.py](https://github.com/iterative/dvc-checkpoints-mnist/blob/live/train.py)
script for an example:

```python
# Iterate over training epochs.
for i in range(1, EPOCHS+1):
for epoch in itertools.count(dvclive.get_step()):
train(model, x_train, y_train)
torch.save(model.state_dict(), "model.pt")
# Evaluate and checkpoint.
Expand All @@ -112,7 +117,9 @@ In your training script, use `dvclive.log()` to log metrics and
```

Then, in `dvc.yaml`, add the `checkpoint: true` option to your model output and
a `live` section to your stage output. See [dvc.yaml](dvc.yaml) for an example:
a `live` section to your stage output. See
[dvc.yaml](https://github.com/iterative/dvc-checkpoints-mnist/blob/live/dvc.yaml)
for an example:

```yaml
stages:
Expand All @@ -124,7 +131,7 @@ stages:
- model.pt:
checkpoint: true
live:
logs:
dvclive:
summary: true
html: true
```
Expand All @@ -133,9 +140,9 @@ If you do not already have a `dvc.yaml` stage, you can use [dvc stage
add](https://dvc.org/doc/command-reference/stage/add) to create it:

```bash
$ dvc stage add -n train -d train.py -c model.pt --live logs python train.py
$ dvc stage add -n train -d train.py -c model.pt --live dvclive python train.py
```

That's it! For users already familiar with logging metrics in DVC, note that you
no longer need a `metrics` section in `dvc.yaml` since dvclive is already
no longer need a `metrics` section in `dvc.yaml` since `dvclive` is already
logging metrics.
25 changes: 12 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
"""Model training and evaluation."""
import json
import yaml
import os

import torch
import torch.nn.functional as F
import torchvision
import dvclive

import yaml

EPOCHS = 10
import dvclive


class ConvNet(torch.nn.Module):
"""Toy convolutional neural net."""

def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 8, 3, padding=1)
self.maxpool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1)
self.dense1 = torch.nn.Linear(16*14*14, 32)
self.dense1 = torch.nn.Linear(16 * 14 * 14, 32)
self.dense2 = torch.nn.Linear(32, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = self.maxpool1(x)
x = F.relu(self.conv2(x))
x = x.view(-1, 16*14*14)
x = x.view(-1, 16 * 14 * 14)
x = F.relu(self.dense1(x))
x = self.dense2(x)
return x


def transform(dataset):
"""Get inputs and targets from dataset."""
x = dataset.data.reshape(len(dataset.data), 1, 28, 28)/255
x = dataset.data.reshape(len(dataset.data), 1, 28, 28) / 255
y = dataset.targets
return x, y

Expand Down Expand Up @@ -62,7 +62,7 @@ def get_metrics(y, y_pred, y_pred_label):
"""Get loss and accuracy metrics."""
metrics = {}
criterion = torch.nn.CrossEntropyLoss()
metrics["acc"] = (y_pred_label == y).sum().item()/len(y)
metrics["acc"] = (y_pred_label == y).sum().item() / len(y)
return metrics


Expand All @@ -82,18 +82,17 @@ def main():
if os.path.exists("model.pt"):
model.load_state_dict(torch.load("model.pt"))
# Load train and test data.
mnist_train = torchvision.datasets.MNIST("data", download=True)
mnist_train = torchvision.datasets.MNIST("data", download=True, train=True)
x_train, y_train = transform(mnist_train)
mnist_test = torchvision.datasets.MNIST("data", download=True, train=False)
x_test, y_test = transform(mnist_test)
try:
# Iterate over training epochs.
for i in range(1, EPOCHS+1):
for epoch in itertools.count(dvclive.get_step()):
# Train in batches.
train_loader = torch.utils.data.DataLoader(
dataset=list(zip(x_train, y_train)),
batch_size=512,
shuffle=True)
dataset=list(zip(x_train, y_train)), batch_size=512, shuffle=True
)
for x_batch, y_batch in train_loader:
train(model, x_batch, y_batch)
torch.save(model.state_dict(), "model.pt")
Expand Down