Skip to content

Commit

Permalink
gto like metadata fields: labels and type
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Sep 6, 2022
1 parent 0b2b59a commit 4aa8a28
Show file tree
Hide file tree
Showing 25 changed files with 353 additions and 92 deletions.
32 changes: 32 additions & 0 deletions dvc/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import asdict, dataclass, field, fields
from typing import ClassVar, Dict, List, Optional

from funcy import compact


@dataclass
class Annotation:
PARAM_DESC: ClassVar[str] = "desc"
PARAM_TYPE: ClassVar[str] = "type"
PARAM_LABELS: ClassVar[str] = "labels"

desc: Optional[str] = None
type: Optional[str] = None
labels: List[str] = field(default_factory=list)

def update(self, **kwargs) -> "Annotation":
for attr, value in kwargs.items():
if value and hasattr(self, attr):
setattr(self, attr, value)
return self

def to_dict(self) -> Dict[str, str]:
return compact(asdict(self))


ANNOTATION_FIELDS = [field.name for field in fields(Annotation)]
ANNOTATION_SCHEMA = {
Annotation.PARAM_DESC: str,
Annotation.PARAM_TYPE: str,
Annotation.PARAM_LABELS: [str],
}
10 changes: 10 additions & 0 deletions dvc/cli/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from argparse import _AppendAction


class CommaSeparatedArgs(_AppendAction): # pylint: disable=protected-access
def __call__(self, parser, namespace, values, option_string=None):
from funcy import ldistinct

items = getattr(namespace, self.dest) or []
items.extend(map(str.strip, values.split(",")))
setattr(namespace, self.dest, ldistinct(items))
39 changes: 30 additions & 9 deletions dvc/commands/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,38 @@
import logging

from dvc.cli import completion
from dvc.cli.actions import CommaSeparatedArgs
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link

logger = logging.getLogger(__name__)


def _add_annotating_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--desc",
type=str,
metavar="<text>",
help=(
"User description of the data (optional). "
"This doesn't affect any DVC operations."
),
)
parser.add_argument(
"--labels",
type=str,
action=CommaSeparatedArgs,
metavar="<str>",
help="Comma separated list of labels for the data (optional).",
)
parser.add_argument(
"--type",
type=str,
metavar="<str>",
help="Type of the data (optional).",
)


class CmdAdd(CmdBase):
def run(self):
from dvc.exceptions import (
Expand All @@ -28,6 +54,8 @@ def run(self):
glob=self.args.glob,
desc=self.args.desc,
out=self.args.out,
type=self.args.type,
labels=self.args.labels,
remote=self.args.remote,
to_remote=self.args.to_remote,
jobs=self.args.jobs,
Expand Down Expand Up @@ -109,15 +137,8 @@ def add_parser(subparsers, parent_parser):
),
metavar="<number>",
)
parser.add_argument(
"--desc",
type=str,
metavar="<text>",
help=(
"User description of the data (optional). "
"This doesn't affect any DVC operations."
),
)

_add_annotating_args(parser)
parser.add_argument(
"targets", nargs="+", help="Input files/directories to add."
).complete = completion.FILE
Expand Down
15 changes: 6 additions & 9 deletions dvc/commands/imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def run(self):
no_exec=self.args.no_exec,
no_download=self.args.no_download,
desc=self.args.desc,
type=self.args.type,
labels=self.args.labels,
jobs=self.args.jobs,
)
except DvcException:
Expand All @@ -34,6 +36,8 @@ def run(self):


def add_parser(subparsers, parent_parser):
from .add import _add_annotating_args

IMPORT_HELP = (
"Download file or directory tracked by DVC or by Git "
"into the workspace, and track it."
Expand Down Expand Up @@ -84,15 +88,6 @@ def add_parser(subparsers, parent_parser):
help="Create .dvc file including target data hash value(s)"
" but do not actually download the file(s).",
)
import_parser.add_argument(
"--desc",
type=str,
metavar="<text>",
help=(
"User description of the data (optional). "
"This doesn't affect any DVC operations."
),
)
import_parser.add_argument(
"-j",
"--jobs",
Expand All @@ -103,4 +98,6 @@ def add_parser(subparsers, parent_parser):
),
metavar="<number>",
)

_add_annotating_args(import_parser)
import_parser.set_defaults(func=CmdImport)
14 changes: 5 additions & 9 deletions dvc/commands/imp_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def run(self):
remote=self.args.remote,
to_remote=self.args.to_remote,
desc=self.args.desc,
type=self.args.type,
labels=self.args.labels,
jobs=self.args.jobs,
)
except DvcException:
Expand All @@ -35,6 +37,8 @@ def run(self):


def add_parser(subparsers, parent_parser):
from .add import _add_annotating_args

IMPORT_HELP = (
"Download or copy file from URL and take it under DVC control."
)
Expand Down Expand Up @@ -103,13 +107,5 @@ def add_parser(subparsers, parent_parser):
),
metavar="<number>",
)
import_parser.add_argument(
"--desc",
type=str,
metavar="<text>",
help=(
"User description of the data (optional). "
"This doesn't affect any DVC operations."
),
)
_add_annotating_args(import_parser)
import_parser.set_defaults(func=CmdImportUrl)
30 changes: 17 additions & 13 deletions dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dvc_data.transfer import transfer as otransfer
from dvc_objects.errors import ObjectFormatError

from .annotations import ANNOTATION_FIELDS, ANNOTATION_SCHEMA, Annotation
from .fs import LocalFileSystem, RemoteMissingDepsError, Schemes, get_cloud_fs
from .utils import relpath
from .utils.fs import path_isin
Expand Down Expand Up @@ -80,9 +81,9 @@ def loadd_from(stage, d_list):
plot = d.pop(Output.PARAM_PLOT, False)
persist = d.pop(Output.PARAM_PERSIST, False)
checkpoint = d.pop(Output.PARAM_CHECKPOINT, False)
desc = d.pop(Output.PARAM_DESC, False)
live = d.pop(Output.PARAM_LIVE, False)
remote = d.pop(Output.PARAM_REMOTE, None)
annot = {field: d.pop(field, None) for field in ANNOTATION_FIELDS}
ret.append(
_get(
stage,
Expand All @@ -93,9 +94,9 @@ def loadd_from(stage, d_list):
plot=plot,
persist=persist,
checkpoint=checkpoint,
desc=desc,
live=live,
remote=remote,
**annot,
)
)
return ret
Expand Down Expand Up @@ -187,7 +188,7 @@ def load_from_pipeline(stage, data, typ="outs"):
Output.PARAM_PERSIST,
Output.PARAM_CHECKPOINT,
Output.PARAM_REMOTE,
Output.PARAM_DESC,
Annotation.PARAM_DESC,
],
)

Expand Down Expand Up @@ -253,7 +254,6 @@ class Output:
PARAM_PLOT_TITLE = "title"
PARAM_PLOT_HEADER = "header"
PARAM_PERSIST = "persist"
PARAM_DESC = "desc"
PARAM_LIVE = "live"
PARAM_LIVE_SUMMARY = "summary"
PARAM_LIVE_HTML = "html"
Expand Down Expand Up @@ -285,9 +285,12 @@ def __init__(
checkpoint=False,
live=False,
desc=None,
type=None, # pylint: disable=redefined-builtin
labels=None,
remote=None,
repo=None,
):
self.annot = Annotation(desc=desc, type=type, labels=labels or [])
self.repo = stage.repo if not repo and stage else repo
meta = Meta.from_dict(info)
# NOTE: when version_aware is not passed into get_cloud_fs, it will be
Expand Down Expand Up @@ -337,7 +340,6 @@ def __init__(
self.persist = persist
self.checkpoint = checkpoint
self.live = live
self.desc = desc

self.fs_path = self._parse_path(self.fs, fs_path)
self.obj = None
Expand Down Expand Up @@ -696,9 +698,7 @@ def dumpd(self):
if self.IS_DEPENDENCY:
return ret

if self.desc:
ret[self.PARAM_DESC] = self.desc

ret.update(self.annot.to_dict())
if not self.use_cache:
ret[self.PARAM_CACHE] = self.use_cache

Expand Down Expand Up @@ -1114,22 +1114,26 @@ def is_plot(self) -> bool:
return bool(self.plot) or bool(self.live)


META_SCHEMA = {
Meta.PARAM_SIZE: int,
Meta.PARAM_NFILES: int,
Meta.PARAM_ISEXEC: bool,
Meta.PARAM_VERSION_ID: str,
}

ARTIFACT_SCHEMA = {
**CHECKSUMS_SCHEMA,
**META_SCHEMA,
Required(Output.PARAM_PATH): str,
Output.PARAM_PLOT: bool,
Output.PARAM_PERSIST: bool,
Output.PARAM_CHECKPOINT: bool,
Meta.PARAM_SIZE: int,
Meta.PARAM_NFILES: int,
Meta.PARAM_ISEXEC: bool,
Meta.PARAM_VERSION_ID: str,
}

SCHEMA = {
**ARTIFACT_SCHEMA,
**ANNOTATION_SCHEMA,
Output.PARAM_CACHE: bool,
Output.PARAM_METRIC: Output.METRIC_SCHEMA,
Output.PARAM_DESC: str,
Output.PARAM_REMOTE: str,
}
5 changes: 3 additions & 2 deletions dvc/repo/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def create_stages(
outs=[out],
external=external,
)
if kwargs.get("desc"):
stage.outs[0].desc = kwargs["desc"]

out_obj = stage.outs[0]
out_obj.annot.update(**kwargs)
yield stage
7 changes: 4 additions & 3 deletions dvc/repo/imp_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def imp_url(
remote=None,
to_remote=False,
desc=None,
type=None, # pylint: disable=redefined-builtin
labels=None,
jobs=None,
):
from dvc.dvcfile import Dvcfile
Expand Down Expand Up @@ -61,9 +63,8 @@ def imp_url(
)
restore_fields(stage)

if desc:
stage.outs[0].desc = desc

out_obj = stage.outs[0]
out_obj.annot.update(desc=desc, type=type, labels=labels)
dvcfile = Dvcfile(self, stage.path)
dvcfile.remove()

Expand Down
3 changes: 2 additions & 1 deletion dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from voluptuous import Any, Optional, Required, Schema

from dvc import dependency, output
from dvc.annotations import Annotation
from dvc.output import CHECKSUMS_SCHEMA, Output
from dvc.parsing import DO_KWD, FOREACH_KWD, VARS_KWD
from dvc.parsing.versions import SCHEMA_KWD, lockfile_version_schema
Expand Down Expand Up @@ -50,7 +51,7 @@
Output.PARAM_CACHE: bool,
Output.PARAM_PERSIST: bool,
Output.PARAM_CHECKPOINT: bool,
Output.PARAM_DESC: str,
Annotation.PARAM_DESC: str,
Output.PARAM_REMOTE: str,
}
}
Expand Down
9 changes: 3 additions & 6 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,13 @@ def restore_fields(stage):
# will be used to restore comments later
# noqa, pylint: disable=protected-access
stage._stage_text = old._stage_text

stage.meta = old.meta
stage.desc = old.desc

old_fields = {out.def_path: (out.desc, out.remote) for out in old.outs}

old_fields = {out.def_path: (out.annot, out.remote) for out in old.outs}
for out in stage.outs:
out_fields = old_fields.get(out.def_path, None)
if out_fields:
out.desc, out.remote = out_fields
if out_fields := old_fields.get(out.def_path, None):
out.annot, out.remote = out_fields


class Stage(params.StageParams):
Expand Down
8 changes: 4 additions & 4 deletions dvc/stage/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from funcy import post_processing

from dvc.dependency import ParamsDependency
from dvc.output import Output
from dvc.output import Annotation, Output
from dvc.utils.collections import apply_diff
from dvc.utils.serialize import parse_yaml_for_update

Expand All @@ -27,7 +27,7 @@
PARAM_PLOT = Output.PARAM_PLOT
PARAM_PERSIST = Output.PARAM_PERSIST
PARAM_CHECKPOINT = Output.PARAM_CHECKPOINT
PARAM_DESC = Output.PARAM_DESC
PARAM_DESC = Annotation.PARAM_DESC
PARAM_REMOTE = Output.PARAM_REMOTE

DEFAULT_PARAMS_FILE = ParamsDependency.DEFAULT_PARAMS_FILE
Expand All @@ -38,8 +38,8 @@

@post_processing(OrderedDict)
def _get_flags(out):
if out.desc:
yield PARAM_DESC, out.desc
if out.annot.desc:
yield PARAM_DESC, out.annot.desc
if not out.use_cache:
yield PARAM_CACHE, False
if out.checkpoint:
Expand Down
Loading

0 comments on commit 4aa8a28

Please sign in to comment.