diff --git a/colcon_core/extension_point.py b/colcon_core/extension_point.py index c724d221..4bea3fc6 100644 --- a/colcon_core/extension_point.py +++ b/colcon_core/extension_point.py @@ -3,8 +3,11 @@ # Licensed under the Apache License, Version 2.0 from collections import defaultdict +from itertools import chain import os +import sys import traceback +import warnings try: from importlib.metadata import distributions @@ -26,7 +29,6 @@ logger = colcon_logger.getChild(__name__) - """ The group name for entry points identifying colcon extension points. @@ -36,6 +38,8 @@ """ EXTENSION_POINT_GROUP_NAME = 'colcon_core.extension_point' +_ENTRY_POINTS_CACHE = [] + def _get_unique_distributions(): seen = set() @@ -46,6 +50,50 @@ def _get_unique_distributions(): yield dist +def _get_entry_points(): + for dist in _get_unique_distributions(): + for entry_point in dist.entry_points: + # Modern EntryPoint instances should already have this set + if not hasattr(entry_point, 'dist'): + entry_point.dist = dist + yield entry_point + + +def _get_cached_entry_points(): + if not _ENTRY_POINTS_CACHE: + if sys.version_info >= (3, 10): + # We prefer using importlib.metadata.entry_points because it + # has an internal optimization which allows us to load the entry + # points without reading the individual PKG-INFO files, while + # still visiting each unique distribution only once. + all_entry_points = entry_points() + if isinstance(all_entry_points, dict): + # Prior to Python 3.12, entry_points returned a (deprecated) + # dict. Unfortunately, the "future-proof" recommended + # pattern is to add filter parameters, but we actually + # want to cache everything so that doesn't work here. + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + 'SelectableGroups dict interface is deprecated', + DeprecationWarning, + module=__name__) + all_entry_points = chain.from_iterable( + all_entry_points.values()) + _ENTRY_POINTS_CACHE.extend(all_entry_points) + else: + # If we don't have Python 3.10, we must read each PKG-INFO to + # get the name of the distribution so that we can skip the + # "shadowed" distributions properly. + _ENTRY_POINTS_CACHE.extend(_get_entry_points()) + return _ENTRY_POINTS_CACHE + + +def clear_entry_point_cache(): + """Purge the entry point cache.""" + _ENTRY_POINTS_CACHE.clear() + + def get_all_extension_points(): """ Get all extension points related to `colcon` and any of its extensions. @@ -59,23 +107,24 @@ def get_all_extension_points(): colcon_extension_points = get_extension_points(EXTENSION_POINT_GROUP_NAME) colcon_extension_points.setdefault(EXTENSION_POINT_GROUP_NAME, None) - entry_points = defaultdict(dict) - for dist in _get_unique_distributions(): - for entry_point in dist.entry_points: - # skip groups which are not registered as extension points - if entry_point.group not in colcon_extension_points: - continue - - if entry_point.name in entry_points[entry_point.group]: - previous = entry_points[entry_point.group][entry_point.name] - logger.error( - f"Entry point '{entry_point.group}.{entry_point.name}' is " - f"declared multiple times, '{entry_point.value}' " - f"from '{dist._path}' " - f"overwriting '{previous}'") - entry_points[entry_point.group][entry_point.name] = \ - (entry_point.value, dist.metadata['Name'], dist.version) - return entry_points + extension_points = defaultdict(dict) + for entry_point in _get_cached_entry_points(): + if entry_point.group not in colcon_extension_points: + continue + + dist_metadata = entry_point.dist.metadata + ep_tuple = ( + entry_point.value, + dist_metadata['Name'], dist_metadata['Version'], + ) + if entry_point.name in extension_points[entry_point.group]: + previous = extension_points[entry_point.group][entry_point.name] + logger.error( + f"Entry point '{entry_point.group}.{entry_point.name}' is " + f"declared multiple times, '{ep_tuple}' " + f"overwriting '{previous}'") + extension_points[entry_point.group][entry_point.name] = ep_tuple + return extension_points def get_extension_points(group): @@ -87,16 +136,9 @@ def get_extension_points(group): :rtype: dict """ extension_points = {} - try: - # Python 3.10 and newer - query = entry_points(group=group) - except TypeError: - query = ( - entry_point - for dist in _get_unique_distributions() - for entry_point in dist.entry_points - if entry_point.group == group) - for entry_point in query: + for entry_point in _get_cached_entry_points(): + if entry_point.group != group: + continue if entry_point.name in extension_points: previous_entry_point = extension_points[entry_point.name] logger.error( diff --git a/test/spell_check.words b/test/spell_check.words index d44f7a23..96051a0f 100644 --- a/test/spell_check.words +++ b/test/spell_check.words @@ -50,6 +50,7 @@ importlib importorskip isatty iterdir +itertools junit levelname libexec diff --git a/test/test_extension_point.py b/test/test_extension_point.py index 6d626961..63f89edb 100644 --- a/test/test_extension_point.py +++ b/test/test_extension_point.py @@ -12,6 +12,7 @@ # TODO: Drop this with Python 3.7 support from importlib_metadata import Distribution +from colcon_core.extension_point import clear_entry_point_cache from colcon_core.extension_point import EntryPoint from colcon_core.extension_point import EXTENSION_POINT_GROUP_NAME from colcon_core.extension_point import get_all_extension_points @@ -73,6 +74,8 @@ def test_all_extension_points(): 'colcon_core.extension_point.distributions', side_effect=_distributions ): + clear_entry_point_cache() + # successfully load a known entry point extension_points = get_all_extension_points() assert set(extension_points.keys()) == { @@ -94,12 +97,14 @@ def test_extension_point_blocklist(): 'colcon_core.extension_point.distributions', side_effect=_distributions ): + clear_entry_point_cache() extension_points = get_extension_points('group1') assert 'extA' in extension_points.keys() extension_point = extension_points['extA'] assert extension_point == 'eA' with patch.object(EntryPoint, 'load', return_value=None) as load: + clear_entry_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 @@ -108,12 +113,14 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extB', 'group2.extC']) ): + clear_entry_point_cache() load_extension_point('extA', 'eA', 'group1') assert load.call_count == 1 # entry point in a blocked group can't be loaded load.reset_mock() with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST='group1'): + clear_entry_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point group name is listed in the environment ' \ @@ -124,6 +131,7 @@ def test_extension_point_blocklist(): with EnvironmentContext(COLCON_EXTENSION_BLOCKLIST=os.pathsep.join([ 'group1.extA', 'group1.extB']) ): + clear_entry_point_cache() with pytest.raises(RuntimeError) as e: load_extension_point('extA', 'eA', 'group1') assert 'The entry point name is listed in the environment ' \ @@ -131,6 +139,38 @@ def test_extension_point_blocklist(): assert load.call_count == 0 +def test_redefined_extension_point(): + def _duped_distributions(): + yield from _distributions() + yield _FakeDistribution({ + 'group2': [('extC', 'eC-prime')], + }) + + def _duped_entry_points(): + for dist in _duped_distributions(): + yield from dist.entry_points + + with patch('colcon_core.extension_point.logger.error') as error: + with patch( + 'colcon_core.extension_point.entry_points', + side_effect=_duped_entry_points + ): + with patch( + 'colcon_core.extension_point.distributions', + side_effect=_duped_distributions + ): + clear_entry_point_cache() + extension_points = get_all_extension_points() + assert 'eC-prime' == extension_points['group2']['extC'][0] + assert error.call_count == 1 + + error.reset_mock() + clear_entry_point_cache() + extension_points = get_extension_points('group2') + assert 'eC-prime' == extension_points.get('extC') + assert error.call_count == 1 + + def entry_point_load(self, *args, **kwargs): if self.name == 'exception': raise Exception('entry point raising exception')