diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0ae8a3c9..b3c062fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,8 +56,10 @@ jobs: run: pip install . - name: Validate JSON-LD files + # wiki-text is excluded at the moment. See: https://github.com/mlcommons/croissant/issues/101. + # movielens is excluded at the moment. See: https://github.com/mlcommons/croissant/issues/103. run: | - JSON_FILES=$(python -c "import os; from etils import epath; [print(os.fspath(path)) for path in epath.Path('../../datasets').glob('*/*.json')]") + JSON_FILES=$(find ../../datasets/ -type f -name "*.json" ! -path '*wiki-text*' ! -path '*movielens*') for file in ${JSON_FILES} do echo "Validating ${file}..." diff --git a/datasets/movielens/metadata.json b/datasets/movielens/metadata.json index d16272f6..9e173886 100644 --- a/datasets/movielens/metadata.json +++ b/datasets/movielens/metadata.json @@ -183,7 +183,7 @@ ] }, { - "name": "movies+ratings+tags", + "name": "movies_with_ratings_with_tags", "@type": "ml:RecordSet", "source": "#{movies}", "key": "#{movie_id}", @@ -209,7 +209,6 @@ "dataType": "ml:RecordSet", "source": "#{ratings}", "parentField": { - "@type": "ml:Field", "source": "#{ratings/movie_id}", "references": "#{movies}" }, @@ -237,7 +236,6 @@ "dataType": "ml:RecordSet", "source": "#{tags}", "parentField": { - "@type": "ml:Field", "source": "#{tags/movie_id}", "references": "#{movies}" }, diff --git a/datasets/recipes/compressed_archive.json b/datasets/recipes/compressed_archive.json index 582918a4..f794d5ed 100644 --- a/datasets/recipes/compressed_archive.json +++ b/datasets/recipes/compressed_archive.json @@ -10,7 +10,7 @@ "source": "ml:source" }, "@type": "sc:Dataset", - "name": "Compressed archive example", + "name": "compressed_archive_example", "description": "This is a fairly minimal example, showing a way to describe archive files.", "url": "https://example.com/datasets/recipes/compressed_archive/about", "distribution": [ diff --git a/datasets/recipes/enum.json b/datasets/recipes/enum.json index 8c851866..e27bc88c 100644 --- a/datasets/recipes/enum.json +++ b/datasets/recipes/enum.json @@ -11,7 +11,7 @@ "references": "ml:references" }, "@type": "sc:Dataset", - "name": "Enum example", + "name": "enum_example", "description": "This is a fairly minimal example, showing a way to describe enumerations.", "url": "https://example.com/datasets/enum/about", "distribution": [ diff --git a/datasets/recipes/minimal.json b/datasets/recipes/minimal.json index a6366c6e..db3e1b60 100644 --- a/datasets/recipes/minimal.json +++ b/datasets/recipes/minimal.json @@ -4,7 +4,7 @@ "sc": "https://schema.org/" }, "@type": "sc:Dataset", - "name": "Minimal example", + "name": "minimal_example", "description": "This is a very minimal example, with only the required fields.", "url": "https://example.com/dataset/minimal/about" } diff --git a/datasets/recipes/minimal_recommended.json b/datasets/recipes/minimal_recommended.json index d963de41..56d9973c 100644 --- a/datasets/recipes/minimal_recommended.json +++ b/datasets/recipes/minimal_recommended.json @@ -10,7 +10,7 @@ "references": "ml:references" }, "@type": "sc:Dataset", - "name": "Minimal example with recommended fields", + "name": "minimal_example_with_recommended_fields", "description": "This is a minimal example, including the required and the recommended fields.", "url": "https://example.com/dataset/recipes/minimal-recommended", "license": "https://creativecommons.org/licenses/by/4.0/", diff --git a/datasets/wiki-text/metadata.json b/datasets/wiki-text/metadata.json index 7b312503..b0b37a93 100644 --- a/datasets/wiki-text/metadata.json +++ b/datasets/wiki-text/metadata.json @@ -14,6 +14,7 @@ "applyTransform": "ml:applyTransform", "format": "ml:format", "regex": "ml:regex", + "replace": "ml:replace", "separator": "ml:separator", "references": "ml:references" }, diff --git a/python/ml_croissant/README.md b/python/ml_croissant/README.md index 948f2870..ca87ae12 100644 --- a/python/ml_croissant/README.md +++ b/python/ml_croissant/README.md @@ -35,10 +35,52 @@ python -m pip install ".[dev]" pytest . ``` -## Roadmap +## Design -Refer to the [design doc](https://docs.google.com/document/d/1zYQIUX9ae1sZOOBq9OCsJ8JW8-Ejy3NLSeqaI5LtOEM/edit?resourcekey=0-CK78DfFvF7fnufyZqF3h3Q) for an overview of the implementation. +The most important modules in the library are: -Refer to the [GitHub project](https://github.com/orgs/mlcommons/projects/26) for more detailed user stories. +- [`ml_croissant/_src/structure_graph`](./ml_croissant/_src/structure_graph/graph.py) is responsible for the **static analysis** of the Croissant files. We convert Croissant files to a Python representation called "**structure graph**" (using [NetworkX](https://networkx.org/)). In the process, we catch any static analysis issues (e.g., a missing mandatory field or a logic problem in the file). +- [`ml_croissant/_src/operation_graph`](./ml_croissant/_src/operation_graph/graph.py) is responsible for the **dynamic analysis** of the Croissant files (i.e., actually loading the dataset by yielding examples). We convert the structure graph into an "**operation graph**". Operations are the unit transformation that allow to build the dataset (like [`Download`](./ml_croissant/_src/operation_graph/operations/download.py), [`Extract`](./ml_croissant/_src/operation_graph/operations/extract.py), etc). -All contributions are welcome! We even have [good first issues](https://github.com/mlcommons/croissant/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) to start in the project. +Other important modules are: + +- [`ml_croissant/_src/core`](./ml_croissant/_src/core) defines all needed core internals. For instance, [`Issues`](./ml_croissant/_src/core/issues.py) are a way to track errors and warning during the analysis of Croissant files. +- [`ml_croissant/__init__`](./ml_croissant/__init__.py) declares the public API with [`ml_croissant.Dataset`](./ml_croissant/_src/datasets.py). + +For the full design, refer to the [design doc](https://docs.google.com/document/d/1zYQIUX9ae1sZOOBq9OCsJ8JW8-Ejy3NLSeqaI5LtOEM/edit?resourcekey=0-CK78DfFvF7fnufyZqF3h3Q) for an overview of the implementation. + +## Contribute + +All contributions are welcome! We even have [good first issues](https://github.com/mlcommons/croissant/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) to start in the project. Refer to the [GitHub project](https://github.com/orgs/mlcommons/projects/26) for more detailed user stories. + +The development workflow goes as follow: + +- [Fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo) the repository: https://github.com/mlcommons/croissant. +- Clone the newly forked repository: + ```bash + git clone git@github.com:/croissant.git + ``` +- Create a new branch: + ```bash + cd croissant/ + git checkout -b feature/my-awesome-new-feature + ``` +- Code the feature. We support [VS Code](https://code.visualstudio.com) with pre-set settings. +- Push to GitHub: + ```bash + git add . + git push --set-upstream origin feature/my-awesome-new-feature + ``` +- Open a pull request (PR) with the main branch of https://github.com/mlcommons/croissant, and ask for feedback! + +## Debug + +You can debug the validation of the file with the `--debug` flag: + +```bash +python scripts/validate.py --file ../../datasets/titanic/metadata.json --debug +``` + +This will: +1. print extra information, like the generated nodes; +2. save the generated structure graph to a folder indicated in the logs. diff --git a/python/ml_croissant/ml_croissant/_src/computations.py b/python/ml_croissant/ml_croissant/_src/computations.py deleted file mode 100644 index 0457fd53..00000000 --- a/python/ml_croissant/ml_croissant/_src/computations.py +++ /dev/null @@ -1,352 +0,0 @@ -"""graph module.""" - -from collections.abc import Mapping -import dataclasses - -from etils import epath -from ml_croissant._src.core import constants -from ml_croissant._src.core.issues import Issues -from ml_croissant._src.structure_graph.nodes import ( - Field, - FileObject, - FileSet, - Metadata, - RecordSet, -) -from ml_croissant._src.operation_graph.base_operation import Operation -from ml_croissant._src.operation_graph.operations import ( - Data, - Download, - GroupRecordSet, - InitOperation, - Join, - Merge, - Untar, - ReadCsv, - ReadField, -) -from ml_croissant._src.structure_graph.base_node import Node -import networkx as nx -from rdflib import namespace - - -def concatenate_uid(source: tuple[str]) -> str: - return "/".join(source) - - -def _find_record_set(graph: nx.MultiDiGraph, node: Node) -> RecordSet: - """Finds the record set to which a field is attached. - - The record set will be typically either the parent or the parent's parent. - """ - parent_node = graph.nodes[node].get("parent") - if isinstance(parent_node, RecordSet): - return parent_node - elif parent_node is None: - raise ValueError(f"Node {node} is not in a RecordSet.") - # Recursively returns the parent's the parent. - return _find_record_set(graph, parent_node) - - -def _add_operations_for_field_with_source( - issues: Issues, - graph: nx.MultiDiGraph, - operations: nx.MultiDiGraph, - last_operation: Mapping[Node, Operation], - node: Field, - rdf_namespace_manager: namespace.NamespaceManager, -): - """Adds all operations for a node of type `Field`. - - Operations are: - - - `Join` if the field comes from several sources. - - `ReadField` to specify how the field is read. - - `GroupRecordSet` to structure the final dict that is sent back to the user. - """ - # Attach the field to a record set - record_set = _find_record_set(graph, node) - group_record_set = GroupRecordSet(node=record_set) - parent_node = graph.nodes[node].get("parent") - join = Join(node=parent_node) - # `Join()` takes left=Source and right=Source as kwargs. - if node.references is not None and len(node.references.reference) > 1: - kwargs = { - "left": node.source, - "right": node.references, - } - operations.add_node(join, kwargs=kwargs) - else: - # Else, we add a dummy JOIN operation. - operations.add_node(join) - operations.add_edge(join, group_record_set) - for predecessor in graph.predecessors(node): - operations.add_edge(last_operation[predecessor], join) - if len(node.source.reference) != 2: - issues.add_error(f'Wrong source in node "{node.uid}"') - return - # Read/extract the field - read_field = ReadField(node=node, rdf_namespace_manager=rdf_namespace_manager) - operations.add_edge(group_record_set, read_field) - last_operation[node] = read_field - - -def _add_operations_for_field_with_data( - graph: nx.MultiDiGraph, - operations: nx.MultiDiGraph, - last_operation: Mapping[Node, Operation], - node: Field, -): - """Adds a `Data` operation for a node of type `Field` with data. - - Those nodes return a DataFrame representing the lines in `data`. - """ - operation = Data(node=node) - for predecessor in graph.predecessors(node): - operations.add_edge(last_operation[predecessor], operation) - last_operation[node] = operation - - -def _add_operations_for_file_object( - graph: nx.MultiDiGraph, - operations: nx.MultiDiGraph, - last_operation: Mapping[Node, Operation], - node: Node, - croissant_folder: epath.Path, -): - """Adds all operations for a node of type `FileObject`. - - Operations are: - - - `Download`. - - `Untar` if the file needs to be extracted. - - `Merge` to merge several dataframes into one. - - `ReadCsv` to read the file if it's a CSV. - """ - # Download the file - operation = Download(node=node, url=node.content_url) - operations.add_node(operation) - for successor in graph.successors(node): - # Extract the file if needed - if ( - node.encoding_format == "application/x-tar" - and isinstance(successor, (FileObject, FileSet)) - and successor.encoding_format != "application/x-tar" - ): - untar = Untar(node=node, target_node=successor) - operations.add_edge(operation, untar) - last_operation[node] = untar - operation = untar - if isinstance(successor, FileSet): - merge = Merge(node=successor) - operations.add_edge(operation, merge) - operation = merge - # Read the file - if node.encoding_format == "text/csv": - read_csv = ReadCsv( - node=node, - url=node.content_url, - croissant_folder=croissant_folder, - ) - operations.add_edge(operation, read_csv) - operation = read_csv - last_operation[node] = operation - - -@dataclasses.dataclass(frozen=True) -class ComputationGraph: - """Graph of dependent operations to execute to generate the dataset.""" - - issues: Issues - graph: nx.MultiDiGraph - - @classmethod - def from_nodes( - cls, - issues: Issues, - metadata: Node, - graph: nx.MultiDiGraph, - croissant_folder: epath.Path, - rdf_namespace_manager: namespace.NamespaceManager, - ) -> "ComputationGraph": - """Builds the ComputationGraph from the nodes. - - This is done by: - - 1. Building the structure graph. - 2. Building the computation graph by exploring the structure graph layers by - layers in a breadth-first search. - """ - last_operation: Mapping[Node, Operation] = {} - operations = nx.MultiDiGraph() - # Find all fields - for node in nx.topological_sort(graph): - predecessors = graph.predecessors(node) - # Transfer operation from predecessor -> node. - for predecessor in predecessors: - if predecessor in last_operation: - last_operation[node] = last_operation[predecessor] - if isinstance(node, Field): - if node.source and not node.has_sub_fields: - _add_operations_for_field_with_source( - issues, - graph, - operations, - last_operation, - node, - rdf_namespace_manager, - ) - elif node.data: - _add_operations_for_field_with_data( - graph, - operations, - last_operation, - node, - ) - elif isinstance(node, FileObject): - _add_operations_for_file_object( - graph, operations, last_operation, node, croissant_folder - ) - - # Attach all entry nodes to a single `start` node - entry_operations = get_entry_nodes(issues, operations) - init_operation = InitOperation(node=metadata) - for entry_operation in entry_operations: - operations.add_edge(init_operation, entry_operation) - return ComputationGraph(issues=issues, graph=operations) - - def check_graph(self): - """Checks the computation graph for issues.""" - if not self.graph.is_directed(): - self.issues.add_error("Computation graph is not directed.") - selfloops = [operation.uid for operation, _ in nx.selfloop_edges(self.graph)] - if selfloops: - self.issues.add_error( - f"The following operations refered to themselves: {selfloops}" - ) - - -def get_entry_nodes(issues: Issues, graph: nx.MultiDiGraph) -> list[Node]: - """Retrieves the entry nodes (without predecessors) in a graph.""" - entry_nodes = [] - for node, indegree in graph.in_degree(graph.nodes()): - if indegree == 0: - entry_nodes.append(node) - # Fields should usually not be entry nodes, except if they have subFields. So we - # check for this: - for node in entry_nodes: - if isinstance(node, Field) and not node.has_sub_fields: - issues.add_error( - f'Node "{node.uid}" is a field and has no source. Please, use' - f" {constants.ML_COMMONS_SOURCE} to specify the source." - ) - return entry_nodes - - -def _check_no_duplicate(issues: Issues, nodes: list[Node]) -> Mapping[str, Node]: - """Checks that no node has duplicated UID and returns the mapping `uid`->`Node`.""" - uid_to_node: Mapping[str, Node] = {} - for node in nodes: - if node.uid in uid_to_node: - issues.add_error(f"Duplicate node with the same identifier: {node.uid}") - uid_to_node[node.uid] = node - return uid_to_node - - -def add_node_as_entry_node(issues: Issues, graph: nx.MultiDiGraph, node: Node): - """Add `node` as the entry node of the graph by updating `graph` in place.""" - graph.add_node(node, parent=None) - entry_nodes = get_entry_nodes(issues, graph) - for entry_node in entry_nodes: - if isinstance(node, (FileObject, FileSet)): - graph.add_edge(entry_node, node) - - -def add_edge( - issues: Issues, - graph: nx.MultiDiGraph, - uid_to_node: Mapping[str, Node], - uid: str, - node: Node, - expected_types: type | tuple[type], -): - if uid not in uid_to_node: - issues.add_error( - f'There is a reference to node named "{uid}" in node "{node.uid}", but this' - " node doesn't exist." - ) - return - if not isinstance(uid_to_node[uid], expected_types): - issues.add_error( - f'There is a reference to node named "{uid}" in node "{node.uid}", but this' - f" node doesn't have the expected type: {expected_types}." - ) - return - graph.add_edge(uid_to_node[uid], node) - - -def build_structure_graph( - issues: Issues, nodes: list[Node] -) -> tuple[Node, nx.MultiDiGraph]: - """Builds the structure graph from the nodes. - - The structure graph represents the relationship between the nodes: - - - For ml:Fields without ml:subField, the predecessors in the structure graph are the - sources. - - For sc:FileSet or sc:FileObject with a `containedIn`, the predecessors in the - structure graph are those `containedId`. - - For other objects, the predecessors are their parents (i.e., predecessors in the - JSON-LD). For example: for ml:Field with subField, the predecessors are the - ml:RecordSet in which they are contained. - """ - graph = nx.MultiDiGraph() - uid_to_node = _check_no_duplicate(issues, nodes) - for node in nodes: - if isinstance(node, Metadata): - continue - parent = uid_to_node[node.parent_uid] - graph.add_node(node, parent=parent) - # Distribution - if isinstance(node, (FileObject, FileSet)) and node.contained_in: - for uid in node.contained_in: - add_edge(issues, graph, uid_to_node, uid, node, (FileObject, FileSet)) - # Fields - elif isinstance(node, Field): - references = [] - if node.source is not None: - references.append(node.source.reference) - if node.references is not None: - references.append(node.references.reference) - for reference in references: - # The source can be either another field... - if (uid := concatenate_uid(reference)) in uid_to_node: - # Record sets are not valid parents here. - # The case can arise when a Field references a record set to have a - # machine-readable explanation of the field (see datasets/titanic - # for example). - if not isinstance(uid_to_node[uid], RecordSet): - add_edge(issues, graph, uid_to_node, uid, node, Node) - # ...or the source can be a metadata. - elif (uid := reference[0]) in uid_to_node: - if not isinstance(uid_to_node[uid], RecordSet): - add_edge( - issues, graph, uid_to_node, uid, node, (FileObject, FileSet) - ) - else: - issues.add_error( - "Source refers to an unknown node" - f' "{concatenate_uid(reference)}".' - ) - # Other nodes - elif node.parent_uid is not None: - add_edge(issues, graph, uid_to_node, node.parent_uid, node, Node) - # `Metadata` are used as the entry node. - metadata = next((node for node in nodes if isinstance(node, Metadata)), None) - if metadata is None: - issues.add_error("No metadata is defined in the dataset.") - return None, graph - add_node_as_entry_node(issues, graph, metadata) - if not graph.is_directed(): - issues.add_error("Structure graph is not directed.") - return metadata, graph diff --git a/python/ml_croissant/ml_croissant/_src/constants.py b/python/ml_croissant/ml_croissant/_src/constants.py deleted file mode 100644 index ef7f0790..00000000 --- a/python/ml_croissant/ml_croissant/_src/constants.py +++ /dev/null @@ -1,77 +0,0 @@ -"""constants module.""" - -from etils import epath -import rdflib - -# MLCommons-defined URIs (still draft). -ML_COMMONS_APPLY_TRANSFORM = rdflib.term.URIRef( - "http://mlcommons.org/schema/applyTransform" -) -ML_COMMONS_DATA = rdflib.term.URIRef("http://mlcommons.org/schema/data") -ML_COMMONS_DATA_TYPE = rdflib.term.URIRef("http://mlcommons.org/schema/dataType") -ML_COMMONS_FORMAT = rdflib.term.URIRef("http://mlcommons.org/schema/format") -ML_COMMONS_FIELD = rdflib.term.URIRef("http://mlcommons.org/schema/Field") -ML_COMMONS_INCLUDES = rdflib.term.URIRef("http://mlcommons.org/schema/includes") -ML_COMMONS_RECORD_SET = rdflib.term.URIRef("http://mlcommons.org/schema/RecordSet") -ML_COMMONS_REFERENCES = rdflib.term.URIRef("http://mlcommons.org/schema/references") -ML_COMMONS_REGEX = rdflib.term.URIRef("http://mlcommons.org/schema/regex") -ML_COMMONS_SOURCE = rdflib.term.URIRef("http://mlcommons.org/schema/source") -ML_COMMONS_SUB_FIELD = rdflib.term.URIRef("http://mlcommons.org/schema/SubField") - -# RDF standard URIs. -# For "@type" key: -RDF_TYPE = rdflib.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type") - -# Schema.org standard URIs. -SCHEMA_ORG_CITATION = rdflib.term.URIRef("https://schema.org/citation") -SCHEMA_ORG_CONTAINED_IN = rdflib.term.URIRef("https://schema.org/containedIn") -SCHEMA_ORG_CONTENT_SIZE = rdflib.term.URIRef("https://schema.org/contentSize") -SCHEMA_ORG_CONTENT_URL = rdflib.term.URIRef("https://schema.org/contentUrl") -SCHEMA_ORG_DATASET = rdflib.URIRef("https://schema.org/Dataset") -SCHEMA_ORG_DATA_TYPE_BOOL = rdflib.term.URIRef("https://schema.org/Boolean") -SCHEMA_ORG_DATA_TYPE_DATE = rdflib.term.URIRef("https://schema.org/Date") -SCHEMA_ORG_DATA_TYPE_FLOAT = rdflib.term.URIRef("https://schema.org/Float") -SCHEMA_ORG_DATA_TYPE_INTEGER = rdflib.term.URIRef("https://schema.org/Integer") -SCHEMA_ORG_DATA_TYPE_TEXT = rdflib.term.URIRef("https://schema.org/Text") -SCHEMA_ORG_DATA_TYPE_URL = rdflib.term.URIRef("https://schema.org/URL") -SCHEMA_ORG_DESCRIPTION = rdflib.term.URIRef("https://schema.org/description") -SCHEMA_ORG_DISTRIBUTION = rdflib.term.URIRef("https://schema.org/distribution") -SCHEMA_ORG_EMAIL = rdflib.term.URIRef("https://schema.org/email") -SCHEMA_ORG_ENCODING_FORMAT = rdflib.term.URIRef("https://schema.org/encodingFormat") -SCHEMA_ORG_LICENSE = rdflib.term.URIRef("https://schema.org/license") -SCHEMA_ORG_NAME = rdflib.term.URIRef("https://schema.org/name") -SCHEMA_ORG_SHA256 = rdflib.term.URIRef("https://schema.org/sha256") -SCHEMA_ORG_URL = rdflib.term.URIRef("https://schema.org/url") - -# Schema.org URIs that do not exist yet in the standard. -SCHEMA_ORG_FILE_OBJECT = rdflib.term.URIRef("https://schema.org/FileObject") -SCHEMA_ORG_FILE_SET = rdflib.term.URIRef("https://schema.org/FileSet") -SCHEMA_ORG_MD5 = rdflib.term.URIRef("https://schema.org/md5") - -TO_CROISSANT = { - ML_COMMONS_APPLY_TRANSFORM: "apply_transform", - ML_COMMONS_DATA_TYPE: "data_type", - ML_COMMONS_DATA: "data", - ML_COMMONS_FORMAT: "format", - ML_COMMONS_INCLUDES: "includes", - ML_COMMONS_REFERENCES: "references", - ML_COMMONS_REGEX: "regex", - ML_COMMONS_SOURCE: "source", - SCHEMA_ORG_CITATION: "citation", - SCHEMA_ORG_CONTAINED_IN: "contained_in", - SCHEMA_ORG_CONTENT_SIZE: "content_size", - SCHEMA_ORG_CONTENT_URL: "content_url", - SCHEMA_ORG_DESCRIPTION: "description", - SCHEMA_ORG_ENCODING_FORMAT: "encoding_format", - SCHEMA_ORG_LICENSE: "license", - SCHEMA_ORG_MD5: "md5", - SCHEMA_ORG_NAME: "name", - SCHEMA_ORG_SHA256: "sha256", - SCHEMA_ORG_URL: "url", -} - -FROM_CROISSANT = {v: k for k, v in TO_CROISSANT.items()} - -CROISSANT_CACHE = epath.Path("~/.cache/croissant").expanduser() -DOWNLOAD_PATH = CROISSANT_CACHE / "download" -EXTRACT_PATH = CROISSANT_CACHE / "extract" diff --git a/python/ml_croissant/ml_croissant/_src/core/constants.py b/python/ml_croissant/ml_croissant/_src/core/constants.py index ef7f0790..8b6a9791 100644 --- a/python/ml_croissant/ml_croissant/_src/core/constants.py +++ b/python/ml_croissant/ml_croissant/_src/core/constants.py @@ -15,6 +15,8 @@ ML_COMMONS_RECORD_SET = rdflib.term.URIRef("http://mlcommons.org/schema/RecordSet") ML_COMMONS_REFERENCES = rdflib.term.URIRef("http://mlcommons.org/schema/references") ML_COMMONS_REGEX = rdflib.term.URIRef("http://mlcommons.org/schema/regex") +ML_COMMONS_REPLACE = rdflib.term.URIRef("http://mlcommons.org/schema/replace") +ML_COMMONS_SEPARATOR = rdflib.term.URIRef("http://mlcommons.org/schema/separator") ML_COMMONS_SOURCE = rdflib.term.URIRef("http://mlcommons.org/schema/source") ML_COMMONS_SUB_FIELD = rdflib.term.URIRef("http://mlcommons.org/schema/SubField") @@ -56,6 +58,8 @@ ML_COMMONS_INCLUDES: "includes", ML_COMMONS_REFERENCES: "references", ML_COMMONS_REGEX: "regex", + ML_COMMONS_REPLACE: "replace", + ML_COMMONS_SEPARATOR: "separator", ML_COMMONS_SOURCE: "source", SCHEMA_ORG_CITATION: "citation", SCHEMA_ORG_CONTAINED_IN: "contained_in", diff --git a/python/ml_croissant/ml_croissant/_src/core/constants_test.py b/python/ml_croissant/ml_croissant/_src/core/constants_test.py new file mode 100644 index 00000000..087c3869 --- /dev/null +++ b/python/ml_croissant/ml_croissant/_src/core/constants_test.py @@ -0,0 +1,14 @@ +"""constants_test module.""" + +from ml_croissant._src.core.constants import TO_CROISSANT + + +def test_to_croissant_values_are_unique(): + deja_vu = {} + for key, value in TO_CROISSANT.items(): + if value in deja_vu: + raise ValueError( + f"Keys {key} and {deja_vu[value]} define the same Croissant value:" + f" {value}." + ) + deja_vu[value] = key diff --git a/python/ml_croissant/ml_croissant/_src/core/issues.py b/python/ml_croissant/ml_croissant/_src/core/issues.py index bc594279..cc8a2b52 100644 --- a/python/ml_croissant/ml_croissant/_src/core/issues.py +++ b/python/ml_croissant/ml_croissant/_src/core/issues.py @@ -1,6 +1,5 @@ """issues module.""" -import contextlib import dataclasses @@ -8,7 +7,7 @@ class ValidationError(Exception): """Error during the validation of the format.""" -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Context: """Context to identify an issue. @@ -39,33 +38,32 @@ class Issues: errors: set[str] = dataclasses.field(default_factory=set, hash=False) warnings: set[str] = dataclasses.field(default_factory=set, hash=False) - _local_context: Context = dataclasses.field(default_factory=Context, hash=False) - def _wrap_in_local_context(self, issue: str) -> str: + def _wrap_in_context(self, context: Context | None, issue: str) -> str: + if context is None: + return issue local_context = [] - if self._local_context.dataset_name is not None: - local_context.append(f"dataset({self._local_context.dataset_name})") - if self._local_context.distribution_name is not None: - local_context.append( - f"distribution({self._local_context.distribution_name})" - ) - if self._local_context.record_set_name is not None: - local_context.append(f"record_set({self._local_context.record_set_name})") - if self._local_context.field_name is not None: - local_context.append(f"field({self._local_context.field_name})") - if self._local_context.sub_field_name is not None: - local_context.append(f"sub_field({self._local_context.sub_field_name})") + if context.dataset_name is not None: + local_context.append(f"dataset({context.dataset_name})") + if context.distribution_name is not None: + local_context.append(f"distribution({context.distribution_name})") + if context.record_set_name is not None: + local_context.append(f"record_set({context.record_set_name})") + if context.field_name is not None: + local_context.append(f"field({context.field_name})") + if context.sub_field_name is not None: + local_context.append(f"sub_field({context.sub_field_name})") if not local_context: return issue return f"[{' > '.join(local_context)}] {issue}" - def add_error(self, error: str): + def add_error(self, error: str, context: Context | None = None): """Mutates self.errors with a new error.""" - self.errors.add(self._wrap_in_local_context(error)) + self.errors.add(self._wrap_in_context(context, error)) - def add_warning(self, warning: str): + def add_warning(self, warning: str, context: Context | None = None): """Mutates self.warnings with a new warning.""" - self.warnings.add(self._wrap_in_local_context(warning)) + self.warnings.add(self._wrap_in_context(context, warning)) def report(self) -> str: """Reports errors and warnings in a string.""" @@ -84,54 +82,3 @@ def report(self) -> str: for issue in issues: message += f" - {issue}\n" return message.strip() - - @contextlib.contextmanager - def context( - self, - *, - dataset_name: str | None = None, - distribution_name: str | None = None, - record_set_name: str | None = None, - field_name: str | None = None, - sub_field_name: str | None = None, - ): - """Context manager to add a string context to each error/warning. - - Usage: - This prints the error "[dataset(abc)] xyz": - ``` - with issues.context(dataset_id=abc"): - issues.add_error("xyz") - ``` - """ - tmp_dataset_name = self._local_context.dataset_name - tmp_distribution_name = self._local_context.distribution_name - tmp_record_set_name = self._local_context.record_set_name - tmp_field_name = self._local_context.field_name - tmp_sub_field_name = self._local_context.sub_field_name - - self._local_context.dataset_name = ( - dataset_name if dataset_name is not None else tmp_dataset_name - ) - self._local_context.distribution_name = ( - distribution_name - if distribution_name is not None - else tmp_distribution_name - ) - self._local_context.record_set_name = ( - record_set_name if record_set_name is not None else tmp_record_set_name - ) - self._local_context.field_name = ( - field_name if field_name is not None else tmp_field_name - ) - self._local_context.sub_field_name = ( - sub_field_name if sub_field_name is not None else tmp_sub_field_name - ) - - yield - - self._local_context.dataset_name = tmp_dataset_name - self._local_context.distribution_name = tmp_distribution_name - self._local_context.record_set_name = tmp_record_set_name - self._local_context.field_name = tmp_field_name - self._local_context.sub_field_name = tmp_sub_field_name diff --git a/python/ml_croissant/ml_croissant/_src/core/issues_test.py b/python/ml_croissant/ml_croissant/_src/core/issues_test.py index f0c59ff0..ead94b06 100644 --- a/python/ml_croissant/ml_croissant/_src/core/issues_test.py +++ b/python/ml_croissant/ml_croissant/_src/core/issues_test.py @@ -2,7 +2,7 @@ import textwrap -from ml_croissant._src.core.issues import Issues +from ml_croissant._src.core.issues import Context, Issues def test_issues(): @@ -11,47 +11,29 @@ def test_issues(): assert not issues.warnings # With context - with issues.context(dataset_name="abc"): - issues.add_error("foo") - issues.add_warning("bar") + issues.add_error("foo", Context(dataset_name='abc')) + issues.add_warning("bar", Context(dataset_name='abc', distribution_name='xyz')) assert issues.errors == {"[dataset(abc)] foo"} - assert issues.warnings == {"[dataset(abc)] bar"} - - # With nested context - with issues.context(dataset_name="abc", distribution_name="xyz"): - issues.add_error("foo") - issues.add_warning("bar") - assert issues.errors == { - "[dataset(abc)] foo", - "[dataset(abc) > distribution(xyz)] foo", - } - assert issues.warnings == { - "[dataset(abc)] bar", - "[dataset(abc) > distribution(xyz)] bar", - } + assert issues.warnings == {"[dataset(abc) > distribution(xyz)] bar"} # Without context issues.add_error("foo") issues.add_warning("bar") assert issues.errors == { "[dataset(abc)] foo", - "[dataset(abc) > distribution(xyz)] foo", "foo", } assert issues.warnings == { - "[dataset(abc)] bar", "[dataset(abc) > distribution(xyz)] bar", "bar", } # Final report assert issues.report() == textwrap.dedent( - """Found the following 3 error(s) during the validation: - - [dataset(abc) > distribution(xyz)] foo + """Found the following 2 error(s) during the validation: - [dataset(abc)] foo - foo -Found the following 3 warning(s) during the validation: +Found the following 2 warning(s) during the validation: - [dataset(abc) > distribution(xyz)] bar - - [dataset(abc)] bar - bar""" ) diff --git a/python/ml_croissant/ml_croissant/_src/data_types.py b/python/ml_croissant/ml_croissant/_src/data_types.py deleted file mode 100644 index 35e6e9e7..00000000 --- a/python/ml_croissant/ml_croissant/_src/data_types.py +++ /dev/null @@ -1,15 +0,0 @@ -"""data_types module.""" - -from ml_croissant._src.core import constants -import pandas as pd - -EXPECTED_DATA_TYPES: dict[str, type] = { - constants.SCHEMA_ORG_DATA_TYPE_BOOL: bool, - constants.SCHEMA_ORG_DATA_TYPE_DATE: pd.DatetimeTZDtype, - constants.SCHEMA_ORG_DATA_TYPE_FLOAT: float, - constants.SCHEMA_ORG_DATA_TYPE_INTEGER: int, - constants.SCHEMA_ORG_DATA_TYPE_TEXT: str, - constants.SCHEMA_ORG_DATA_TYPE_URL: str, - constants.SCHEMA_ORG_EMAIL: str, - constants.SCHEMA_ORG_URL: str, -} diff --git a/python/ml_croissant/ml_croissant/_src/datasets.py b/python/ml_croissant/ml_croissant/_src/datasets.py index 55125ded..9dad227c 100644 --- a/python/ml_croissant/ml_croissant/_src/datasets.py +++ b/python/ml_croissant/ml_croissant/_src/datasets.py @@ -3,40 +3,28 @@ from collections.abc import Mapping import dataclasses -import json from typing import Any from absl import logging from etils import epath +from ml_croissant._src.core.graphs import utils as graphs_utils from ml_croissant._src.core.issues import Issues, ValidationError -from ml_croissant._src.rdf_graph import graph from ml_croissant._src.operation_graph import ( - build_structure_graph, ComputationGraph, ) from ml_croissant._src.operation_graph.operations import ( GroupRecordSet, ReadField, ) +from ml_croissant._src.structure_graph.graph import ( + from_file_to_json, + from_json_to_jsonld, + from_jsonld_to_nodes, + from_nodes_to_structure_graph, +) import networkx as nx -def _load_file(filepath: epath.PathLike) -> tuple[epath.Path, dict]: - """Loads the file. - - Args: - filepath: the path to the Croissant file. - - Returns: - A tuple with the path to the file and the file content. - """ - filepath = epath.Path(filepath).expanduser().resolve() - if not filepath.exists(): - raise ValueError(f"File {filepath} does not exist.") - with filepath.open() as filedescriptor: - return filepath, json.load(filedescriptor) - - @dataclasses.dataclass class Validator: """Static analysis of the issues in the Croissant file.""" @@ -46,14 +34,22 @@ class Validator: file: dict = dataclasses.field(init=False) operations: ComputationGraph | None = None - def run_static_analysis(self): + def run_static_analysis(self, debug: bool = False): try: - file_path, self.file = _load_file(self.file_or_file_path) - rdf_graph, rdf_nx_graph = graph.load_rdf_graph(self.file) - rdf_namespace_manager = rdf_graph.namespace_manager - nodes = graph.check_rdf_graph(self.issues, rdf_nx_graph) - - entry_node, structure_graph = build_structure_graph(self.issues, nodes) + file_path, self.file = from_file_to_json(self.file_or_file_path) + ns, json_ld = from_json_to_jsonld(self.file) + nodes, parents = from_jsonld_to_nodes(self.issues, json_ld) + # Print all nodes for debugging purposes. + if debug: + logging.info('Found the following nodes during static analysis.') + for node in nodes: + logging.info(node) + entry_node, structure_graph = from_nodes_to_structure_graph( + self.issues, nodes, parents + ) + # Draw the structure graph for debugging purposes. + if debug: + graphs_utils.pretty_print_graph(structure_graph, simplify=True) # Feature toggling: do not check for MovieLens, because we need more # features. if entry_node.uid == "Movielens-25M": @@ -63,7 +59,7 @@ def run_static_analysis(self): metadata=entry_node, graph=structure_graph, croissant_folder=file_path.parent, - rdf_namespace_manager=rdf_namespace_manager, + rdf_namespace_manager=ns, ) self.operations.check_graph() except Exception as exception: @@ -86,11 +82,12 @@ class Dataset: file: epath.PathLike operations: ComputationGraph | None = None + debug: bool = False def __post_init__(self): """Runs the static analysis of `file`.""" self.validator = Validator(self.file) - self.validator.run_static_analysis() + self.validator.run_static_analysis(debug=self.debug) self.file = self.validator.file self.operations = self.validator.operations diff --git a/python/ml_croissant/ml_croissant/_src/errors.py b/python/ml_croissant/ml_croissant/_src/errors.py deleted file mode 100644 index a1e21489..00000000 --- a/python/ml_croissant/ml_croissant/_src/errors.py +++ /dev/null @@ -1,137 +0,0 @@ -"""errors module.""" - -import contextlib -import dataclasses - - -class ValidationError(Exception): - """Error during the validation of the format.""" - - -@dataclasses.dataclass -class Context: - """Context to identify an issue. - - This allows to add context to an issue by tracing it back: - - within a given dataset, - - within a given distribution, - - within a given record set, - - within a given field, - - within a given sub field. - """ - - dataset_name: str | None = None - distribution_name: str | None = None - record_set_name: str | None = None - field_name: str | None = None - sub_field_name: str | None = None - - -@dataclasses.dataclass(frozen=True) -class Issues: - """ - Issues during the validation of the format. - - Issues can either be errors (blocking) or warnings (informative). - - We use sets to represent errors and warnings to avoid repeated strings. - """ - - errors: set[str] = dataclasses.field(default_factory=set, hash=False) - warnings: set[str] = dataclasses.field(default_factory=set, hash=False) - _local_context: Context = dataclasses.field(default_factory=Context, hash=False) - - def _wrap_in_local_context(self, issue: str) -> str: - local_context = [] - if self._local_context.dataset_name is not None: - local_context.append(f"dataset({self._local_context.dataset_name})") - if self._local_context.distribution_name is not None: - local_context.append( - f"distribution({self._local_context.distribution_name})" - ) - if self._local_context.record_set_name is not None: - local_context.append(f"record_set({self._local_context.record_set_name})") - if self._local_context.field_name is not None: - local_context.append(f"field({self._local_context.field_name})") - if self._local_context.sub_field_name is not None: - local_context.append(f"sub_field({self._local_context.sub_field_name})") - if not local_context: - return issue - return f"[{' > '.join(local_context)}] {issue}" - - def add_error(self, error: str): - """Mutates self.errors with a new error.""" - self.errors.add(self._wrap_in_local_context(error)) - - def add_warning(self, warning: str): - """Mutates self.warnings with a new warning.""" - self.warnings.add(self._wrap_in_local_context(warning)) - - def report(self) -> str: - """Reports errors and warnings in a string.""" - message = "" - # Sort before printing because sets are not ordered. - for issues, issue_type in [ - (sorted(self.errors), "error(s)"), - (sorted(self.warnings), "warning(s)"), - ]: - num_issues = len(issues) - if num_issues: - message += ( - f"Found the following {len(issues)} {issue_type} during the" - " validation:\n" - ) - for issue in issues: - message += f" - {issue}\n" - return message.strip() - - @contextlib.contextmanager - def context( - self, - *, - dataset_name: str | None = None, - distribution_name: str | None = None, - record_set_name: str | None = None, - field_name: str | None = None, - sub_field_name: str | None = None, - ): - """Context manager to add a string context to each error/warning. - - Usage: - This prints the error "[dataset(abc)] xyz": - ``` - with issues.context(dataset_id=abc"): - issues.add_error("xyz") - ``` - """ - tmp_dataset_name = self._local_context.dataset_name - tmp_distribution_name = self._local_context.distribution_name - tmp_record_set_name = self._local_context.record_set_name - tmp_field_name = self._local_context.field_name - tmp_sub_field_name = self._local_context.sub_field_name - - self._local_context.dataset_name = ( - dataset_name if dataset_name is not None else tmp_dataset_name - ) - self._local_context.distribution_name = ( - distribution_name - if distribution_name is not None - else tmp_distribution_name - ) - self._local_context.record_set_name = ( - record_set_name if record_set_name is not None else tmp_record_set_name - ) - self._local_context.field_name = ( - field_name if field_name is not None else tmp_field_name - ) - self._local_context.sub_field_name = ( - sub_field_name if sub_field_name is not None else tmp_sub_field_name - ) - - yield - - self._local_context.dataset_name = tmp_dataset_name - self._local_context.distribution_name = tmp_distribution_name - self._local_context.record_set_name = tmp_record_set_name - self._local_context.field_name = tmp_field_name - self._local_context.sub_field_name = tmp_sub_field_name diff --git a/python/ml_croissant/ml_croissant/_src/errors_test.py b/python/ml_croissant/ml_croissant/_src/errors_test.py deleted file mode 100644 index f5168440..00000000 --- a/python/ml_croissant/ml_croissant/_src/errors_test.py +++ /dev/null @@ -1,57 +0,0 @@ -"""errors_test module.""" - -import textwrap - -from ml_croissant._src.core.issues import Issues - - -def test_issues(): - issues = Issues() - assert not issues.errors - assert not issues.warnings - - # With context - with issues.context(dataset_name="abc"): - issues.add_error("foo") - issues.add_warning("bar") - assert issues.errors == {"[dataset(abc)] foo"} - assert issues.warnings == {"[dataset(abc)] bar"} - - # With nested context - with issues.context(dataset_name="abc", distribution_name="xyz"): - issues.add_error("foo") - issues.add_warning("bar") - assert issues.errors == { - "[dataset(abc)] foo", - "[dataset(abc) > distribution(xyz)] foo", - } - assert issues.warnings == { - "[dataset(abc)] bar", - "[dataset(abc) > distribution(xyz)] bar", - } - - # Without context - issues.add_error("foo") - issues.add_warning("bar") - assert issues.errors == { - "[dataset(abc)] foo", - "[dataset(abc) > distribution(xyz)] foo", - "foo", - } - assert issues.warnings == { - "[dataset(abc)] bar", - "[dataset(abc) > distribution(xyz)] bar", - "bar", - } - - # Final report - assert issues.report() == textwrap.dedent( - """Found the following 3 error(s) during the validation: - - [dataset(abc) > distribution(xyz)] foo - - [dataset(abc)] foo - - foo -Found the following 3 warning(s) during the validation: - - [dataset(abc) > distribution(xyz)] bar - - [dataset(abc)] bar - - bar""" - ) diff --git a/python/ml_croissant/ml_croissant/_src/graphs.py b/python/ml_croissant/ml_croissant/_src/graphs.py deleted file mode 100644 index 8820a859..00000000 --- a/python/ml_croissant/ml_croissant/_src/graphs.py +++ /dev/null @@ -1,84 +0,0 @@ -"""graphs module.""" - -from __future__ import annotations - -import typing - -from ml_croissant._src.core import constants -from ml_croissant._src.core.issues import Issues -from ml_croissant._src.structure_graph.graph import ( - children_nodes, - from_rdf_graph, -) -import networkx as nx -import rdflib -from rdflib.extras import external_graph_libs - -if typing.TYPE_CHECKING: - from ml_croissant._src.structure_graph.base_node import Node - - -def load_rdf_graph(dict_dataset: dict) -> tuple[rdflib.Graph, nx.MultiDiGraph]: - """Parses RDF graph with NetworkX from a dict.""" - graph = rdflib.Graph() - graph.parse( - data=dict_dataset, - format="json-ld", - ) - return graph, external_graph_libs.rdflib_to_networkx_multidigraph(graph) - - -def _find_entry_object(issues: Issues, graph: nx.MultiDiGraph) -> rdflib.term.BNode: - """Finds the source entry node without any parent.""" - sources = [ - node - for node, indegree in graph.in_degree(graph.nodes()) - if indegree == 0 and isinstance(node, rdflib.term.BNode) - ] - if len(sources) != 1: - issues.add_error("Trying to define more than one dataset in the file.") - return sources[0] - - -def check_rdf_graph(issues: Issues, graph: nx.MultiDiGraph) -> list[Node]: - """Validates the graph and populates issues with errors/warnings. - - We first build a NetworkX graph where edges are subject->object with the attribute - `property`. - - Subject/object/property are RDF triples: - - `subject`is an ID instanciated by RDFLib. - - `property` (aka predicate) denotes the relationship (e.g., - `https://schema.org/description`). - - `object` is either the value (e.g., the description) or another `subject`. - - Refer to https://www.w3.org/TR/rdf-concepts to learn more. - - Args: - issues: the issues that will be modified in-place. - graph: The NetworkX RDF graph to validate. - """ - # Check RDF properties in nodes - source = _find_entry_object(issues, graph) - metadata = from_rdf_graph(issues, graph, source, None) - nodes = [metadata] - dataset_name = metadata.name - with issues.context(dataset_name=dataset_name, distribution_name=""): - distributions = children_nodes(metadata, constants.SCHEMA_ORG_DISTRIBUTION) - nodes += distributions - record_sets = children_nodes(metadata, constants.ML_COMMONS_RECORD_SET) - nodes += record_sets - for record_set in record_sets: - with issues.context( - dataset_name=dataset_name, - record_set_name=record_set.name, - field_name="", - ): - fields = children_nodes(record_set, constants.ML_COMMONS_FIELD) - nodes += fields - if len(fields) == 0: - issues.add_error("The node doesn't define any field.") - for field in fields: - sub_fields = children_nodes(field, constants.ML_COMMONS_SUB_FIELD) - nodes += sub_fields - return nodes diff --git a/python/ml_croissant/ml_croissant/_src/graphs_utils.py b/python/ml_croissant/ml_croissant/_src/graphs_utils.py deleted file mode 100644 index dc4e8da3..00000000 --- a/python/ml_croissant/ml_croissant/_src/graphs_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -"""graphs_utils module.""" - -import time - -import networkx as nx - - -def pretty_print_graph(graph: nx.Graph, simplify=False): - """Pretty prints a NetworkX graph. - - Args: - graph: Any NetworkX graph. - - Warning: this function is for debugging purposes only.""" - if simplify: - simple_graph = nx.Graph() - for x, y in graph.edges(): - x = getattr(x, "uid", x) - y = getattr(y, "uid", y) - simple_graph.add_edge(x, y) - graph = simple_graph - agraph = nx.nx_agraph.to_agraph(graph) - agraph.layout(prog="dot") - temporary_file = f"/tmp/graph_{time.time()}.png" - agraph.draw(temporary_file, args="-Gnodesep=0.01 -Gfont_size=1", prog="dot") - print(f"Generated a graph and saved it in: {temporary_file}") - - -def print_graph_traversal(graph: nx.Graph): - """Pretty prints a NetworkX graph. - - Args: - graph: Any NetworkX graph. - - Warning: this function is for debugging purposes only.""" - visited = {} - print("--- Graph traversal ---") - for start, end, _ in nx.edge_bfs(graph): - for node in [start, end]: - if node.name not in visited: - print(f"Visited: {node.name}") - visited[node.name] = True - print("Done traversing the graph.") diff --git a/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py b/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py index bd6e6aaa..4999c63a 100644 --- a/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py +++ b/python/ml_croissant/ml_croissant/_src/operation_graph/__init__.py @@ -1,7 +1,5 @@ -from ml_croissant._src.operation_graph.graph import build_structure_graph from ml_croissant._src.operation_graph.graph import ComputationGraph __all__ = [ - "build_structure_graph", "ComputationGraph", ] diff --git a/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py b/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py index 0457fd53..abee23df 100644 --- a/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py +++ b/python/ml_croissant/ml_croissant/_src/operation_graph/graph.py @@ -4,13 +4,11 @@ import dataclasses from etils import epath -from ml_croissant._src.core import constants from ml_croissant._src.core.issues import Issues from ml_croissant._src.structure_graph.nodes import ( Field, FileObject, FileSet, - Metadata, RecordSet, ) from ml_croissant._src.operation_graph.base_operation import Operation @@ -26,14 +24,11 @@ ReadField, ) from ml_croissant._src.structure_graph.base_node import Node +from ml_croissant._src.structure_graph.graph import get_entry_nodes import networkx as nx from rdflib import namespace -def concatenate_uid(source: tuple[str]) -> str: - return "/".join(source) - - def _find_record_set(graph: nx.MultiDiGraph, node: Node) -> RecordSet: """Finds the record set to which a field is attached. @@ -224,129 +219,3 @@ def check_graph(self): self.issues.add_error( f"The following operations refered to themselves: {selfloops}" ) - - -def get_entry_nodes(issues: Issues, graph: nx.MultiDiGraph) -> list[Node]: - """Retrieves the entry nodes (without predecessors) in a graph.""" - entry_nodes = [] - for node, indegree in graph.in_degree(graph.nodes()): - if indegree == 0: - entry_nodes.append(node) - # Fields should usually not be entry nodes, except if they have subFields. So we - # check for this: - for node in entry_nodes: - if isinstance(node, Field) and not node.has_sub_fields: - issues.add_error( - f'Node "{node.uid}" is a field and has no source. Please, use' - f" {constants.ML_COMMONS_SOURCE} to specify the source." - ) - return entry_nodes - - -def _check_no_duplicate(issues: Issues, nodes: list[Node]) -> Mapping[str, Node]: - """Checks that no node has duplicated UID and returns the mapping `uid`->`Node`.""" - uid_to_node: Mapping[str, Node] = {} - for node in nodes: - if node.uid in uid_to_node: - issues.add_error(f"Duplicate node with the same identifier: {node.uid}") - uid_to_node[node.uid] = node - return uid_to_node - - -def add_node_as_entry_node(issues: Issues, graph: nx.MultiDiGraph, node: Node): - """Add `node` as the entry node of the graph by updating `graph` in place.""" - graph.add_node(node, parent=None) - entry_nodes = get_entry_nodes(issues, graph) - for entry_node in entry_nodes: - if isinstance(node, (FileObject, FileSet)): - graph.add_edge(entry_node, node) - - -def add_edge( - issues: Issues, - graph: nx.MultiDiGraph, - uid_to_node: Mapping[str, Node], - uid: str, - node: Node, - expected_types: type | tuple[type], -): - if uid not in uid_to_node: - issues.add_error( - f'There is a reference to node named "{uid}" in node "{node.uid}", but this' - " node doesn't exist." - ) - return - if not isinstance(uid_to_node[uid], expected_types): - issues.add_error( - f'There is a reference to node named "{uid}" in node "{node.uid}", but this' - f" node doesn't have the expected type: {expected_types}." - ) - return - graph.add_edge(uid_to_node[uid], node) - - -def build_structure_graph( - issues: Issues, nodes: list[Node] -) -> tuple[Node, nx.MultiDiGraph]: - """Builds the structure graph from the nodes. - - The structure graph represents the relationship between the nodes: - - - For ml:Fields without ml:subField, the predecessors in the structure graph are the - sources. - - For sc:FileSet or sc:FileObject with a `containedIn`, the predecessors in the - structure graph are those `containedId`. - - For other objects, the predecessors are their parents (i.e., predecessors in the - JSON-LD). For example: for ml:Field with subField, the predecessors are the - ml:RecordSet in which they are contained. - """ - graph = nx.MultiDiGraph() - uid_to_node = _check_no_duplicate(issues, nodes) - for node in nodes: - if isinstance(node, Metadata): - continue - parent = uid_to_node[node.parent_uid] - graph.add_node(node, parent=parent) - # Distribution - if isinstance(node, (FileObject, FileSet)) and node.contained_in: - for uid in node.contained_in: - add_edge(issues, graph, uid_to_node, uid, node, (FileObject, FileSet)) - # Fields - elif isinstance(node, Field): - references = [] - if node.source is not None: - references.append(node.source.reference) - if node.references is not None: - references.append(node.references.reference) - for reference in references: - # The source can be either another field... - if (uid := concatenate_uid(reference)) in uid_to_node: - # Record sets are not valid parents here. - # The case can arise when a Field references a record set to have a - # machine-readable explanation of the field (see datasets/titanic - # for example). - if not isinstance(uid_to_node[uid], RecordSet): - add_edge(issues, graph, uid_to_node, uid, node, Node) - # ...or the source can be a metadata. - elif (uid := reference[0]) in uid_to_node: - if not isinstance(uid_to_node[uid], RecordSet): - add_edge( - issues, graph, uid_to_node, uid, node, (FileObject, FileSet) - ) - else: - issues.add_error( - "Source refers to an unknown node" - f' "{concatenate_uid(reference)}".' - ) - # Other nodes - elif node.parent_uid is not None: - add_edge(issues, graph, uid_to_node, node.parent_uid, node, Node) - # `Metadata` are used as the entry node. - metadata = next((node for node in nodes if isinstance(node, Metadata)), None) - if metadata is None: - issues.add_error("No metadata is defined in the dataset.") - return None, graph - add_node_as_entry_node(issues, graph, metadata) - if not graph.is_directed(): - issues.add_error("Structure graph is not directed.") - return metadata, graph diff --git a/python/ml_croissant/ml_croissant/_src/rdf_graph/__init__.py b/python/ml_croissant/ml_croissant/_src/rdf_graph/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/ml_croissant/ml_croissant/_src/rdf_graph/graph.py b/python/ml_croissant/ml_croissant/_src/rdf_graph/graph.py deleted file mode 100644 index 9de74845..00000000 --- a/python/ml_croissant/ml_croissant/_src/rdf_graph/graph.py +++ /dev/null @@ -1,84 +0,0 @@ -"""RDF graph module.""" - -from __future__ import annotations - -import typing - -from ml_croissant._src.core import constants -from ml_croissant._src.core.issues import Issues -from ml_croissant._src.structure_graph.graph import ( - children_nodes, - from_rdf_graph, -) -import networkx as nx -import rdflib -from rdflib.extras import external_graph_libs - -if typing.TYPE_CHECKING: - from ml_croissant._src.structure_graph.base_node import Node - - -def load_rdf_graph(dict_dataset: dict) -> tuple[rdflib.Graph, nx.MultiDiGraph]: - """Parses RDF graph with NetworkX from a dict.""" - graph = rdflib.Graph() - graph.parse( - data=dict_dataset, - format="json-ld", - ) - return graph, external_graph_libs.rdflib_to_networkx_multidigraph(graph) - - -def _find_entry_object(issues: Issues, graph: nx.MultiDiGraph) -> rdflib.term.BNode: - """Finds the source entry node without any parent.""" - sources = [ - node - for node, indegree in graph.in_degree(graph.nodes()) - if indegree == 0 and isinstance(node, rdflib.term.BNode) - ] - if len(sources) != 1: - issues.add_error("Trying to define more than one dataset in the file.") - return sources[0] - - -def check_rdf_graph(issues: Issues, graph: nx.MultiDiGraph) -> list[Node]: - """Validates the graph and populates issues with errors/warnings. - - We first build a NetworkX graph where edges are subject->object with the attribute - `property`. - - Subject/object/property are RDF triples: - - `subject`is an ID instanciated by RDFLib. - - `property` (aka predicate) denotes the relationship (e.g., - `https://schema.org/description`). - - `object` is either the value (e.g., the description) or another `subject`. - - Refer to https://www.w3.org/TR/rdf-concepts to learn more. - - Args: - issues: the issues that will be modified in-place. - graph: The NetworkX RDF graph to validate. - """ - # Check RDF properties in nodes - source = _find_entry_object(issues, graph) - metadata = from_rdf_graph(issues, graph, source, None) - nodes = [metadata] - dataset_name = metadata.name - with issues.context(dataset_name=dataset_name, distribution_name=""): - distributions = children_nodes(metadata, constants.SCHEMA_ORG_DISTRIBUTION) - nodes += distributions - record_sets = children_nodes(metadata, constants.ML_COMMONS_RECORD_SET) - nodes += record_sets - for record_set in record_sets: - with issues.context( - dataset_name=dataset_name, - record_set_name=record_set.name, - field_name="", - ): - fields = children_nodes(record_set, constants.ML_COMMONS_FIELD) - nodes += fields - if len(fields) == 0: - issues.add_error("The node doesn't define any field.") - for field in fields: - sub_fields = children_nodes(field, constants.ML_COMMONS_SUB_FIELD) - nodes += sub_fields - return nodes diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/base_node.py b/python/ml_croissant/ml_croissant/_src/structure_graph/base_node.py index 8c166bb1..7c16c2f6 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/base_node.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/base_node.py @@ -1,5 +1,6 @@ """Base node module.""" +import abc import dataclasses import re @@ -7,14 +8,14 @@ import rdflib from ml_croissant._src.core import constants -from ml_croissant._src.core.issues import Issues +from ml_croissant._src.core.issues import Context, Issues ID_REGEX = "[a-zA-Z0-9\\-_\\.]+" _MAX_ID_LENGTH = 255 -@dataclasses.dataclass(frozen=True) -class Node: +@dataclasses.dataclass(frozen=True, repr=False) +class Node(abc.ABC): """Structure node in Croissant. This generic class will be inherited by the actual Croissant nodes: @@ -31,8 +32,12 @@ class Node: graph: The NetworkX RDF graph to validate. node: The node in the graph to convert. name: The name of the node. - parent_uid: UID of the parent node if it exists. This is the parent in the - JSON-LD structure, whereas `sources` are the parents in the resource tree. + rdf_id: The RDF @id created by RDFLib. + uid: Croissant unique identifier. It's the concatenation of the path within + the Croissant hierarchy. For instance, for a field: + dataset.name/record_set.name/field.name. + context: Context of the node in the Croissant hierarchy (dataset, distribution, + record set, field). Usage: @@ -53,10 +58,12 @@ class Node: """ issues: Issues - graph: nx.MultiDiGraph - node: rdflib.term.BNode + graph: nx.MultiDiGraph = None + node: rdflib.term.BNode = None name: str = "" - parent_uid: str | None = None + rdf_id: str | None = None + uid: str | None = None + context: Context | None = None def __post_init__(self): """Checks for `name` (common property between all nodes).""" @@ -67,20 +74,6 @@ def __post_init__(self): def _edges_from_node(self): return self.graph.edges(self.node, keys=True) - @property - def uid(self): - """Creates a UID from the name. - - For fields, the UID cannot be the name, as a dataset - can contain two fields with the same name if they are - in different record sets for instancd. - """ - is_field = hasattr(self, 'has_sub_fields') - if is_field: - # Concatenate all names except the dataset name. - return f"{self.parent_uid}/{self.name}" - return self.name - def assert_has_mandatory_properties(self, *mandatory_properties: list[str]): """Checks a node in the graph for existing properties with constraints. @@ -95,7 +88,7 @@ def assert_has_mandatory_properties(self, *mandatory_properties: list[str]): f'Property "{constants.FROM_CROISSANT.get(mandatory_property)}" is' " mandatory, but does not exist." ) - self.issues.add_error(error) + self.add_error(error) def assert_has_optional_properties(self, *optional_properties: list[str]): """Checks a node in the graph for existing properties with constraints. @@ -111,7 +104,7 @@ def assert_has_optional_properties(self, *optional_properties: list[str]): f'Property "{constants.FROM_CROISSANT.get(optional_property)}" is' " recommended, but does not exist." ) - self.issues.add_warning(error) + self.add_warning(error) def assert_has_exclusive_properties(self, *exclusive_properties: list[list[str]]): """Checks a node in the graph for existing properties with constraints. @@ -128,7 +121,29 @@ def assert_has_exclusive_properties(self, *exclusive_properties: list[list[str]] "At least one of these properties should be defined:" f" {possible_exclusive_properties}." ) - self.issues.add_error(error) + self.add_error(error) + + def add_error(self, error: str): + """Adds a new error.""" + self.issues.add_error(error, self.context) + + def add_warning(self, warning: str): + """Adds a new warning.""" + self.issues.add_warning(warning, self.context) + + @abc.abstractmethod + def check(self): + raise NotImplementedError + + def __repr__(self): + attributes = self.__dict__.copy() + attributes_to_remove = ["context", "graph", "issues", "node"] + for attribute in attributes_to_remove: + if attribute in attributes: + del attributes[attribute] + attributes_items = sorted(list(attributes.items())) + attributes_str = ", ".join(f"{key}={value}" for key, value in attributes_items) + return f"{self.__class__.__name__}({attributes_str})" def validate_name(issues: Issues, name: str): @@ -144,6 +159,7 @@ def validate_name(issues: Issues, name: str): issues.add_error(f'The identifier "{name}" contains forbidden characters.') return name + def there_exists_at_least_one_property(node: Node, possible_properties: list[str]): """Checks for the existence of one of `possible_exclusive_properties` in `keys`.""" for possible_property in possible_properties: diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/base_node_test.py b/python/ml_croissant/ml_croissant/_src/structure_graph/base_node_test.py index 31caf528..d982c830 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/base_node_test.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/base_node_test.py @@ -1,9 +1,8 @@ """base_node_test module.""" import dataclasses -import json -from etils import epath +from ml_croissant._src.core.issues import Issues from ml_croissant._src.structure_graph import base_node @@ -14,7 +13,20 @@ class Node: property2: str node = Node(property1="property1", property2="property2") - # pylint:disable=protected-access - assert base_node.there_exists_at_least_one_property(node, ["property0", "property1"]) + assert base_node.there_exists_at_least_one_property( + node, ["property0", "property1"] + ) assert not base_node.there_exists_at_least_one_property(node, []) assert not base_node.there_exists_at_least_one_property(node, ["property0"]) + + +def test_repr(): + @dataclasses.dataclass(frozen=True, repr=False) + class MyNode(base_node.Node): + foo: str = "" + + def check(self): + pass + + node = MyNode(issues=Issues(), name="NAME", foo="bar", rdf_id="RDF_IR", uid="UID") + assert str(node) == "MyNode(foo=bar, name=NAME, rdf_id=RDF_IR, uid=UID)" diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/graph.py b/python/ml_croissant/ml_croissant/_src/structure_graph/graph.py index 17da7878..e371bcb3 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/graph.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/graph.py @@ -1,9 +1,28 @@ -"""Structure graph module.""" +"""Structure graph module. -from typing import Any, Mapping +The goal of this module is the static analysis of the JSON file. We convert the initial +JSON to a so-called "structure graph", which is a Python representation of the JSON +containing the nodes (Metadata, FileObject, etc) and the hierarchy between them. In the +process of parsing all the nodes, we also check that no information is missing and raise +issues (errors or warnings) when necessary. See the docstring of +`from_nodes_to_structure_graph` for more information. +The important functions of this module are: +- from_file_to_json file -> JSON +- from_json_to_jsonld JSON -> JSON-LD +- from_jsonld_to_nodes JSON-LD -> nodes +- from_nodes_to_structure_graph nodes -> structure graph +""" + +from collections.abc import Mapping +import dataclasses +import json +from typing import Any + +from etils import epath from ml_croissant._src.core import constants -from ml_croissant._src.core.issues import Issues +from ml_croissant._src.core.issues import Context, Issues +from ml_croissant._src.structure_graph.base_node import Node from ml_croissant._src.structure_graph.nodes import ( Field, FileObject, @@ -15,160 +34,378 @@ from ml_croissant._src.structure_graph.nodes.source import parse_reference import networkx as nx import rdflib +from rdflib import namespace + +Json = dict[str, Any] | list["Json"] + +_EXPECTED_TYPES = [ + constants.SCHEMA_ORG_DATASET, + constants.SCHEMA_ORG_FILE_OBJECT, + constants.SCHEMA_ORG_FILE_SET, + constants.ML_COMMONS_RECORD_SET, + constants.ML_COMMONS_FIELD, + constants.ML_COMMONS_SUB_FIELD, +] -def _extract_properties( - issues: Issues, graph: nx.MultiDiGraph, node: rdflib.term.BNode -) -> Mapping[str, Any]: - """Extracts properties RDF->Python nodes. +def from_file_to_json(filepath: epath.PathLike) -> tuple[epath.Path, Json]: + """Loads the file as a JSON. - Note: we could find a better way to extract information from the RDF graph. + Args: + filepath: The path to the file as a str or a path. The path can be absolute or + relative. + + Returns: + A tuple with the absolute path to the file and the JSON. """ - properties: Mapping[str, str | tuple[str]] = {} - # pylint:disable=invalid-name - for _, _object, _property in graph.edges(node, keys=True): - if isinstance(_object, rdflib.term.BNode): - # `source` needs a special treatment when it is a dict. - if _property == constants.ML_COMMONS_SOURCE: - source = _extract_properties(issues, graph, _object) - properties["source"] = source - if _property == constants.ML_COMMONS_SUB_FIELD: - properties["has_sub_fields"] = True - continue + filepath = epath.Path(filepath).expanduser().resolve() + if not filepath.exists(): + raise ValueError(f"File {filepath} does not exist.") + with filepath.open() as f: + return filepath, json.load(f) - # Normalize values to strings. - if isinstance(_object, rdflib.term.Literal): - _object = str(_object) - # Normalize properties to Croissant values if it exists. - _property = constants.TO_CROISSANT.get(_property, _property) +def from_json_to_jsonld( + data: Json, +) -> tuple[namespace.NamespaceManager, Json]: + """Expands JSON->JSON-LD using RDFLib. - # Add `property` to existing properties. - if _property not in properties: - properties[_property] = _object - elif isinstance(properties[_property], tuple): - # Use tuple, because we need immutable types in order - # for the objects to be hashable and used by NetworkX. - properties[_property] = properties[_property] + (_object,) - else: - # In the loop, we just found out that there are several values for the same - # property. `self.properties[property]` should be transformed to a tuple. - properties[_property] = (properties[_property], _object) - - # Normalize `source`. - if (source := properties.get("source")) is not None: - properties["source"] = Source.from_json_ld(issues, source) - # Normalize `references`. - if (references := properties.get("references")) is not None: - properties["references"] = Source.from_json_ld(issues, references) - # Normalize `contained_in`. - if (contained_in := properties.get("contained_in")) is not None: + We use RDFLib instead of reinventing a JSON-LD parser. This may be more cumbersome + short-term, but will prove handy long-term, when we integrate more advanced feature + of RDF/JSON-LD, or other parsers (e.g., YAML-LD). + + Args: + data: The JSON dict. + + Returns: + A tuple with the RDF namespace manager (see: + https://rdflib.readthedocs.io/en/stable/namespaces_and_bindings.html) and + the expanded JSON-LD. + """ + graph = rdflib.Graph() + graph.parse( + data=data, + format="json-ld", + ) + ns = graph.namespace_manager + json_ld = graph.serialize(format="json-ld") + json_ld = json.loads(json_ld) + return ns, json_ld + + +def _get_uid(predecessors: list[Node]) -> str: + """Concatenates the names of all predecessors to get the UID.""" + if not predecessors: + raise ValueError( + "This should not happen, as predecessors of a node also contain the node" + " itself." + ) + node = predecessors[-1] + if isinstance(node, Metadata): + return node.name + names: list[str] = [] + for predecessor in predecessors: + predecessor_name = predecessor.name + if not isinstance(predecessor, Metadata): + names.append(predecessor_name) + return "/".join(names) + + +def _get_predecessors( + nodes: list[Node], node: Node, parents: Mapping[str, str] +) -> list[Node]: + """Lists predecessors in the Croissant hierarchy. + + For a field for example, the predecessors are: metadata > record set > field. + """ + node_id = node.rdf_id + if node_id not in parents: + return [node] + parent_id = parents[node_id] + parents = [_node for _node in nodes if _node.rdf_id == parent_id] + if not parents: + raise ValueError(f"Node {node} has no parent {parent_id}") + parent = parents[0] + predecessors_of_parent = _get_predecessors(nodes, parent, parents) + return predecessors_of_parent + [node] + + +def _get_context(predecessors: list[Node]) -> Context: + """Forms the context from the predecessors.""" + params = {} + for predecessor in predecessors: + if isinstance(predecessor, Metadata): + params["dataset_name"] = predecessor.name + elif isinstance(predecessor, (FileObject, FileSet)): + params["distribution_name"] = predecessor.name + elif isinstance(predecessor, RecordSet): + params["record_set_name"] = predecessor.name + elif isinstance(predecessor, Field): + params["field_name"] = predecessor.name + return Context(**params) + + +def _get_type(node: Json) -> str | None: + node_type = node.get("@type") + if not (isinstance(node_type, list) and node_type): + return None + return rdflib.term.URIRef(node_type[0]) + + +def _get_value(issues: Issues, json_ld: Json, value: Any): + """Helper for _parse_node_params.""" + values = [] + for element in value: + if "@value" in element: + values.append(element["@value"]) + elif "@id" in element: + # In that case, we reference another node, so we have to parse its params: + other_id = element["@id"] + other_node = next( + _node for _node in json_ld if _node.get("@id") == other_id + ) + values.append(_parse_node_params(issues, json_ld, other_node)) + # TODO(marcenacp): integrate the target type in TO_CROISSANT. + if len(values) == 1: + return values[0] + return tuple(values) + + +def _parse_node_params(issues: Issues, json_ld: Json, node: Json) -> Json: + """Recursively parses all information from a node to Croissant.""" + node_params = {} + node_type = _get_type(node) + if node_type == constants.ML_COMMONS_FIELD: + node_params["has_sub_fields"] = str(constants.ML_COMMONS_SUB_FIELD) in node + # Parse values. + for key, value in node.items(): + key = rdflib.term.URIRef(key) + if key in constants.TO_CROISSANT: + croissant_key = constants.TO_CROISSANT[key] + node_params[croissant_key] = _get_value(issues, json_ld, value) + # Parse `source`. + if (source := node_params.get("source")) is not None: + node_params["source"] = Source.from_json_ld(issues, source) + # Parse `references`. + if (references := node_params.get("references")) is not None: + node_params["references"] = Source.from_json_ld(issues, references) + # Parse `contained_in`. + if (contained_in := node_params.get("contained_in")) is not None: if isinstance(contained_in, str): - properties["contained_in"] = parse_reference(issues, contained_in) + node_params["contained_in"] = parse_reference(issues, contained_in) else: - properties["contained_in"] = ( + node_params["contained_in"] = ( parse_reference(issues, reference)[0] for reference in contained_in ) - return properties + return node_params + + +def from_jsonld_to_nodes( + issues: Issues, json_ld: Json +) -> tuple[list[Node], Mapping[str, str]]: + """Converts JSON-LD to a list of Python-readable nodes. + Args: + issues: The issues to populate in case of problem. + json_ld: The parsed JSON-LD with expanded properties. -def from_rdf_graph( + Returns: + A tuple with the nodes and the parents (a dictionary: rdf_id -> parent_rdf_id). + """ + nodes: list[Node] = [] + parents: Mapping[str, str] = {} + for node in json_ld: + child_node_ids = [] + node_id = node.get("@id") + for possible_child in [ + constants.SCHEMA_ORG_DISTRIBUTION, + constants.ML_COMMONS_RECORD_SET, + constants.ML_COMMONS_FIELD, + constants.ML_COMMONS_SUB_FIELD, + constants.ML_COMMONS_SOURCE, + ]: + possible_child = str(possible_child) + if possible_child in node: + for id in node[possible_child]: + child_node_ids.append(id.get("@id")) + for child_node_id in child_node_ids: + parents[child_node_id] = node_id + for node in json_ld: + node_type = _get_type(node) + if node_type is None: + continue + if node_type == constants.SCHEMA_ORG_DATASET: + node_cls = Metadata + elif node_type == constants.SCHEMA_ORG_FILE_OBJECT: + node_cls = FileObject + elif node_type == constants.SCHEMA_ORG_FILE_SET: + node_cls = FileSet + elif node_type == constants.ML_COMMONS_FIELD: + node_cls = Field + elif node_type == constants.ML_COMMONS_RECORD_SET: + node_cls = RecordSet + else: + issues.add_error( + f'Node should have an attribute `"@type" in "{_EXPECTED_TYPES}"`. Got' + f' "{node_type}".' + ) + continue + + node_id = node.get("@id") + node_params = _parse_node_params(issues, json_ld, node) + try: + new_node = node_cls(issues=issues, rdf_id=node_id, **node_params) + nodes.append(new_node) + except TypeError: + # TODO(marcenacp): handle the exception with dataclasses.dataclass. + continue + # Recreate the nodes with the whole hierarchy. + nodes_with_parents: list[Node] = [] + for node in nodes: + predecessors = _get_predecessors(nodes, node, parents) + context = _get_context(predecessors) + node_with_parents = dataclasses.replace( + node, uid=_get_uid(predecessors), context=context + ) + # Static analysis of the node: + node_with_parents.check() + nodes_with_parents.append(node_with_parents) + return nodes_with_parents, parents + + +def get_entry_nodes(issues: Issues, graph: nx.MultiDiGraph) -> list[Node]: + """Retrieves the entry nodes (without predecessors) in a graph.""" + entry_nodes = [] + for node, indegree in graph.in_degree(graph.nodes()): + if indegree == 0: + entry_nodes.append(node) + # Fields should usually not be entry nodes, except if they have subFields. So we + # check for this: + for node in entry_nodes: + if isinstance(node, Field) and not node.has_sub_fields: + issues.add_error( + f'Node "{node.uid}" is a field and has no source. Please, use' + f" {constants.ML_COMMONS_SOURCE} to specify the source." + ) + return entry_nodes + + +def _check_no_duplicate(nodes: list[Node]) -> Mapping[str, Node]: + """Checks that no node has duplicated UID and returns the mapping `uid`->`Node`.""" + uid_to_node: Mapping[str, Node] = {} + for node in nodes: + if node.uid in uid_to_node: + node.add_error( + f"Duplicate nodes with the same identifier: {uid_to_node[node.uid]}" + ) + uid_to_node[node.uid] = node + return uid_to_node + + +def _add_node_as_entry_node(issues: Issues, graph: nx.MultiDiGraph, node: Node): + """Add `node` as the entry node of the graph by updating `graph` in place.""" + graph.add_node(node, parent=None) + entry_nodes = get_entry_nodes(issues, graph) + for entry_node in entry_nodes: + if isinstance(node, (FileObject, FileSet)): + graph.add_edge(entry_node, node) + + +def _add_edge( issues: Issues, graph: nx.MultiDiGraph, - node: rdflib.term.BNode, - parent_uid: str, + uid_to_node: Mapping[str, Node], + uid: str, + node: Node, + expected_types: type | tuple[type], ): - """Builds a Node from the provided graph.""" - properties = _extract_properties(issues, graph, node) - name = properties.get("name") - - # Check @type. - rdf_type = properties.get(constants.RDF_TYPE) - expected_types = [ - constants.SCHEMA_ORG_DATASET, - constants.SCHEMA_ORG_FILE_OBJECT, - constants.SCHEMA_ORG_FILE_SET, - constants.ML_COMMONS_RECORD_SET, - constants.ML_COMMONS_FIELD, - constants.ML_COMMONS_SUB_FIELD, - ] - if rdf_type not in expected_types: + """Adds an edge in the structure graph.""" + if uid not in uid_to_node: issues.add_error( - f'Node should have an attribute `"@type" in "{expected_types}"`.' + f'There is a reference to node named "{uid}" in node "{node.uid}", but this' + " node doesn't exist." ) - - # Return proper node in each case. - args = [issues, graph, node, name, parent_uid] - if rdf_type == constants.SCHEMA_ORG_DATASET: - with issues.context(dataset_name=name): - return Metadata( - *args, - citation=properties.get("citation"), - description=properties.get("description"), - license=properties.get("license"), - url=properties.get("url"), - ) - elif rdf_type == constants.SCHEMA_ORG_FILE_OBJECT: - with issues.context(distribution_name=name): - return FileObject( - *args, - contained_in=properties.get("contained_in"), - content_url=properties.get("content_url"), - description=properties.get("description"), - encoding_format=properties.get("encoding_format"), - md5=properties.get("md5"), - sha256=properties.get("sha256"), - ) - elif rdf_type == constants.SCHEMA_ORG_FILE_SET: - with issues.context(distribution_name=name): - return FileSet( - *args, - contained_in=properties.get("contained_in"), - description=properties.get("description"), - includes=properties.get("includes"), - encoding_format=properties.get("encoding_format"), - ) - elif rdf_type == constants.ML_COMMONS_RECORD_SET: - with issues.context(record_set_name=name): - return RecordSet( - *args, - data=properties.get("data"), - description=properties.get("description"), - key=properties.get("key"), - ) - elif rdf_type == constants.ML_COMMONS_FIELD: - with issues.context(field_name=name): - return Field( - *args, - data_type=properties.get("data_type"), - description=properties.get("description"), - has_sub_fields=properties.get("has_sub_fields"), - references=properties.get("references"), - source=properties.get("source"), - ) - raise ValueError(f"Wrong RDF type: {rdf_type}.") - - -def children_nodes(node: Any, expected_property: str) -> list[Any]: - """Finds all children objects/nodes.""" - nodes = [] - # pylint:disable=invalid-name - for _, _object, _property in node._edges_from_node: - if isinstance(_object, rdflib.term.BNode) and expected_property == _property: - nodes.append( - from_rdf_graph( - issues=node.issues, - graph=node.graph, - node=_object, - parent_uid=node.uid, - ) - ) - if not nodes and expected_property in [ - constants.ML_COMMONS_RECORD_SET, - constants.SCHEMA_ORG_DISTRIBUTION, - ]: - node.issues.add_warning( - "The current dataset doesn't declare any node of type:" - f' "{expected_property}"' + return + if not isinstance(uid_to_node[uid], expected_types): + issues.add_error( + f'There is a reference to node named "{uid}" in node "{node.uid}", but this' + f" node doesn't have the expected type: {expected_types}." ) - return nodes + return + graph.add_edge(uid_to_node[uid], node) + + +def _concatenate_uid(source: tuple[str]) -> str: + return "/".join(source) + + +def from_nodes_to_structure_graph( + issues: Issues, nodes: list[Node], parents: Mapping[str, str] +) -> nx.MultiDiGraph: + """Converts the list of nodes to a structure graph. + + In the structure graph: + - Nodes are Metadata, FileObjects, FileSets and Fields. + - Nodes must have a parent property, which is their direct parent in the Croissant + JSON. + - Nodes can have predecessor which is the source where data comes from. I.e., for + a field, the source of the data or a join, etc. + + Args: + issues: The issues to populate in case of problem. + nodes: The list of Python nodes. + parents: The list of nodes + + Returns: + The structure graph with the proper hierarchy. + """ + graph = nx.MultiDiGraph() + uid_to_node = _check_no_duplicate(nodes) + metadata = None + for node in nodes: + # Metadata + if isinstance(node, Metadata): + metadata = node + continue + parent_id = parents[node.rdf_id] + parent = next(_node for _node in nodes if _node.rdf_id == parent_id) + graph.add_node(node, parent=parent) + # Distribution + if isinstance(node, (FileObject, FileSet)) and node.contained_in: + for uid in node.contained_in: + _add_edge(issues, graph, uid_to_node, uid, node, (FileObject, FileSet)) + # Fields + elif isinstance(node, Field): + references = [] + if node.source: + references.append(node.source.reference) + if node.references: + references.append(node.references.reference) + for reference in references: + # The source can be either another field... + if (uid := _concatenate_uid(reference)) in uid_to_node: + # Record sets are not valid parents here. + # The case can arise when a Field references a record set to have a + # machine-readable explanation of the field (see datasets/titanic + # for example). + if not isinstance(uid_to_node[uid], RecordSet): + _add_edge(issues, graph, uid_to_node, uid, node, Node) + # ...or the source can be a metadata. + elif reference and (uid := reference[0]) in uid_to_node: + if not isinstance(uid_to_node[uid], RecordSet): + _add_edge( + issues, graph, uid_to_node, uid, node, (FileObject, FileSet) + ) + else: + issues.add_error( + "Source refers to an unknown node" + f' "{_concatenate_uid(reference)}".' + ) + # `Metadata` are used as the entry node. + if metadata is None: + issues.add_error("No metadata is defined in the dataset.") + return None, graph + _add_node_as_entry_node(issues, graph, metadata) + if not graph.is_directed(): + issues.add_error("Structure graph is not directed.") + return metadata, graph diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/graph_test.py b/python/ml_croissant/ml_croissant/_src/structure_graph/graph_test.py new file mode 100644 index 00000000..8e57b3c9 --- /dev/null +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/graph_test.py @@ -0,0 +1,198 @@ +"""graph_test module.""" + +from ml_croissant._src.core.issues import Context, Issues +from ml_croissant._src.structure_graph.graph import ( + from_json_to_jsonld, + from_jsonld_to_nodes, + Json, +) +from ml_croissant._src.structure_graph.nodes import ( + FileObject, + Metadata, +) + + +def _remove_keys(json: Json, keys: list[str]) -> Json: + """Edits a dict in place by removing the `keys` recursively.""" + if isinstance(json, dict): + for key, value in json.copy().items(): + if key in keys: + del json[key] + elif isinstance(value, (dict, list)): + json[key] = _remove_keys(json[key], keys) + elif isinstance(json, list): + return [_remove_keys(element, keys) for element in json] + return json + + +def assert_jsons_are_equal(json1: Json, json2: Json, ignore_keys: list[str] = []): + """Asserts two JSONs are equal by ignoring some keys.""" + _remove_keys(json1, ignore_keys) + _remove_keys(json2, ignore_keys) + assert json1 == json2 + + +def test_remove_keys(): + json = {"foo": "foo", "bar": [{"foo": "foo", "bar": "bar"}]} + _remove_keys(json, ["foo"]) + assert json == {"bar": [{"bar": "bar"}]} + + +def test_from_json_to_jsonld(): + json = { + "@context": { + "@vocab": "https://schema.org/", + "sc": "https://schema.org/", + "ml": "http://mlcommons.org/schema/", + "includes": "ml:includes", + "recordSet": "ml:RecordSet", + "field": "ml:Field", + "subField": "ml:SubField", + "dataType": "ml:dataType", + "source": "ml:source", + "data": "ml:data", + "applyTransform": "ml:applyTransform", + "format": "ml:format", + "regex": "ml:regex", + "separator": "ml:separator", + }, + "@type": "sc:Dataset", + "@language": "en", + "name": "mydataset", + "url": "https://www.google.com/dataset", + "description": "This is a description.", + "license": "This is a license.", + "citation": "This is a citation.", + "distribution": [ + { + "name": "a-csv-table", + "@type": "sc:FileObject", + "contentUrl": "ratings.csv", + "encodingFormat": "text/csv", + "sha256": "xxx", + } + ], + "recordSet": [ + { + "name": "annotations", + "@type": "ml:RecordSet", + "field": [ + { + "name": "bbox", + "@type": "ml:Field", + "description": "The bounding box around annotated object[s].", + "dataType": "ml:BoundingBox", + "source": { + "data": "#{a-csv-table/annotations}", + "format": "XYWH", + }, + }, + ], + }, + ], + } + expected_json_ld = [ + { + "@type": ["https://schema.org/Dataset"], + "http://mlcommons.org/schema/RecordSet": [{}], + "https://schema.org/@language": [{"@value": "en"}], + "https://schema.org/citation": [{"@value": "This is a citation."}], + "https://schema.org/description": [{"@value": "This is a description."}], + "https://schema.org/distribution": [{}], + "https://schema.org/license": [{"@value": "This is a license."}], + "https://schema.org/name": [{"@value": "mydataset"}], + "https://schema.org/url": [{"@value": "https://www.google.com/dataset"}], + }, + { + "@type": ["https://schema.org/FileObject"], + "https://schema.org/contentUrl": [{"@value": "ratings.csv"}], + "https://schema.org/encodingFormat": [{"@value": "text/csv"}], + "https://schema.org/name": [{"@value": "a-csv-table"}], + "https://schema.org/sha256": [{"@value": "xxx"}], + }, + { + "@type": ["http://mlcommons.org/schema/RecordSet"], + "http://mlcommons.org/schema/Field": [{}], + "https://schema.org/name": [{"@value": "annotations"}], + }, + { + "@type": ["http://mlcommons.org/schema/Field"], + "http://mlcommons.org/schema/dataType": [{"@value": "ml:BoundingBox"}], + "http://mlcommons.org/schema/source": [{}], + "https://schema.org/description": [ + {"@value": "The bounding box around annotated object[s]."} + ], + "https://schema.org/name": [{"@value": "bbox"}], + }, + { + "http://mlcommons.org/schema/data": [ + {"@value": "#{a-csv-table/annotations}"} + ], + "http://mlcommons.org/schema/format": [{"@value": "XYWH"}], + }, + ] + _, json_ld = from_json_to_jsonld(json) + # We ignore `@id`, because they can change. + assert_jsons_are_equal(expected_json_ld, json_ld, ignore_keys=["@id"]) + + +def test_from_jsonld_to_nodes(): + issues = Issues() + json_ld = [ + { + "@id": "ID_DATASET", + "@type": ["https://schema.org/Dataset"], + "https://schema.org/@language": [{"@value": "en"}], + "https://schema.org/citation": [{"@value": "This is a citation."}], + "https://schema.org/description": [{"@value": "This is a description."}], + "https://schema.org/license": [{"@value": "This is a license."}], + "https://schema.org/name": [{"@value": "mydataset"}], + "https://schema.org/url": [{"@value": "https://www.google.com/dataset"}], + "https://schema.org/distribution": [{"@id": "ID_FILE_OBJECT"}], + }, + { + "@id": "ID_FILE_OBJECT", + "@type": ["https://schema.org/FileObject"], + "https://schema.org/name": [{"@value": "a-csv-table"}], + "https://schema.org/contentUrl": [{"@value": "ratings.csv"}], + "https://schema.org/encodingFormat": [{"@value": "text/csv"}], + "https://schema.org/sha256": [{"@value": "xxx"}], + }, + ] + expected_nodes = [ + Metadata( + issues=issues, + name="mydataset", + rdf_id="ID_DATASET", + uid="mydataset", + citation="This is a citation.", + description="This is a description.", + license="This is a license.", + url="https://www.google.com/dataset", + context=Context( + dataset_name="mydataset", + distribution_name=None, + record_set_name=None, + field_name=None, + sub_field_name=None, + ), + ), + FileObject( + issues=issues, + name="a-csv-table", + rdf_id="ID_FILE_OBJECT", + uid="a-csv-table", + content_url="ratings.csv", + encoding_format="text/csv", + sha256="xxx", + context=Context( + dataset_name="mydataset", + distribution_name="a-csv-table", + record_set_name=None, + field_name=None, + sub_field_name=None, + ), + ), + ] + nodes, _ = from_jsonld_to_nodes(issues, json_ld) + assert nodes == expected_nodes diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/field.py b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/field.py index 58a08590..9a0c5b94 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/field.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/field.py @@ -7,7 +7,7 @@ from ml_croissant._src.structure_graph.base_node import Node from ml_croissant._src.structure_graph.nodes.source import Source -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, repr=False) class Field(Node): """Nodes to describe a dataset Field.""" @@ -19,7 +19,7 @@ class Field(Node): references: Source = dataclasses.field(default_factory=Source) source: Source = dataclasses.field(default_factory=Source) - def __post_init__(self): + def check(self): self.assert_has_mandatory_properties("data_type", "name") self.assert_has_optional_properties("description") # TODO(marcenacp): check that `data` has the expected form if it exists. diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_object.py b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_object.py index cf5100dc..db8ca912 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_object.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_object.py @@ -6,11 +6,12 @@ from ml_croissant._src.structure_graph.nodes.source import Source -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, repr=False) class FileObject(Node): """Nodes to describe a dataset FileObject (distribution).""" content_url: str = "" + content_size: str = "" contained_in: tuple[str] = () description: str | None = None encoding_format: str = "" @@ -19,7 +20,7 @@ class FileObject(Node): sha256: str | None = None source: Source | None = None - def __post_init__(self): + def check(self): self.assert_has_mandatory_properties("content_url", "encoding_format", "name") if not self.contained_in: self.assert_has_exclusive_properties(["md5", "sha256"]) diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_set.py b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_set.py index a0ca6ebb..e73ca255 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_set.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/file_set.py @@ -5,7 +5,7 @@ from ml_croissant._src.structure_graph.base_node import Node -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, repr=False) class FileSet(Node): """Nodes to describe a dataset FileSet (distribution).""" @@ -15,5 +15,5 @@ class FileSet(Node): includes: str = "" name: str = "" - def __post_init__(self): + def check(self): self.assert_has_mandatory_properties("includes", "encoding_format", "name") diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/metadata.py b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/metadata.py index d4ea3bd3..632296f8 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/metadata.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/metadata.py @@ -5,7 +5,7 @@ from ml_croissant._src.structure_graph.base_node import Node -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, repr=False) class Metadata(Node): """Nodes to describe a dataset metadata.""" @@ -15,6 +15,6 @@ class Metadata(Node): name: str = "" url: str = "" - def __post_init__(self): + def check(self): self.assert_has_mandatory_properties("name", "url") self.assert_has_optional_properties("citation", "license") diff --git a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/record_set.py b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/record_set.py index fd6267a3..e143fa49 100644 --- a/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/record_set.py +++ b/python/ml_croissant/ml_croissant/_src/structure_graph/nodes/record_set.py @@ -5,9 +5,10 @@ from typing import Any from ml_croissant._src.structure_graph.base_node import Node +from ml_croissant._src.structure_graph.nodes.source import Source -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, repr=False) class RecordSet(Node): """Nodes to describe a dataset RecordSet.""" @@ -17,7 +18,8 @@ class RecordSet(Node): description: str | None = None key: str | None = None name: str = "" + source: Source | None = None - def __post_init__(self): + def check(self): self.assert_has_mandatory_properties("name") self.assert_has_optional_properties("description") diff --git a/python/ml_croissant/ml_croissant/_src/tests/nodes.py b/python/ml_croissant/ml_croissant/_src/tests/nodes.py index 88dbe70b..b7564d96 100644 --- a/python/ml_croissant/ml_croissant/_src/tests/nodes.py +++ b/python/ml_croissant/ml_croissant/_src/tests/nodes.py @@ -3,10 +3,16 @@ from ml_croissant._src.core.issues import Issues from ml_croissant._src.structure_graph.base_node import Node -empty_node = Node( + +class _EmptyNode(Node): + def check(self): + pass + + +empty_node = _EmptyNode( issues=Issues(), graph=None, node=None, name="node_name", - parent_uid="parent_name", + uid="node_name", ) diff --git a/python/ml_croissant/scripts/validate.py b/python/ml_croissant/scripts/validate.py index 232641ff..d8fa67f3 100644 --- a/python/ml_croissant/scripts/validate.py +++ b/python/ml_croissant/scripts/validate.py @@ -13,6 +13,12 @@ "Path to the file to validate.", ) +flags.DEFINE_bool( + "debug", + False, + "Whether to print debug hints.", +) + flags.mark_flag_as_required("file") @@ -22,8 +28,9 @@ def main(argv): del argv file = FLAGS.file + debug = FLAGS.debug try: - Dataset(file) + Dataset(file, debug=debug) logging.info("Done.") except ValidationError as exception: logging.error(exception)