Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

track dvcyaml file #710

Merged
merged 21 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ jobs:
pip install -e '.[tests]'

- name: Run tests
run: pytest -v tests --ignore=tests/test_frameworks
run: pytest -v tests --ignore=tests/frameworks
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

.dvc/
.dvcignore
src/dvclive/_dvclive_version.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ from dvclive import Live

params = {"learning_rate": 0.002, "optimizer": "Adam", "epochs": 20}

with Live(save_dvc_exp=True) as live:
with Live() as live:

# log a parameters
for param in params:
Expand Down
2 changes: 1 addition & 1 deletion examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=True)],\n",
" callbacks=[DVCLiveCallback(log_model=True, report=\"notebook\")]\n",
" )\n",
" trainer.train()"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/DVCLive-PyTorch-Lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
" limit_train_batches=200,\n",
" limit_val_batches=100,\n",
" max_epochs=5,\n",
" logger=DVCLiveLogger(save_dvc_exp=True, report=\"notebook\", log_model=True),\n",
" logger=DVCLiveLogger(log_model=True, report=\"notebook\"),\n",
" )\n",
" trainer.fit(model, train_loader, validation_loader)\n"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/DVCLive-Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
"\n",
"best_test_acc = 0\n",
"\n",
"with Live(save_dvc_exp=True, report=\"notebook\") as live:\n",
"with Live(report=\"notebook\") as live:\n",
"\n",
" live.log_params(params)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/DVCLive-scikit-learn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"\n",
"for n_estimators in (10, 50, 100):\n",
"\n",
" with Live(report=None, save_dvc_exp=True) as live:\n",
" with Live() as live:\n",
"\n",
" live.log_param(\"n_estimators\", n_estimators)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"dvc>=3.17.0",
"dvc>=3.20.0",
"dvc-render>=0.5.0,<1.0",
"dvc-studio-client>=0.15.0,<1",
"funcy",
Expand Down
13 changes: 1 addition & 12 deletions src/dvclive/catalyst.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# ruff: noqa: ARG002
from typing import Optional

from catalyst import utils
from catalyst.core.callback import Callback, CallbackOrder

from dvclive import Live


class DVCLiveCallback(Callback):
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
def __init__(self, live: Optional[Live] = None, **kwargs):
super().__init__(order=CallbackOrder.external)
self.model_file = model_file
self.live = live if live is not None else Live(**kwargs)

def on_epoch_end(self, runner) -> None:
Expand All @@ -19,15 +17,6 @@ def on_epoch_end(self, runner) -> None:
self.live.log_metric(
f"{loader_key}/{key.replace('/', '_')}", float(value)
)

if self.model_file:
checkpoint = utils.pack_checkpoint(
model=runner.model,
criterion=runner.criterion,
optimizer=runner.optimizer,
scheduler=runner.scheduler,
)
utils.save_checkpoint(checkpoint, self.model_file)
self.live.next_step()

def on_experiment_end(self, runner):
Expand Down
76 changes: 58 additions & 18 deletions src/dvclive/dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dvclive.plots import Image, Metric
from dvclive.serialize import dump_yaml
from dvclive.utils import StrPath
from dvclive.utils import StrPath, rel_path

if TYPE_CHECKING:
from dvc.repo import Repo
Expand Down Expand Up @@ -51,38 +51,78 @@ def get_dvc_repo() -> Optional["Repo"]:
return None


def make_dvcyaml(live) -> None:
def make_dvcyaml(live) -> None: # noqa: C901
dvcyaml_dir = Path(live.dvc_file).parent.absolute().as_posix()

dvcyaml = {}
if live._params:
dvcyaml["params"] = [os.path.relpath(live.params_file, live.dir)]
dvcyaml["params"] = [rel_path(live.params_file, dvcyaml_dir)]
if live._metrics or live.summary:
dvcyaml["metrics"] = [os.path.relpath(live.metrics_file, live.dir)]
dvcyaml["metrics"] = [rel_path(live.metrics_file, dvcyaml_dir)]
plots: List[Any] = []
plots_path = Path(live.plots_dir)
metrics_path = plots_path / Metric.subfolder
if metrics_path.exists():
metrics_relpath = metrics_path.relative_to(live.dir).as_posix()
metrics_config = {metrics_relpath: {"x": "step"}}
plots_metrics_path = plots_path / Metric.subfolder
if plots_metrics_path.exists():
metrics_config = {rel_path(plots_metrics_path, dvcyaml_dir): {"x": "step"}}
plots.append(metrics_config)
if live._images:
images_path = (plots_path / Image.subfolder).relative_to(live.dir)
plots.append(images_path.as_posix())
images_path = rel_path(plots_path / Image.subfolder, dvcyaml_dir)
plots.append(images_path)
if live._plots:
for plot in live._plots.values():
plot_path = plot.output_path.relative_to(live.dir)
plots.append({plot_path.as_posix(): plot.plot_config})
plot_path = rel_path(plot.output_path, dvcyaml_dir)
plots.append({plot_path: plot.plot_config})
if plots:
dvcyaml["plots"] = plots

if live._artifacts:
dvcyaml["artifacts"] = copy.deepcopy(live._artifacts)
for artifact in dvcyaml["artifacts"].values(): # type: ignore
abs_path = os.path.abspath(artifact["path"])
abs_dir = os.path.realpath(live.dir)
relative_path = os.path.relpath(abs_path, abs_dir)
artifact["path"] = Path(relative_path).as_posix()

dump_yaml(dvcyaml, live.dvc_file)
artifact["path"] = rel_path(artifact["path"], dvcyaml_dir)

if not os.path.exists(live.dvc_file):
dump_yaml(dvcyaml, live.dvc_file)
else:
update_dvcyaml(live, dvcyaml)


def update_dvcyaml(live, updates): # noqa: C901
from dvc.utils.serialize import modify_yaml

dvcyaml_dir = os.path.abspath(os.path.dirname(live.dvc_file))
dvclive_dir = os.path.relpath(live.dir, dvcyaml_dir) + "/"

def _drop_stale_dvclive_entries(entries):
non_dvclive = []
for e in entries:
if isinstance(e, str):
if dvclive_dir not in e:
non_dvclive.append(e)
elif isinstance(e, dict) and len(e) == 1:
if dvclive_dir not in next(iter(e.keys())):
non_dvclive.append(e)
else:
non_dvclive.append(e)
return non_dvclive

def _update_entries(old, new, key):
keepers = _drop_stale_dvclive_entries(old.get(key, []))
old[key] = keepers + new.get(key, [])
if not old[key]:
del old[key]
return old

with modify_yaml(live.dvc_file) as orig:
orig = _update_entries(orig, updates, "params") # noqa: PLW2901
orig = _update_entries(orig, updates, "metrics") # noqa: PLW2901
orig = _update_entries(orig, updates, "plots") # noqa: PLW2901
old_artifacts = {}
for name, meta in orig.get("artifacts", {}).items():
if dvclive_dir not in meta.get("path", dvclive_dir):
old_artifacts[name] = meta
orig["artifacts"] = {**old_artifacts, **updates.get("artifacts", {})}
if not orig["artifacts"]:
del orig["artifacts"]


def get_random_exp_name(scm, baseline_rev) -> str:
Expand Down
5 changes: 5 additions & 0 deletions src/dvclive/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def __init__(self, name, val):
super().__init__(f"Data '{name}' has not supported type {val}")


class InvalidDvcyamlError(DvcLiveError):
def __init__(self):
super().__init__("`dvcyaml` path must have filename 'dvc.yaml'")


class InvalidPlotTypeError(DvcLiveError):
def __init__(self, name):
from .plots import SKLEARN_PLOTS
Expand Down
5 changes: 0 additions & 5 deletions src/dvclive/fastai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ def _inside_fine_tune():
class DVCLiveCallback(Callback):
def __init__(
self,
model_file: Optional[str] = None,
with_opt: bool = False,
live: Optional[Live] = None,
**kwargs,
):
super().__init__()
self.model_file = model_file
self.with_opt = with_opt
self.live = live if live is not None else Live(**kwargs)
self.freeze_stage_ended = False
Expand Down Expand Up @@ -66,9 +64,6 @@ def after_epoch(self):
# When resuming (i.e. passing `start_epoch` to learner)
# fast.ai calls after_epoch but we don't want to increase the step.
if logged_metrics:
if self.model_file:
file = self.learn.save(self.model_file, with_opt=self.with_opt)
self.live.log_artifact(str(file))
self.live.next_step()

def after_fit(self):
Expand Down
21 changes: 0 additions & 21 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def __init__(
):
super().__init__()
self._log_model = log_model
self.model_file = kwargs.pop("model_file", None)
if self.model_file:
logger.warning(
"model_file is deprecated and will be removed"
" in the next major version, use log_model instead"
)
self.live = live if live is not None else Live(**kwargs)

def on_train_begin(
Expand Down Expand Up @@ -65,21 +59,6 @@ def on_save(
if self._log_model == "all" and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)

def on_epoch_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self.model_file:
model = kwargs["model"]
model.save_pretrained(self.model_file)
tokenizer = kwargs.get("tokenizer")
if tokenizer:
tokenizer.save_pretrained(self.model_file)
self.live.log_artifact(self.model_file)

def on_train_end(
self,
args: TrainingArguments,
Expand Down
20 changes: 0 additions & 20 deletions src/dvclive/keras.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# ruff: noqa: ARG002
import os
from typing import Dict, Optional

import tensorflow as tf
Expand All @@ -11,37 +10,18 @@
class DVCLiveCallback(tf.keras.callbacks.Callback):
def __init__(
self,
model_file=None,
save_weights_only: bool = False,
live: Optional[Live] = None,
**kwargs,
):
super().__init__()
self.model_file = model_file
self.save_weights_only = save_weights_only
self.live = live if live is not None else Live(**kwargs)

def on_train_begin(self, logs=None):
if (
self.live._resume # noqa: SLF001
and self.model_file is not None
and os.path.exists(self.model_file)
):
if self.save_weights_only:
self.model.load_weights(self.model_file)
else:
self.model = tf.keras.models.load_model(self.model_file)

def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
logs = logs or {}
for metric, value in logs.items():
self.live.log_metric(standardize_metric_name(metric, __name__), value)
if self.model_file:
if self.save_weights_only:
self.model.save_weights(self.model_file)
else:
self.model.save(self.model_file)
self.live.log_artifact(self.model_file)
self.live.next_step()

def on_train_end(self, logs: Optional[Dict] = None):
Expand Down
6 changes: 1 addition & 5 deletions src/dvclive/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@


class DVCLiveCallback:
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
def __init__(self, live: Optional[Live] = None, **kwargs):
super().__init__()
self.model_file = model_file
self.live = live if live is not None else Live(**kwargs)

def __call__(self, env):
Expand All @@ -16,7 +15,4 @@ def __call__(self, env):
self.live.log_metric(
f"{data_name}/{eval_name}" if multi_eval else eval_name, result
)

if self.model_file:
env.model.save_model(self.model_file)
self.live.next_step()
2 changes: 1 addition & 1 deletion src/dvclive/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__( # noqa: PLR0913
experiment=None,
dir: Optional[str] = None, # noqa: A002
resume: bool = False,
report: Optional[str] = "auto",
report: Optional[str] = None,
save_dvc_exp: bool = False,
dvcyaml: bool = True,
cache_images: bool = False,
Expand Down