From 67847429d8167b34f4788be9c230d48b225770ac Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 27 Aug 2021 14:02:00 +0100 Subject: [PATCH 1/6] unlimited epochs - fixes #13 --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index 3d4e546..0fa34e4 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,6 @@ import dvclive -EPOCHS = 10 - class ConvNet(torch.nn.Module): """Toy convolutional neural net.""" @@ -88,7 +86,7 @@ def main(): x_test, y_test = transform(mnist_test) try: # Iterate over training epochs. - for i in range(1, EPOCHS+1): + for epoch in itertools.count(dvclive.get_step()): # Train in batches. train_loader = torch.utils.data.DataLoader( dataset=list(zip(x_train, y_train)), From e997e00ef360fc23faeaebf39db27cf19e6b6f8a Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 27 Aug 2021 14:25:49 +0100 Subject: [PATCH 2/6] fix & update README --- README.md | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 48fc9ab..f209761 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ This repo has several [branches](https://github.com/iterative/dvc-checkpoints-mnist/branches) that show different methods for using checkpoints (using a similar pipeline): -- The [live](https://github.com/iterative/dvc-checkpoints-mnist/edit/live) +- The [live](https://github.com/iterative/dvc-checkpoints-mnist/tree/live) scenario introduces full-featured checkpoint usage — integrating with [DVCLive](https://github.com/iterative/dvclive). - The [basic](https://github.com/iterative/dvc-checkpoints-mnist/tree/basic) @@ -50,16 +50,16 @@ To try it out for yourself: ## Experimenting -Start training the model with `dvc exp run`. It will train for 10 epochs (you -can use `Ctrl-C` to cancel at any time and still recover the results of the -completed epochs), each of which will generate a checkpoint. +Start training the model with `dvc exp run`. It will train for an unlimited +number of epochs, each of which will generate a checkpoint. Use `Ctrl-C` to stop +at the last checkpoint, and simply `dvc exp run` again to resume. -Dvclive will track performance at each checkpoint. Open `logs.html` in your web -browser during training to track performance over time (you will need to refresh -after each epoch completes to see updates). Metrics will also be logged to -`.tsv` files in the `logs` directory. +DVCLive will track performance at each checkpoint. Open `dvclive.html` in your +web browser during training to track performance over time (you will need to +refresh after each epoch completes to see updates). Metrics will also be logged +to `.tsv` files in the `dvclive` directory. -Once the training script completes, you can view the results of the experiment +Once you stop the training script, you can view the results of the experiment with: ```bash @@ -90,18 +90,20 @@ You can manage it like any other DVC * Run `dvc exp run --reset` to drop all the existing checkpoints and start from scratch. -## Adding dvclive checkpoints to a DVC project +## Adding `dvclive` checkpoints to a DVC project -Using dvclive to add checkpoints to a DVC project requires a few additional +Using `dvclive` to add checkpoints to a DVC project requires a few additional lines of code. -In your training script, use `dvclive.log()` to log metrics and -`dvclive.next_step()` to make a checkpoint with those metrics. See the -[train.py](train.py) script for an example: +In your training script, use `dvclive.get_step()` for the current step number; +`dvclive.log()` to log metrics, and `dvclive.next_step()` to make a checkpoint +with those metrics. See the +[train.py](https://github.com/iterative/dvc-checkpoints-mnist/blob/live/train.py) +script for an example: ```python # Iterate over training epochs. - for i in range(1, EPOCHS+1): + for epoch in itertools.count(dvclive.get_step()): train(model, x_train, y_train) torch.save(model.state_dict(), "model.pt") # Evaluate and checkpoint. @@ -112,7 +114,9 @@ In your training script, use `dvclive.log()` to log metrics and ``` Then, in `dvc.yaml`, add the `checkpoint: true` option to your model output and -a `live` section to your stage output. See [dvc.yaml](dvc.yaml) for an example: +a `live` section to your stage output. See +[dvc.yaml](https://github.com/iterative/dvc-checkpoints-mnist/blob/live/dvc.yaml) +for an example: ```yaml stages: @@ -124,7 +128,7 @@ stages: - model.pt: checkpoint: true live: - logs: + dvclive: summary: true html: true ``` @@ -133,9 +137,9 @@ If you do not already have a `dvc.yaml` stage, you can use [dvc stage add](https://dvc.org/doc/command-reference/stage/add) to create it: ```bash -$ dvc stage add -n train -d train.py -c model.pt --live logs python train.py +$ dvc stage add -n train -d train.py -c model.pt --live dvclive python train.py ``` That's it! For users already familiar with logging metrics in DVC, note that you -no longer need a `metrics` section in `dvc.yaml` since dvclive is already +no longer need a `metrics` section in `dvc.yaml` since `dvclive` is already logging metrics. From ae70f58de6a16373c1852feaaa99aa22e6f798db Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 27 Aug 2021 14:31:29 +0100 Subject: [PATCH 3/6] minor potential data fix --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 0fa34e4..e8e353e 100644 --- a/train.py +++ b/train.py @@ -80,7 +80,7 @@ def main(): if os.path.exists("model.pt"): model.load_state_dict(torch.load("model.pt")) # Load train and test data. - mnist_train = torchvision.datasets.MNIST("data", download=True) + mnist_train = torchvision.datasets.MNIST("data", download=True, train=True) x_train, y_train = transform(mnist_train) mnist_test = torchvision.datasets.MNIST("data", download=True, train=False) x_test, y_test = transform(mnist_test) From 3b52739d2444b8b1ff9ede3891172e65dc510a87 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 27 Aug 2021 14:31:42 +0100 Subject: [PATCH 4/6] lint --- train.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/train.py b/train.py index e8e353e..b484373 100644 --- a/train.py +++ b/train.py @@ -1,29 +1,31 @@ """Model training and evaluation.""" import json -import yaml import os + import torch import torch.nn.functional as F import torchvision -import dvclive +import yaml +import dvclive class ConvNet(torch.nn.Module): """Toy convolutional neural net.""" + def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(1, 8, 3, padding=1) self.maxpool1 = torch.nn.MaxPool2d(2) self.conv2 = torch.nn.Conv2d(8, 16, 3, padding=1) - self.dense1 = torch.nn.Linear(16*14*14, 32) + self.dense1 = torch.nn.Linear(16 * 14 * 14, 32) self.dense2 = torch.nn.Linear(32, 10) def forward(self, x): x = F.relu(self.conv1(x)) x = self.maxpool1(x) x = F.relu(self.conv2(x)) - x = x.view(-1, 16*14*14) + x = x.view(-1, 16 * 14 * 14) x = F.relu(self.dense1(x)) x = self.dense2(x) return x @@ -31,7 +33,7 @@ def forward(self, x): def transform(dataset): """Get inputs and targets from dataset.""" - x = dataset.data.reshape(len(dataset.data), 1, 28, 28)/255 + x = dataset.data.reshape(len(dataset.data), 1, 28, 28) / 255 y = dataset.targets return x, y @@ -60,7 +62,7 @@ def get_metrics(y, y_pred, y_pred_label): """Get loss and accuracy metrics.""" metrics = {} criterion = torch.nn.CrossEntropyLoss() - metrics["acc"] = (y_pred_label == y).sum().item()/len(y) + metrics["acc"] = (y_pred_label == y).sum().item() / len(y) return metrics @@ -89,9 +91,8 @@ def main(): for epoch in itertools.count(dvclive.get_step()): # Train in batches. train_loader = torch.utils.data.DataLoader( - dataset=list(zip(x_train, y_train)), - batch_size=512, - shuffle=True) + dataset=list(zip(x_train, y_train)), batch_size=512, shuffle=True + ) for x_batch, y_batch in train_loader: train(model, x_batch, y_batch) torch.save(model.state_dict(), "model.pt") From 661a8ec824ef908e05a97b3b2eca19cea646c566 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Fri, 27 Aug 2021 14:51:00 +0100 Subject: [PATCH 5/6] ignore /data --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3af0ccb --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/data From 47d930c08be9a4b62e458d356b138146394a9528 Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Fri, 27 Aug 2021 14:46:31 -0400 Subject: [PATCH 6/6] Update README.md Co-authored-by: Casper da Costa-Luis --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f209761..4d0d290 100644 --- a/README.md +++ b/README.md @@ -95,9 +95,12 @@ You can manage it like any other DVC Using `dvclive` to add checkpoints to a DVC project requires a few additional lines of code. -In your training script, use `dvclive.get_step()` for the current step number; -`dvclive.log()` to log metrics, and `dvclive.next_step()` to make a checkpoint -with those metrics. See the +In your training script, use `dvclive.log()` to log metrics and +`dvclive.next_step()` to make a checkpoint with those metrics. +If you need the current epoch number, use `dvclive.get_step()` (e.g. +to use a [learning rate +schedule](https://en.wikipedia.org/wiki/Learning_rate#Learning_rate_schedule) +or stop training after a fixed number of epochs). See the [train.py](https://github.com/iterative/dvc-checkpoints-mnist/blob/live/train.py) script for an example: