Skip to content

Commit

Permalink
Delete MetricInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova committed Jun 19, 2024
1 parent 9b30cb0 commit 6aa3d6c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .download import *
from .features import *
from .fingerprint import disable_caching, enable_caching, is_caching_enabled, set_caching_enabled
from .info import DatasetInfo, MetricInfo
from .info import DatasetInfo
from .inspect import (
get_dataset_config_info,
get_dataset_config_names,
Expand Down
87 changes: 1 addition & 86 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from huggingface_hub import DatasetCard, DatasetCardData

from . import config
from .features import Features, Value
from .features import Features
from .splits import SplitDict
from .tasks import TaskTemplate, task_template_from_dict
from .utils import Version
Expand Down Expand Up @@ -506,88 +506,3 @@ def to_dataset_card_data(self, dataset_card_data: DatasetCardData) -> None:
dataset_info_yaml_dict.pop("config_name", None)
dataset_info_yaml_dict = {"config_name": config_name, **dataset_info_yaml_dict}
dataset_card_data["dataset_info"].append(dataset_info_yaml_dict)


@dataclass
class MetricInfo:
"""Information about a metric.
`MetricInfo` documents a metric, including its name, version, and features.
See the constructor arguments and properties for a full list.
Note: Not all fields are known on construction and may be updated later.
"""

# Set in the dataset scripts
description: str
citation: str
features: Features
inputs_description: str = dataclasses.field(default_factory=str)
homepage: str = dataclasses.field(default_factory=str)
license: str = dataclasses.field(default_factory=str)
codebase_urls: List[str] = dataclasses.field(default_factory=list)
reference_urls: List[str] = dataclasses.field(default_factory=list)
streamable: bool = False
format: Optional[str] = None

# Set later by the builder
metric_name: Optional[str] = None
config_name: Optional[str] = None
experiment_id: Optional[str] = None

def __post_init__(self):
if self.format is not None:
for key, value in self.features.items():
if not isinstance(value, Value):
raise ValueError(
f"When using 'numpy' format, all features should be a `datasets.Value` feature. "
f"Here {key} is an instance of {value.__class__.__name__}"
)

def write_to_directory(self, metric_info_dir, pretty_print=False):
"""Write `MetricInfo` as JSON to `metric_info_dir`.
Also save the license separately in LICENCE.
If `pretty_print` is True, the JSON will be pretty-printed with the indent level of 4.
Example:
```py
>>> from datasets import load_metric
>>> metric = load_metric("accuracy")
>>> metric.info.write_to_directory("/path/to/directory/")
```
"""
with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "w", encoding="utf-8") as f:
json.dump(asdict(self), f, indent=4 if pretty_print else None)

if self.license:
with open(os.path.join(metric_info_dir, config.LICENSE_FILENAME), "w", encoding="utf-8") as f:
f.write(self.license)

@classmethod
def from_directory(cls, metric_info_dir) -> "MetricInfo":
"""Create MetricInfo from the JSON file in `metric_info_dir`.
Args:
metric_info_dir: `str` The directory containing the metadata file. This
should be the root directory of a specific dataset version.
Example:
```py
>>> from datasets import MetricInfo
>>> metric_info = MetricInfo.from_directory("/path/to/directory/")
```
"""
logger.info(f"Loading Metric info from {metric_info_dir}")
if not metric_info_dir:
raise ValueError("Calling MetricInfo.from_directory() with undefined metric_info_dir.")

with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), encoding="utf-8") as f:
metric_info_dict = json.load(f)
return cls.from_dict(metric_info_dict)

@classmethod
def from_dict(cls, metric_info_dict: dict) -> "MetricInfo":
field_names = {f.name for f in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metric_info_dict.items() if k in field_names})

0 comments on commit 6aa3d6c

Please sign in to comment.