-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
util.py
1836 lines (1548 loc) · 65.9 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import functools
import importlib
import importlib.util
import inspect
import itertools
import logging
import os
import pkgutil
import re
import shlex
import shutil
import socket
import stat
import subprocess
import sys
import tempfile
import warnings
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
NoReturn,
Optional,
Pattern,
Set,
Tuple,
Type,
Union,
cast,
)
import catalogue
import langcodes
import numpy
import srsly
import thinc
from catalogue import Registry, RegistryError
from packaging.requirements import Requirement
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import InvalidVersion, Version
from thinc.api import (
Adam,
Config,
ConfigValidationError,
Model,
NumpyOps,
Optimizer,
get_current_ops,
)
try:
import cupy.random
except ImportError:
cupy = None
# These are functions that were previously (v2.x) available from spacy.util
# and have since moved to Thinc. We're importing them here so people's code
# doesn't break, but they should always be imported from Thinc from now on,
# not from spacy.util.
from thinc.api import compounding, decaying, fix_random_seed # noqa: F401
from . import about
from .compat import CudaStream, cupy, importlib_metadata, is_windows
from .errors import OLD_MODEL_SHORTCUTS, Errors, Warnings
from .symbols import ORTH
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from .language import Language, PipeCallable # noqa: F401
from .tokens import Doc, Span # noqa: F401
from .vocab import Vocab # noqa: F401
# fmt: off
OOV_RANK = numpy.iinfo(numpy.uint64).max
DEFAULT_OOV_PROB = -20
LEXEME_NORM_LANGS = ["cs", "da", "de", "el", "en", "grc", "id", "lb", "mk", "pt", "ru", "sr", "ta", "th"]
# Default order of sections in the config file. Not all sections needs to exist,
# and additional sections are added at the end, in alphabetical order.
CONFIG_SECTION_ORDER = ["paths", "variables", "system", "nlp", "components", "corpora", "training", "pretraining", "initialize"]
# fmt: on
logger = logging.getLogger("spacy")
logger_stream_handler = logging.StreamHandler()
logger_stream_handler.setFormatter(
logging.Formatter("[%(asctime)s] [%(levelname)s] %(message)s")
)
logger.addHandler(logger_stream_handler)
class ENV_VARS:
CONFIG_OVERRIDES = "SPACY_CONFIG_OVERRIDES"
class registry(thinc.registry):
languages = catalogue.create("spacy", "languages", entry_points=True)
architectures = catalogue.create("spacy", "architectures", entry_points=True)
tokenizers = catalogue.create("spacy", "tokenizers", entry_points=True)
lemmatizers = catalogue.create("spacy", "lemmatizers", entry_points=True)
lookups = catalogue.create("spacy", "lookups", entry_points=True)
displacy_colors = catalogue.create("spacy", "displacy_colors", entry_points=True)
misc = catalogue.create("spacy", "misc", entry_points=True)
# Callback functions used to manipulate nlp object etc.
callbacks = catalogue.create("spacy", "callbacks", entry_points=True)
batchers = catalogue.create("spacy", "batchers", entry_points=True)
readers = catalogue.create("spacy", "readers", entry_points=True)
augmenters = catalogue.create("spacy", "augmenters", entry_points=True)
loggers = catalogue.create("spacy", "loggers", entry_points=True)
scorers = catalogue.create("spacy", "scorers", entry_points=True)
vectors = catalogue.create("spacy", "vectors", entry_points=True)
# These are factories registered via third-party packages and the
# spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the
# Language.factory decorator (in the spaCy code base and user code) and those
# are the factories used to initialize components via registry.resolve.
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
factories = catalogue.create("spacy", "internal_factories")
# This is mostly used to get a list of all installed models in the current
# environment. spaCy models packaged with `spacy package` will "advertise"
# themselves via entry points.
models = catalogue.create("spacy", "models", entry_points=True)
cli = catalogue.create("spacy", "cli", entry_points=True)
@classmethod
def get_registry_names(cls) -> List[str]:
"""List all available registries."""
names = []
for name, value in inspect.getmembers(cls):
if not name.startswith("_") and isinstance(value, Registry):
names.append(name)
return sorted(names)
@classmethod
def get(cls, registry_name: str, func_name: str) -> Callable:
"""Get a registered function from the registry."""
# We're overwriting this classmethod so we're able to provide more
# specific error messages and implement a fallback to spacy-legacy.
if not hasattr(cls, registry_name):
names = ", ".join(cls.get_registry_names()) or "none"
raise RegistryError(Errors.E892.format(name=registry_name, available=names))
reg = getattr(cls, registry_name)
try:
func = reg.get(func_name)
except RegistryError:
if func_name.startswith("spacy."):
legacy_name = func_name.replace("spacy.", "spacy-legacy.")
try:
return reg.get(legacy_name)
except catalogue.RegistryError:
pass
available = ", ".join(sorted(reg.get_all().keys())) or "none"
raise RegistryError(
Errors.E893.format(
name=func_name, reg_name=registry_name, available=available
)
) from None
return func
@classmethod
def find(
cls, registry_name: str, func_name: str
) -> Dict[str, Optional[Union[str, int]]]:
"""Find information about a registered function, including the
module and path to the file it's defined in, the line number and the
docstring, if available.
registry_name (str): Name of the catalogue registry.
func_name (str): Name of the registered function.
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
"""
# We're overwriting this classmethod so we're able to provide more
# specific error messages and implement a fallback to spacy-legacy.
if not hasattr(cls, registry_name):
names = ", ".join(cls.get_registry_names()) or "none"
raise RegistryError(Errors.E892.format(name=registry_name, available=names))
reg = getattr(cls, registry_name)
try:
func_info = reg.find(func_name)
except RegistryError:
if func_name.startswith("spacy."):
legacy_name = func_name.replace("spacy.", "spacy-legacy.")
try:
return reg.find(legacy_name)
except catalogue.RegistryError:
pass
available = ", ".join(sorted(reg.get_all().keys())) or "none"
raise RegistryError(
Errors.E893.format(
name=func_name, reg_name=registry_name, available=available
)
) from None
return func_info
@classmethod
def has(cls, registry_name: str, func_name: str) -> bool:
"""Check whether a function is available in a registry."""
if not hasattr(cls, registry_name):
return False
reg = getattr(cls, registry_name)
if func_name.startswith("spacy."):
legacy_name = func_name.replace("spacy.", "spacy-legacy.")
return func_name in reg or legacy_name in reg
return func_name in reg
class SimpleFrozenDict(dict):
"""Simplified implementation of a frozen dict, mainly used as default
function or method argument (for arguments that should default to empty
dictionary). Will raise an error if user or spaCy attempts to add to dict.
"""
def __init__(self, *args, error: str = Errors.E095, **kwargs) -> None:
"""Initialize the frozen dict. Can be initialized with pre-defined
values.
error (str): The error message when user tries to assign to dict.
"""
super().__init__(*args, **kwargs)
self.error = error
def __setitem__(self, key, value):
raise NotImplementedError(self.error)
def pop(self, key, default=None):
raise NotImplementedError(self.error)
def update(self, other):
raise NotImplementedError(self.error)
class SimpleFrozenList(list):
"""Wrapper class around a list that lets us raise custom errors if certain
attributes/methods are accessed. Mostly used for properties like
Language.pipeline that return an immutable list (and that we don't want to
convert to a tuple to not break too much backwards compatibility). If a user
accidentally calls nlp.pipeline.append(), we can raise a more helpful error.
"""
def __init__(self, *args, error: str = Errors.E927) -> None:
"""Initialize the frozen list.
error (str): The error message when user tries to mutate the list.
"""
self.error = error
super().__init__(*args)
def append(self, *args, **kwargs):
raise NotImplementedError(self.error)
def clear(self, *args, **kwargs):
raise NotImplementedError(self.error)
def extend(self, *args, **kwargs):
raise NotImplementedError(self.error)
def insert(self, *args, **kwargs):
raise NotImplementedError(self.error)
def pop(self, *args, **kwargs):
raise NotImplementedError(self.error)
def remove(self, *args, **kwargs):
raise NotImplementedError(self.error)
def reverse(self, *args, **kwargs):
raise NotImplementedError(self.error)
def sort(self, *args, **kwargs):
raise NotImplementedError(self.error)
def lang_class_is_loaded(lang: str) -> bool:
"""Check whether a Language class is already loaded. Language classes are
loaded lazily, to avoid expensive setup code associated with the language
data.
lang (str): Two-letter language code, e.g. 'en'.
RETURNS (bool): Whether a Language class has been loaded.
"""
return lang in registry.languages
def find_matching_language(lang: str) -> Optional[str]:
"""
Given an IETF language code, find a supported spaCy language that is a
close match for it (according to Unicode CLDR language-matching rules).
This allows for language aliases, ISO 639-2 codes, more detailed language
tags, and close matches.
Returns the language code if a matching language is available, or None
if there is no matching language.
>>> find_matching_language('en')
'en'
>>> find_matching_language('pt-BR') # Brazilian Portuguese
'pt'
>>> find_matching_language('fra') # an ISO 639-2 code for French
'fr'
>>> find_matching_language('iw') # obsolete alias for Hebrew
'he'
>>> find_matching_language('no') # Norwegian
'nb'
>>> find_matching_language('mo') # old code for ro-MD
'ro'
>>> find_matching_language('zh-Hans') # Simplified Chinese
'zh'
>>> find_matching_language('zxx')
None
"""
import spacy.lang # noqa: F401
if lang == "xx":
return "xx"
# Find out which language modules we have
possible_languages = []
for modinfo in pkgutil.iter_modules(spacy.lang.__path__): # type: ignore[attr-defined]
code = modinfo.name
if code == "xx":
# Temporarily make 'xx' into a valid language code
possible_languages.append("mul")
elif langcodes.tag_is_valid(code):
possible_languages.append(code)
# Distances from 1-9 allow near misses like Bosnian -> Croatian and
# Norwegian -> Norwegian Bokmål. A distance of 10 would include several
# more possibilities, like variants of Chinese like 'wuu', but text that
# is labeled that way is probably trying to be distinct from 'zh' and
# shouldn't automatically match.
match = langcodes.closest_supported_match(lang, possible_languages, max_distance=9)
if match == "mul":
# Convert 'mul' back to spaCy's 'xx'
return "xx"
else:
return match
def get_lang_class(lang: str) -> Type["Language"]:
"""Import and load a Language class.
lang (str): IETF language code, such as 'en'.
RETURNS (Language): Language class.
"""
# Check if language is registered / entry point is available
if lang in registry.languages:
return registry.languages.get(lang)
else:
# Find the language in the spacy.lang subpackage
try:
module = importlib.import_module(f".lang.{lang}", "spacy")
except ImportError as err:
# Find a matching language. For example, if the language 'no' is
# requested, we can use language-matching to load `spacy.lang.nb`.
try:
match = find_matching_language(lang)
except langcodes.tag_parser.LanguageTagError:
# proceed to raising an import error
match = None
if match:
lang = match
module = importlib.import_module(f".lang.{lang}", "spacy")
else:
raise ImportError(Errors.E048.format(lang=lang, err=err)) from err
set_lang_class(lang, getattr(module, module.__all__[0])) # type: ignore[attr-defined]
return registry.languages.get(lang)
def set_lang_class(name: str, cls: Type["Language"]) -> None:
"""Set a custom Language class name that can be loaded via get_lang_class.
name (str): Name of Language class.
cls (Language): Language class.
"""
registry.languages.register(name, func=cls)
def ensure_path(path: Any) -> Any:
"""Ensure string is converted to a Path.
path (Any): Anything. If string, it's converted to Path.
RETURNS: Path or original argument.
"""
if isinstance(path, str):
return Path(path)
else:
return path
def load_language_data(path: Union[str, Path]) -> Union[dict, list]:
"""Load JSON language data using the given path as a base. If the provided
path isn't present, will attempt to load a gzipped version before giving up.
path (str / Path): The data to load.
RETURNS: The loaded data.
"""
path = ensure_path(path)
if path.exists():
return srsly.read_json(path)
path = path.with_suffix(path.suffix + ".gz")
if path.exists():
return srsly.read_gzip_json(path)
raise ValueError(Errors.E160.format(path=path))
def get_module_path(module: ModuleType) -> Path:
"""Get the path of a Python module.
module (ModuleType): The Python module.
RETURNS (Path): The path.
"""
if not hasattr(module, "__module__"):
raise ValueError(Errors.E169.format(module=repr(module)))
file_path = Path(cast(os.PathLike, sys.modules[module.__module__].__file__))
return file_path.parent
# Default value for passed enable/disable values.
_DEFAULT_EMPTY_PIPES = SimpleFrozenList()
def load_model(
name: Union[str, Path],
*,
vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a package or data path.
name (str): Package name or model path.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable.
enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All others will be disabled.
exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
kwargs = {
"vocab": vocab,
"disable": disable,
"enable": enable,
"exclude": exclude,
"config": config,
}
if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))()
if is_package(name): # installed as package
return load_model_from_package(name, **kwargs) # type: ignore[arg-type]
if Path(name).exists(): # path to model data directory
return load_model_from_path(Path(name), **kwargs) # type: ignore[arg-type]
elif hasattr(name, "exists"): # Path or Path-like to model data
return load_model_from_path(name, **kwargs) # type: ignore[arg-type]
if name in OLD_MODEL_SHORTCUTS:
raise IOError(Errors.E941.format(name=name, full=OLD_MODEL_SHORTCUTS[name])) # type: ignore[index]
raise IOError(Errors.E050.format(name=name))
def load_model_from_package(
name: str,
*,
vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from an installed package.
name (str): The package name.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
cls = importlib.import_module(name)
return cls.load(vocab=vocab, disable=disable, enable=enable, exclude=exclude, config=config) # type: ignore[attr-defined]
def load_model_from_path(
model_path: Path,
*,
meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Load a model from a data directory path. Creates Language class with
pipeline from config.cfg and then calls from_disk() with path.
model_path (Path): Model path.
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
if not model_path.exists():
raise IOError(Errors.E052.format(path=model_path))
if not meta:
meta = get_model_meta(model_path)
config_path = model_path / "config.cfg"
overrides = dict_to_dot(config, for_overrides=True)
config = load_config(config_path, overrides=overrides)
nlp = load_model_from_config(
config,
vocab=vocab,
disable=disable,
enable=enable,
exclude=exclude,
meta=meta,
)
return nlp.from_disk(model_path, exclude=exclude, overrides=overrides)
def load_model_from_config(
config: Union[Dict[str, Any], Config],
*,
meta: Dict[str, Any] = SimpleFrozenDict(),
vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
auto_fill: bool = False,
validate: bool = True,
) -> "Language":
"""Create an nlp object from a config. Expects the full config file including
a section "nlp" containing the settings for the nlp object.
name (str): Package name or model path.
meta (Dict[str, Any]): Optional model meta.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. Excluded
components won't be loaded.
auto_fill (bool): Whether to auto-fill config with missing defaults.
validate (bool): Whether to show config validation errors.
RETURNS (Language): The loaded nlp object.
"""
if "nlp" not in config:
raise ValueError(Errors.E985.format(config=config))
nlp_config = config["nlp"]
if "lang" not in nlp_config or nlp_config["lang"] is None:
raise ValueError(Errors.E993.format(config=nlp_config))
# This will automatically handle all codes registered via the languages
# registry, including custom subclasses provided via entry points
lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config(
config,
vocab=vocab,
disable=disable,
enable=enable,
exclude=exclude,
auto_fill=auto_fill,
validate=validate,
meta=meta,
)
return nlp
def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return {
name: cfg
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
}
def resolve_dot_names(
config: Config, dot_names: List[Optional[str]]
) -> Tuple[Any, ...]:
"""Resolve one or more "dot notation" names, e.g. corpora.train.
The paths could point anywhere into the config, so we don't know which
top-level section we'll be looking within.
We resolve the whole top-level section, although we could resolve less --
we could find the lowest part of the tree.
"""
# TODO: include schema?
resolved = {}
output: List[Any] = []
errors = []
for name in dot_names:
if name is None:
output.append(name)
else:
section = name.split(".")[0]
# We want to avoid resolving the same thing twice
if section not in resolved:
if registry.is_promise(config[section]):
# Otherwise we can't resolve [corpus] if it's a promise
result = registry.resolve({"config": config[section]})["config"]
else:
result = registry.resolve(config[section])
resolved[section] = result
try:
output.append(dot_to_object(resolved, name)) # type: ignore[arg-type]
except KeyError:
msg = f"not a valid section reference: {name}"
errors.append({"loc": name.split("."), "msg": msg})
if errors:
raise ConfigValidationError(config=config, errors=errors)
return tuple(output)
def load_model_from_init_py(
init_file: Union[Path, str],
*,
vocab: Union["Vocab", bool] = True,
disable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
enable: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
exclude: Union[str, Iterable[str]] = _DEFAULT_EMPTY_PIPES,
config: Union[Dict[str, Any], Config] = SimpleFrozenDict(),
) -> "Language":
"""Helper function to use in the `load()` method of a model package's
__init__.py.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to disable. Disabled
pipes will be loaded but they won't be run unless you explicitly
enable them by calling nlp.enable_pipe.
enable (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to enable. All other
pipes will be disabled (and can be enabled using `nlp.enable_pipe`).
exclude (Union[str, Iterable[str]]): Name(s) of pipeline component(s) to exclude. Excluded
components won't be loaded.
config (Dict[str, Any] / Config): Config overrides as nested dict or dict
keyed by section values in dot notation.
RETURNS (Language): The loaded nlp object.
"""
model_path = Path(init_file).parent
meta = get_model_meta(model_path)
data_dir = f"{meta['lang']}_{meta['name']}-{meta['version']}"
data_path = model_path / data_dir
if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path(
data_path,
vocab=vocab,
meta=meta,
disable=disable,
enable=enable,
exclude=exclude,
config=config,
)
def load_config(
path: Union[str, Path],
overrides: Dict[str, Any] = SimpleFrozenDict(),
interpolate: bool = False,
) -> Config:
"""Load a config file. Takes care of path validation and section order.
path (Union[str, Path]): Path to the config file or "-" to read from stdin.
overrides: (Dict[str, Any]): Config overrides as nested dict or
dict keyed by section values in dot notation.
interpolate (bool): Whether to interpolate and resolve variables.
RETURNS (Config): The loaded config.
"""
config_path = ensure_path(path)
config = Config(section_order=CONFIG_SECTION_ORDER)
if str(config_path) == "-": # read from standard input
return config.from_str(
sys.stdin.read(), overrides=overrides, interpolate=interpolate
)
else:
if not config_path or not config_path.is_file():
raise IOError(Errors.E053.format(path=config_path, name="config file"))
return config.from_disk(
config_path, overrides=overrides, interpolate=interpolate
)
def load_config_from_str(
text: str, overrides: Dict[str, Any] = SimpleFrozenDict(), interpolate: bool = False
):
"""Load a full config from a string. Wrapper around Thinc's Config.from_str.
text (str): The string config to load.
interpolate (bool): Whether to interpolate and resolve variables.
RETURNS (Config): The loaded config.
"""
return Config(section_order=CONFIG_SECTION_ORDER).from_str(
text, overrides=overrides, interpolate=interpolate
)
def get_installed_models() -> List[str]:
"""List all model packages currently installed in the environment.
RETURNS (List[str]): The string names of the models.
"""
return list(registry.models.get_all().keys())
def get_package_version(name: str) -> Optional[str]:
"""Get the version of an installed package. Typically used to get model
package versions.
name (str): The name of the installed Python package.
RETURNS (str / None): The version or None if package not installed.
"""
try:
return importlib_metadata.version(name) # type: ignore[attr-defined]
except importlib_metadata.PackageNotFoundError: # type: ignore[attr-defined]
return None
def is_compatible_version(
version: str, constraint: str, prereleases: bool = True
) -> Optional[bool]:
"""Check if a version (e.g. "2.0.0") is compatible given a version
constraint (e.g. ">=1.9.0,<2.2.1"). If the constraint is a specific version,
it's interpreted as =={version}.
version (str): The version to check.
constraint (str): The constraint string.
prereleases (bool): Whether to allow prereleases. If set to False,
prerelease versions will be considered incompatible.
RETURNS (bool / None): Whether the version is compatible, or None if the
version or constraint are invalid.
"""
# Handle cases where exact version is provided as constraint
if constraint[0].isdigit():
constraint = f"=={constraint}"
try:
spec = SpecifierSet(constraint)
version = Version(version) # type: ignore[assignment]
except (InvalidSpecifier, InvalidVersion):
return None
spec.prereleases = prereleases
return version in spec
def is_unconstrained_version(
constraint: str, prereleases: bool = True
) -> Optional[bool]:
# We have an exact version, this is the ultimate constrained version
if constraint[0].isdigit():
return False
try:
spec = SpecifierSet(constraint)
except InvalidSpecifier:
return None
spec.prereleases = prereleases
specs = [sp for sp in spec]
# We only have one version spec and it defines > or >=
if len(specs) == 1 and specs[0].operator in (">", ">="):
return True
# One specifier is exact version
if any(sp.operator in ("==") for sp in specs):
return False
has_upper = any(sp.operator in ("<", "<=") for sp in specs)
has_lower = any(sp.operator in (">", ">=") for sp in specs)
# We have a version spec that defines an upper and lower bound
if has_upper and has_lower:
return False
# Everything else, like only an upper version, only a lower version etc.
return True
def split_requirement(requirement: str) -> Tuple[str, str]:
"""Split a requirement like spacy>=1.2.3 into ("spacy", ">=1.2.3")."""
req = Requirement(requirement)
return (req.name, str(req.specifier))
def get_minor_version_range(version: str) -> str:
"""Generate a version range like >=1.2.3,<1.3.0 based on a given version
(e.g. of spaCy).
"""
release = Version(version).release
return f">={version},<{release[0]}.{release[1] + 1}.0"
def get_model_lower_version(constraint: str) -> Optional[str]:
"""From a version range like >=1.2.3,<1.3.0 return the lower pin."""
try:
specset = SpecifierSet(constraint)
for spec in specset:
if spec.operator in (">=", "==", "~="):
return spec.version
except Exception:
pass
return None
def is_prerelease_version(version: str) -> bool:
"""Check whether a version is a prerelease version.
version (str): The version, e.g. "3.0.0.dev1".
RETURNS (bool): Whether the version is a prerelease version.
"""
return Version(version).is_prerelease
def get_base_version(version: str) -> str:
"""Generate the base version without any prerelease identifiers.
version (str): The version, e.g. "3.0.0.dev1".
RETURNS (str): The base version, e.g. "3.0.0".
"""
return Version(version).base_version
def get_minor_version(version: str) -> Optional[str]:
"""Get the major + minor version (without patch or prerelease identifiers).
version (str): The version.
RETURNS (str): The major + minor version or None if version is invalid.
"""
try:
v = Version(version)
except (TypeError, InvalidVersion):
return None
return f"{v.major}.{v.minor}"
def is_minor_version_match(version_a: str, version_b: str) -> bool:
"""Compare two versions and check if they match in major and minor, without
patch or prerelease identifiers. Used internally for compatibility checks
that should be insensitive to patch releases.
version_a (str): The first version
version_b (str): The second version.
RETURNS (bool): Whether the versions match.
"""
a = get_minor_version(version_a)
b = get_minor_version(version_b)
return a is not None and b is not None and a == b
def load_meta(path: Union[str, Path]) -> Dict[str, Any]:
"""Load a model meta.json from a path and validate its contents.
path (Union[str, Path]): Path to meta.json.
RETURNS (Dict[str, Any]): The loaded meta.
"""
path = ensure_path(path)
if not path.parent.exists():
raise IOError(Errors.E052.format(path=path.parent))
if not path.exists() or not path.is_file():
raise IOError(Errors.E053.format(path=path.parent, name="meta.json"))
meta = srsly.read_json(path)
for setting in ["lang", "name", "version"]:
if setting not in meta or not meta[setting]:
raise ValueError(Errors.E054.format(setting=setting))
if "spacy_version" in meta:
if not is_compatible_version(about.__version__, meta["spacy_version"]):
lower_version = get_model_lower_version(meta["spacy_version"])
lower_version = get_base_version(lower_version) # type: ignore[arg-type]
if lower_version is not None:
lower_version = "v" + lower_version
elif "spacy_git_version" in meta:
lower_version = "git commit " + meta["spacy_git_version"]
else:
lower_version = "version unknown"
warn_msg = Warnings.W095.format(
model=f"{meta['lang']}_{meta['name']}",
model_version=meta["version"],
version=lower_version,
current=about.__version__,
)
warnings.warn(warn_msg)
if is_unconstrained_version(meta["spacy_version"]):
warn_msg = Warnings.W094.format(
model=f"{meta['lang']}_{meta['name']}",
model_version=meta["version"],
version=meta["spacy_version"],
example=get_minor_version_range(about.__version__),
)
warnings.warn(warn_msg)
return meta
def get_model_meta(path: Union[str, Path]) -> Dict[str, Any]:
"""Get model meta.json from a directory path and validate its contents.
path (str / Path): Path to model directory.
RETURNS (Dict[str, Any]): The model's meta data.
"""
model_path = ensure_path(path)
return load_meta(model_path / "meta.json")
def is_package(name: str) -> bool:
"""Check if string maps to a package installed via pip.
name (str): Name of package.
RETURNS (bool): True if installed package, False if not.
"""
try:
importlib_metadata.distribution(name) # type: ignore[attr-defined]
return True
except: # noqa: E722
return False
def get_package_path(name: str) -> Path:
"""Get the path to an installed package.
name (str): Package name.
RETURNS (Path): Path to installed package.
"""
# Here we're importing the module just to find it. This is worryingly
# indirect, but it's otherwise very difficult to find the package.
pkg = importlib.import_module(name)
return Path(cast(Union[str, os.PathLike], pkg.__file__)).parent
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
"""Replace a node within a model with a new one, updating refs.
model (Model): The parent model.
target (Model): The target node.
replacement (Model): The node to replace the target with.
"""
# Place the node into the sublayers
for node in model.walk():
if target in node.layers:
node.layers[node.layers.index(target)] = replacement
# Now fix any node references
for node in model.walk():
for ref_name in node.ref_names:
if node.maybe_get_ref(ref_name) is target:
node.set_ref(ref_name, replacement)
def split_command(command: str) -> List[str]:
"""Split a string command using shlex. Handles platform compatibility.
command (str) : The command to split
RETURNS (List[str]): The split command.
"""
return shlex.split(command, posix=not is_windows)
def run_command(
command: Union[str, List[str]],
*,
stdin: Optional[Any] = None,
capture: bool = False,
) -> subprocess.CompletedProcess:
"""Run a command on the command line as a subprocess. If the subprocess
returns a non-zero exit code, a system exit is performed.
command (str / List[str]): The command. If provided as a string, the
string will be split using shlex.split.
stdin (Optional[Any]): stdin to read from or None.
capture (bool): Whether to capture the output and errors. If False,
the stdout and stderr will not be redirected, and if there's an error,
sys.exit will be called with the return code. You should use capture=False
when you want to turn over execution to the command, and capture=True
when you want to run the command more like a function.
RETURNS (Optional[CompletedProcess]): The process object.
"""