diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index 76df5870bac..4606e0b56d5 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -22,7 +22,7 @@ from .download import * from .features import * from .fingerprint import disable_caching, enable_caching, is_caching_enabled, set_caching_enabled -from .info import DatasetInfo, MetricInfo +from .info import DatasetInfo from .inspect import ( get_dataset_config_info, get_dataset_config_names, diff --git a/src/datasets/info.py b/src/datasets/info.py index 557f5b77d3f..c64d321e238 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -43,7 +43,7 @@ from huggingface_hub import DatasetCard, DatasetCardData from . import config -from .features import Features, Value +from .features import Features from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version @@ -506,88 +506,3 @@ def to_dataset_card_data(self, dataset_card_data: DatasetCardData) -> None: dataset_info_yaml_dict.pop("config_name", None) dataset_info_yaml_dict = {"config_name": config_name, **dataset_info_yaml_dict} dataset_card_data["dataset_info"].append(dataset_info_yaml_dict) - - -@dataclass -class MetricInfo: - """Information about a metric. - - `MetricInfo` documents a metric, including its name, version, and features. - See the constructor arguments and properties for a full list. - - Note: Not all fields are known on construction and may be updated later. - """ - - # Set in the dataset scripts - description: str - citation: str - features: Features - inputs_description: str = dataclasses.field(default_factory=str) - homepage: str = dataclasses.field(default_factory=str) - license: str = dataclasses.field(default_factory=str) - codebase_urls: List[str] = dataclasses.field(default_factory=list) - reference_urls: List[str] = dataclasses.field(default_factory=list) - streamable: bool = False - format: Optional[str] = None - - # Set later by the builder - metric_name: Optional[str] = None - config_name: Optional[str] = None - experiment_id: Optional[str] = None - - def __post_init__(self): - if self.format is not None: - for key, value in self.features.items(): - if not isinstance(value, Value): - raise ValueError( - f"When using 'numpy' format, all features should be a `datasets.Value` feature. " - f"Here {key} is an instance of {value.__class__.__name__}" - ) - - def write_to_directory(self, metric_info_dir, pretty_print=False): - """Write `MetricInfo` as JSON to `metric_info_dir`. - Also save the license separately in LICENCE. - If `pretty_print` is True, the JSON will be pretty-printed with the indent level of 4. - - Example: - - ```py - >>> from datasets import load_metric - >>> metric = load_metric("accuracy") - >>> metric.info.write_to_directory("/path/to/directory/") - ``` - """ - with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "w", encoding="utf-8") as f: - json.dump(asdict(self), f, indent=4 if pretty_print else None) - - if self.license: - with open(os.path.join(metric_info_dir, config.LICENSE_FILENAME), "w", encoding="utf-8") as f: - f.write(self.license) - - @classmethod - def from_directory(cls, metric_info_dir) -> "MetricInfo": - """Create MetricInfo from the JSON file in `metric_info_dir`. - - Args: - metric_info_dir: `str` The directory containing the metadata file. This - should be the root directory of a specific dataset version. - - Example: - - ```py - >>> from datasets import MetricInfo - >>> metric_info = MetricInfo.from_directory("/path/to/directory/") - ``` - """ - logger.info(f"Loading Metric info from {metric_info_dir}") - if not metric_info_dir: - raise ValueError("Calling MetricInfo.from_directory() with undefined metric_info_dir.") - - with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), encoding="utf-8") as f: - metric_info_dict = json.load(f) - return cls.from_dict(metric_info_dict) - - @classmethod - def from_dict(cls, metric_info_dict: dict) -> "MetricInfo": - field_names = {f.name for f in dataclasses.fields(cls)} - return cls(**{k: v for k, v in metric_info_dict.items() if k in field_names})