diff --git a/spacy/errors.py b/spacy/errors.py index 5fe550145cd..4e3b2022a57 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -887,6 +887,7 @@ class Errors(metaclass=ErrorsWithCodes): E1021 = ("`pos` value \"{pp}\" is not a valid Universal Dependencies tag. " "Non-UD tags should use the `tag` property.") E1022 = ("Words must be of type str or int, but input is of type '{wtype}'") + E1023 = ("Couldn't read EntityRuler from the {path}. This file doesn't exist.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 2c3db257530..78d7a0be27b 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -431,10 +431,16 @@ def from_disk( path = ensure_path(path) self.clear() depr_patterns_path = path.with_suffix(".jsonl") - if depr_patterns_path.is_file(): + if path.suffix == ".jsonl": # user provides a jsonl + if path.is_file: + patterns = srsly.read_jsonl(path) + self.add_patterns(patterns) + else: + raise ValueError(Errors.E1023.format(path=path)) + elif depr_patterns_path.is_file(): patterns = srsly.read_jsonl(depr_patterns_path) self.add_patterns(patterns) - else: + elif path.is_dir(): # path is a valid directory cfg = {} deserializers_patterns = { "patterns": lambda p: self.add_patterns( @@ -451,6 +457,8 @@ def from_disk( self.nlp.vocab, attr=self.phrase_matcher_attr ) from_disk(path, deserializers_patterns, {}) + else: # path is not a valid directory or file + raise ValueError(Errors.E146.format(path=path)) return self def to_disk( diff --git a/spacy/tests/pipeline/test_entity_ruler.py b/spacy/tests/pipeline/test_entity_ruler.py index dc0ca030138..e66b49518da 100644 --- a/spacy/tests/pipeline/test_entity_ruler.py +++ b/spacy/tests/pipeline/test_entity_ruler.py @@ -5,6 +5,8 @@ from spacy.language import Language from spacy.pipeline import EntityRuler from spacy.errors import MatchPatternError +from spacy.tests.util import make_tempdir + from thinc.api import NumpyOps, get_current_ops @@ -238,3 +240,23 @@ def test_entity_ruler_multiprocessing(nlp, n_process): for doc in nlp.pipe(texts, n_process=2): for ent in doc.ents: assert ent.ent_id_ == "1234" + + +def test_entity_ruler_serialize_jsonl(nlp, patterns): + ruler = nlp.add_pipe("entity_ruler") + ruler.add_patterns(patterns) + with make_tempdir() as d: + ruler.to_disk(d / "test_ruler.jsonl") + ruler.from_disk(d / "test_ruler.jsonl") # read from an existing jsonl file + with pytest.raises(ValueError): + ruler.from_disk(d / "non_existing.jsonl") # read from a bad jsonl file + + +def test_entity_ruler_serialize_dir(nlp, patterns): + ruler = nlp.add_pipe("entity_ruler") + ruler.add_patterns(patterns) + with make_tempdir() as d: + ruler.to_disk(d / "test_ruler") + ruler.from_disk(d / "test_ruler") # read from an existing directory + with pytest.raises(ValueError): + ruler.from_disk(d / "non_existing_dir") # read from a bad directory