From 183d3de48a0d72be5f81e74eb3032df368080f97 Mon Sep 17 00:00:00 2001 From: Wouter De Borger Date: Mon, 4 Mar 2024 09:05:07 +0100 Subject: [PATCH] Improve deploy performance for very large models (Issue #7262, PR #7278) Pull request opened by the merge tool on behalf of #7278 --- changelogs/unreleased/7262-performance.yml | 7 + src/inmanta/agent/config.py | 2 +- src/inmanta/config.py | 28 +-- src/inmanta/data/__init__.py | 102 ++++++++--- src/inmanta/server/agentmanager.py | 2 +- src/inmanta/server/config.py | 10 +- .../server/services/orchestrationservice.py | 64 ++++--- .../server/services/resourceservice.py | 172 ++++++++++++------ src/inmanta/util/__init__.py | 12 ++ tests/agent_server/test_server_agent.py | 20 +- tests/test_server.py | 29 ++- tests/utils.py | 10 +- 12 files changed, 304 insertions(+), 154 deletions(-) create mode 100644 changelogs/unreleased/7262-performance.yml diff --git a/changelogs/unreleased/7262-performance.yml b/changelogs/unreleased/7262-performance.yml new file mode 100644 index 0000000000..485daeb668 --- /dev/null +++ b/changelogs/unreleased/7262-performance.yml @@ -0,0 +1,7 @@ +description: Improve deploy performance for very large models +issue-nr: 7262 +change-type: minor +destination-branches: [master, iso7] +sections: + minor-improvement: "{{description}}" + diff --git a/src/inmanta/agent/config.py b/src/inmanta/agent/config.py index a0ab4eaec7..3001876e3d 100644 --- a/src/inmanta/agent/config.py +++ b/src/inmanta/agent/config.py @@ -91,7 +91,7 @@ is_time, ) -agent_deploy_interval = Option( +agent_deploy_interval: Option[int | str] = Option( "config", "agent-deploy-interval", 0, diff --git a/src/inmanta/config.py b/src/inmanta/config.py index f5d3019c7f..dcec6b863b 100644 --- a/src/inmanta/config.py +++ b/src/inmanta/config.py @@ -133,19 +133,19 @@ def get(cls, section: Optional[str] = None, name: Optional[str] = None, default_ @classmethod def get_for_option(cls, option: "Option[T]") -> T: - raw_value: Optional[str | T] = cls._get_value(option.section, option.name, option.get_default_value()) + raw_value: str | T = cls._get_value(option.section, option.name, option.get_default_value()) return option.validate(raw_value) @classmethod - def _get_value(cls, section: str, name: str, default_value: Optional[T] = None) -> Optional[str | T]: + def _get_value(cls, section: str, name: str, default_value: T) -> str | T: cfg: ConfigParser = cls.get_instance() val: Optional[str] = _get_from_env(section, name) if val is not None: LOGGER.debug(f"Setting {section}:{name} was set using an environment variable") - else: - val = cfg.get(section, name, fallback=default_value) - - return val + return val + # Typing of this method in the sdk is not entirely accurate + # It just returns the fallback, whatever its type + return cfg.get(section, name, fallback=default_value) @classmethod def is_set(cls, section: str, name: str) -> bool: @@ -205,12 +205,12 @@ def is_float(value: str) -> float: return float(value) -def is_time(value: str) -> int: +def is_time(value: str | int) -> int: """Time, the number of seconds represented as an integer value""" return int(value) -def is_time_or_cron(value: str) -> Union[int, str]: +def is_time_or_cron(value: str | int) -> Union[int, str]: """Time, the number of seconds represented as an integer value or a cron-like expression""" try: return is_time(value) @@ -232,8 +232,10 @@ def is_bool(value: Union[bool, str]) -> bool: return boolean_states[value.lower()] -def is_list(value: str) -> list[str]: +def is_list(value: str | list[str]) -> list[str]: """List of comma-separated values""" + if isinstance(value, list): + return value return [] if value == "" else [x.strip() for x in value.split(",")] @@ -304,9 +306,9 @@ def __init__( self, section: str, name: str, - default: Union[T, None, Callable[[], T]], + default: Union[T, Callable[[], T]], documentation: str, - validator: Callable[[Optional[str | T]], T] = is_str, + validator: Callable[[str | T], T] = is_str, predecessor_option: Optional["Option"] = None, ) -> None: self.section = section @@ -342,10 +344,10 @@ def get_default_desc(self) -> str: else: return f"``{defa}``" - def validate(self, value: Optional[str | T]) -> T: + def validate(self, value: str | T) -> T: return self.validator(value) - def get_default_value(self) -> Optional[T]: + def get_default_value(self) -> T: defa = self.default if callable(defa): return defa() diff --git a/src/inmanta/data/__init__.py b/src/inmanta/data/__init__.py index 35988fd12f..41d323ecd3 100644 --- a/src/inmanta/data/__init__.py +++ b/src/inmanta/data/__init__.py @@ -4621,6 +4621,40 @@ def convert_or_ignore(rvid): ) return out + @classmethod + async def set_deployed_multi( + cls, + environment: uuid.UUID, + resource_ids: Sequence[m.ResourceIdStr], + version: int, + connection: Optional[asyncpg.connection.Connection] = None, + ) -> None: + query = "UPDATE resource SET status='deployed' WHERE environment=$1 AND model=$2 AND resource_id =ANY($3) " + async with cls.get_connection(connection) as connection: + await connection.execute(query, environment, version, resource_ids) + + @classmethod + async def get_resource_ids_with_status( + cls, + environment: uuid.UUID, + resource_version_ids: list[m.ResourceIdStr], + version: int, + statuses: Sequence[const.ResourceState], + lock: Optional[RowLockMode] = None, + connection: Optional[asyncpg.connection.Connection] = None, + ) -> list[m.ResourceIdStr]: + query = ( + "SELECT resource_id as resource_id FROM resource WHERE " + "environment=$1 AND model=$2 AND status = ANY($3) and resource_id =ANY($4) " + ) + if lock: + query += lock.value + async with cls.get_connection(connection) as connection: + return [ + m.ResourceIdStr(cast(str, r["resource_id"])) + for r in await connection.fetch(query, environment, version, statuses, resource_version_ids) + ] + @classmethod async def get_undeployable(cls, environment: uuid.UUID, version: int) -> list["Resource"]: """ @@ -4794,12 +4828,18 @@ async def get_resources_for_version_raw_with_persistent_state( cls, environment: uuid.UUID, version: int, - projection: Optional[list[str]], - projection_presistent: Optional[list[str]], + projection: Optional[list[typing.LiteralString]], + projection_presistent: Optional[list[typing.LiteralString]], + project_attributes: Optional[list[typing.LiteralString]] = None, *, connection: Optional[Connection] = None, ) -> list[dict[str, object]]: - """This method performs none of the mangling required to produce valid resources!""" + """This method performs none of the mangling required to produce valid resources! + + project_attributes performs a projection on the json attributes of the resources table + + all projections must be disjoint, as they become named fields in the output record + """ def collect_projection(projection: Optional[list[str]], prefix: str) -> str: if not projection: @@ -4807,16 +4847,23 @@ def collect_projection(projection: Optional[list[str]], prefix: str) -> str: else: return ",".join(f"{prefix}.{field}" for field in projection) + if project_attributes: + json_projection = "," + ",".join(f"r.attributes->'{v}' as {v}" for v in project_attributes) + else: + json_projection = "" + query = f""" - SELECT {collect_projection(projection, 'r')}, {collect_projection(projection_presistent, 'ps')} + SELECT {collect_projection(projection, 'r')}, {collect_projection(projection_presistent, 'ps')} {json_projection} FROM {cls.table_name()} r JOIN resource_persistent_state ps ON r.resource_id = ps.resource_id WHERE r.environment=$1 AND ps.environment = $1 and r.model = $2;""" resource_records = await cls._fetch_query(query, environment, version, connection=connection) resources = [dict(record) for record in resource_records] for res in resources: - if "attributes" in res: - res["attributes"] = json.loads(res["attributes"]) + if project_attributes: + for k in project_attributes: + if res[k]: + res[k] = json.loads(res[k]) return resources @classmethod @@ -5403,6 +5450,7 @@ async def get_list( no_obj: Optional[bool] = None, lock: Optional[RowLockMode] = None, connection: Optional[asyncpg.connection.Connection] = None, + no_status: bool = False, # don't load the status field **query: object, ) -> list["ConfigurationModel"]: # sanitize and validate order parameters @@ -5446,14 +5494,21 @@ async def get_list( {lock_statement}""" query_result = await cls._fetch_query(query_string, *values, connection=connection) result = [] - for record in query_result: - record = dict(record) + for in_record in query_result: + record = dict(in_record) if no_obj: - record["status"] = await cls._get_status_field(record["environment"], record["status"]) + if no_status: + record["status"] = {} + else: + record["status"] = await cls._get_status_field(record["environment"], record["status"]) result.append(record) else: done = record.pop("done") - status = await cls._get_status_field(record["environment"], record.pop("status")) + if no_status: + status = {} + record.pop("status") + else: + status = await cls._get_status_field(record["environment"], record.pop("status")) obj = cls(from_postgres=True, **record) obj._done = done obj._status = status @@ -5703,23 +5758,23 @@ async def get_increment( deployed and different hash -> increment """ # Depends on deploying - projection_a_resource = [ + projection_a_resource: list[typing.LiteralString] = [ "resource_id", "attribute_hash", - "attributes", "status", ] - projection_a_state = [ + projection_a_state: list[typing.LiteralString] = [ "last_success", "last_produced_events", "last_deployed_attribute_hash", "last_non_deploying_status", ] - projection = ["resource_id", "status", "attribute_hash"] + projection_a_attributes: list[typing.LiteralString] = ["requires", "send_event"] + projection: list[typing.LiteralString] = ["resource_id", "status", "attribute_hash"] # get resources for agent resources = await Resource.get_resources_for_version_raw_with_persistent_state( - environment, version, projection_a_resource, projection_a_state, connection=connection + environment, version, projection_a_resource, projection_a_state, projection_a_attributes, connection=connection ) # to increment @@ -5740,20 +5795,11 @@ async def get_increment( continue # Now outstanding events last_success = resource["last_success"] or DATETIME_MIN_UTC - attributes = resource["attributes"] - assert isinstance(attributes, dict) # mypy - for req in attributes["requires"]: + for req in resource["requires"]: req_res = id_to_resource[req] assert req_res is not None # todo - req_res_attributes = req_res["attributes"] - assert isinstance(req_res_attributes, dict) # mypy last_produced_events = req_res["last_produced_events"] - if ( - last_produced_events is not None - and last_produced_events > last_success - and "send_event" in req_res_attributes - and req_res_attributes["send_event"] - ): + if last_produced_events is not None and last_produced_events > last_success and req_res["send_event"]: in_increment = True break @@ -5839,9 +5885,9 @@ async def get_increment( # build lookup tables for res in resources: - for req in res["attributes"]["requires"]: + for req in res["requires"]: original_provides[req].append(res["resource_id"]) - if "send_event" in res["attributes"] and res["attributes"]["send_event"]: + if res["send_event"]: send_events.append(res["resource_id"]) # recursively include stuff potentially receiving events from nodes in the increment diff --git a/src/inmanta/server/agentmanager.py b/src/inmanta/server/agentmanager.py index 60c26119aa..7915a93bdf 100644 --- a/src/inmanta/server/agentmanager.py +++ b/src/inmanta/server/agentmanager.py @@ -1006,7 +1006,7 @@ async def _terminate_agents(self) -> None: async def _ensure_agents( self, env: data.Environment, - agents: list[str], + agents: Sequence[str], restart: bool = False, *, connection: Optional[asyncpg.connection.Connection] = None, diff --git a/src/inmanta/server/config.py b/src/inmanta/server/config.py index 329c058768..b01a75c283 100644 --- a/src/inmanta/server/config.py +++ b/src/inmanta/server/config.py @@ -244,7 +244,7 @@ def validate_fact_renew(value: object) -> int: "server", "purge-resource-action-logs-interval", 3600, "The number of seconds between resource-action log purging", is_time ) -server_resource_action_log_prefix = Option( +server_resource_action_log_prefix: Option[str] = Option( "server", "resource_action_log_prefix", "resource-actions-", @@ -252,10 +252,10 @@ def validate_fact_renew(value: object) -> int: is_str, ) -server_enabled_extensions = Option( +server_enabled_extensions: Option[list[str]] = Option( "server", "enabled_extensions", - "", + list, "A list of extensions the server must load. Core is always loaded." "If an extension listed in this list is not available, the server will refuse to start.", is_list, @@ -271,9 +271,9 @@ def validate_fact_renew(value: object) -> int: ) -def default_hangtime() -> str: +def default_hangtime() -> int: """:inmanta.config:option:`server.agent-timeout` *3/4""" - return str(int(agent_timeout.get() * 3 / 4)) + return int(agent_timeout.get() * 3 / 4) agent_hangtime = Option( diff --git a/src/inmanta/server/services/orchestrationservice.py b/src/inmanta/server/services/orchestrationservice.py index 7de4d7f9e0..5220dc04e8 100644 --- a/src/inmanta/server/services/orchestrationservice.py +++ b/src/inmanta/server/services/orchestrationservice.py @@ -16,6 +16,7 @@ Contact: code@inmanta.com """ +import asyncio import datetime import logging import uuid @@ -26,8 +27,8 @@ import asyncpg.connection import asyncpg.exceptions import pydantic -from asyncpg import Connection +import inmanta.util from inmanta import const, data from inmanta.const import ResourceState from inmanta.data import ( @@ -69,6 +70,8 @@ from inmanta.types import Apireturn, JsonType, PrimitiveTypes, ReturnTupple LOGGER = logging.getLogger(__name__) +PLOGGER = logging.getLogger("performance") + PERFORM_CLEANUP: bool = True # Kill switch for cleanup, for use when working with historical data @@ -411,7 +414,9 @@ async def _purge_versions(self) -> None: # get available versions n_versions = await env_item.get(AVAILABLE_VERSIONS_TO_KEEP, connection=connection) assert isinstance(n_versions, int) - versions = await data.ConfigurationModel.get_list(environment=env_item.id, connection=connection) + versions = await data.ConfigurationModel.get_list( + environment=env_item.id, connection=connection, no_status=True + ) if len(versions) > n_versions: LOGGER.info("Removing %s available versions from environment %s", len(versions) - n_versions, env_item.id) version_dict = {x.version: x for x in versions} @@ -652,7 +657,7 @@ async def _put_version( pip_config: Optional[PipConfig] = None, *, connection: asyncpg.connection.Connection, - ) -> None: + ) -> abc.Collection[str]: """ :param rid_to_resource: This parameter should contain all the resources when a full compile is done. When a partial compile is done, it should contain all the resources that belong to the @@ -666,6 +671,8 @@ async def _put_version( sets that are removed by the partial compile. When no resource sets are removed by a partial compile or when a full compile is done, this parameter can be set to None. + :return: all agents affected + Pre-conditions: * The requires and provides relationships of the resources in rid_to_resource must be set correctly. For a partial compile, this means it is assumed to be valid with respect to all absolute constraints that apply to @@ -818,18 +825,22 @@ async def _put_version( await ra.insert(connection=connection) LOGGER.debug("Successfully stored version %d", version) + return list(all_agents) async def _trigger_auto_deploy( self, env: data.Environment, version: int, *, - connection: Optional[Connection], + agents: Optional[abc.Sequence[str]] = None, ) -> None: """ Triggers auto-deploy for stored resources. Must be called only after transaction that stores resources has been allowed to commit. If not respected, the auto deploy might work on stale data, likely resulting in resources hanging in the deploying state. + + :argument agents: the list of agents we should restrict our notifications to. if it is None, we notify all agents if + PUSH_ON_AUTO_DEPLOY is set """ auto_deploy = await env.get(data.AUTO_DEPLOY) if auto_deploy: @@ -837,8 +848,8 @@ async def _trigger_auto_deploy( push_on_auto_deploy = cast(bool, await env.get(data.PUSH_ON_AUTO_DEPLOY)) agent_trigger_method_on_autodeploy = cast(str, await env.get(data.AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY)) agent_trigger_method_on_autodeploy = const.AgentTriggerMethod[agent_trigger_method_on_autodeploy] - await self.release_version( - env, version, push_on_auto_deploy, agent_trigger_method_on_autodeploy, connection=connection + self.add_background_task( + self.release_version(env, version, push_on_auto_deploy, agent_trigger_method_on_autodeploy, agents=agents) ) def _create_unknown_parameter_daos_from_api_unknowns( @@ -903,6 +914,8 @@ async def put_version( ) async with data.Resource.get_connection() as con: + # We don't allow connection reuse here, because the last line in this block can't tolerate a transaction + # assert not con.is_in_transaction() async with con.transaction(): # Acquire a lock that conflicts with the lock acquired by put_partial but not with itself await env.put_version_lock(shared=True, connection=con) @@ -916,13 +929,9 @@ async def put_version( pip_config=pip_config, connection=con, ) - try: - await self._trigger_auto_deploy(env, version, connection=con) - except Conflict as e: - # this should be an api warning, but this is not supported here - LOGGER.warning( - "Could not perform auto deploy on version %d in environment %s, because %s", version, env.id, e.log_message - ) + # This must be outside all transactions, as it relies on the result of _put_version + # and it starts a background task, so it can't re-use this connection + await self._trigger_auto_deploy(env, version) return 200 @@ -1033,14 +1042,13 @@ async def put_partial( # add shared resources merged_resources = partial_update_merger.merge_updated_and_shared_resources(list(rid_to_resource.values())) - await data.Code.copy_versions(env.id, base_version, version, connection=con) merged_unknowns = await partial_update_merger.merge_unknowns( unknowns_in_partial_compile=self._create_unknown_parameter_daos_from_api_unknowns(env.id, version, unknowns) ) - await self._put_version( + all_agents = await self._put_version( env, version, merged_resources, @@ -1054,14 +1062,7 @@ async def put_partial( ) returnvalue: ReturnValue[int] = ReturnValue[int](200, response=version) - try: - await self._trigger_auto_deploy(env, version, connection=con) - except Conflict as e: - # It is unclear if this condition can ever happen - LOGGER.warning( - "Could not perform auto deploy on version %d in environment %s, because %s", version, env.id, e.log_message - ) - returnvalue.add_warnings([f"Could not perform auto deploy: {e.log_message} {e.details}"]) + await self._trigger_auto_deploy(env, version, agents=all_agents) return returnvalue @@ -1074,9 +1075,14 @@ async def release_version( agent_trigger_method: Optional[const.AgentTriggerMethod] = None, *, connection: Optional[asyncpg.connection.Connection] = None, + agents: Optional[abc.Sequence[str]] = None, ) -> ReturnTupple: + """ + :param agents: agents that have to be notified by the push, defaults to all + """ async with data.ConfigurationModel.get_connection(connection) as connection: - async with connection.transaction(): + version_run_ahead_lock = asyncio.Event() + async with connection.transaction(), inmanta.util.FinallySet(version_run_ahead_lock): with ConnectionInTransaction(connection) as connection_holder: # explicit lock to allow patching of increments for stale failures # (locks out patching stage of deploy_done to avoid races) @@ -1144,15 +1150,15 @@ async def release_version( ) if latest_version: - increments: tuple[abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]] = ( + version, increment_ids, neg_increment, neg_increment_per_agent = ( await self.resource_service.get_increment( env, version_id, connection=connection, + run_ahead_lock=version_run_ahead_lock, ) ) - increment_ids, neg_increment = increments await self.resource_service.mark_deployed( env, neg_increment, now, version_id, connection=connection_holder ) @@ -1170,8 +1176,10 @@ async def release_version( # We can't be in a transaction here, or the agent will not see the data that as committed # This assert prevents anyone from wrapping this method in a transaction by accident assert not connection.is_in_transaction() - # fetch all resource in this cm and create a list of distinct agents - agents = await data.ConfigurationModel.get_agents(env.id, version_id, connection=connection) + + if agents is None: + # fetch all resource in this cm and create a list of distinct agents + agents = await data.ConfigurationModel.get_agents(env.id, version_id, connection=connection) await self.autostarted_agent_manager._ensure_agents(env, agents, connection=connection) for agent in agents: diff --git a/src/inmanta/server/services/resourceservice.py b/src/inmanta/server/services/resourceservice.py index e58c486ad6..f7db1aa5b5 100644 --- a/src/inmanta/server/services/resourceservice.py +++ b/src/inmanta/server/services/resourceservice.py @@ -20,6 +20,7 @@ import datetime import logging import os +import re import uuid from collections import abc, defaultdict from collections.abc import Sequence @@ -115,8 +116,19 @@ def __init__(self) -> None: self._resource_action_loggers: dict[uuid.UUID, logging.Logger] = {} self._resource_action_handlers: dict[uuid.UUID, logging.Handler] = {} - # Dict: environment_id: (model_version, increment, negative_increment) - self._increment_cache: dict[uuid.UUID, Optional[tuple[int, abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]]]] = {} + # Dict: environment_id: (model_version, increment, negative_increment, negative_increment_per_agent, run_ahead_lock) + self._increment_cache: dict[ + uuid.UUID, + Optional[ + tuple[ + int, + abc.Set[ResourceIdStr], + abc.Set[ResourceIdStr], + abc.Mapping[str, abc.Set[ResourceIdStr]], + Optional[asyncio.Event], + ] + ], + ] = {} # lock to ensure only one inflight request self._increment_cache_locks: dict[uuid.UUID, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) @@ -330,14 +342,10 @@ async def get_resource_increment_for_agent(self, env: data.Environment, agent: s if version is None: return 404, {"message": "No version available"} - increments: tuple[abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]] = await self.get_increment(env, version) - increment_ids, neg_increment = increments + version, increment_ids, neg_increment, neg_increment_per_agent = await self.get_increment(env, version) now = datetime.datetime.now().astimezone() - - def on_agent(res: ResourceIdStr) -> bool: - idr = Id.parse_id(res) - return idr.get_agent_name() == agent + ON_AGENT_REGEX = re.compile(rf"^[a-zA-Z0-9_:]+\[{re.escape(agent)},") # This is a bit subtle. # Any resource we consider deployed has to be marked as such. @@ -350,11 +358,10 @@ def on_agent(res: ResourceIdStr) -> bool: # As such, it should not race with backpropagation on failure. await self.mark_deployed( env, - neg_increment, + neg_increment_per_agent[agent], now, version, - filter=on_agent, - only_update_from_states={const.ResourceState.available, const.ResourceState.deploying}, + only_update_from_states=[const.ResourceState.available, const.ResourceState.deploying], ) resources = await data.Resource.get_resources_for_version(env.id, version, agent) @@ -366,12 +373,10 @@ def on_agent(res: ResourceIdStr) -> bool: if rv.resource_id not in increment_ids: continue - # TODO double parsing of ID def in_requires(req: ResourceIdStr) -> bool: if req in increment_ids: return True - idr = Id.parse_id(req) - return idr.get_agent_name() != agent + return ON_AGENT_REGEX.match(req) is None rv.attributes["requires"] = [r for r in rv.attributes["requires"] if in_requires(r)] deploy_model.append(rv.to_dict()) @@ -404,7 +409,7 @@ async def mark_deployed( version: int, filter: Callable[[ResourceIdStr], bool] = lambda x: True, connection: ConnectionMaybeInTransaction = ConnectionNotInTransaction(), - only_update_from_states: Optional[set[const.ResourceState]] = None, + only_update_from_states: Optional[Sequence[const.ResourceState]] = None, ) -> None: """ Set the status of the provided resources as deployed @@ -414,34 +419,75 @@ async def mark_deployed( :param version: Version of the resources to consider. :param filter: Filter function that takes a resource id as an argument and returns True if it should be kept. """ - resources_version_ids: list[ResourceVersionIdStr] = [ - ResourceVersionIdStr(f"{res_id},v={version}") for res_id in resources_id if filter(res_id) - ] - logline = { - "level": "INFO", - "msg": "Setting deployed due to known good status", - "timestamp": util.datetime_iso_format(timestamp), - "args": [], - } + if not resources_id: + return - await self.resource_action_update( - env, - resources_version_ids, - action_id=uuid.uuid4(), - started=timestamp, - finished=timestamp, - status=const.ResourceState.deployed, - # does this require a different ResourceAction? - action=const.ResourceAction.deploy, - changes={}, - messages=[logline], - change=const.Change.nochange, - send_events=False, - keep_increment_cache=True, - is_increment_notification=True, - only_update_from_states=only_update_from_states, - connection=connection, - ) + # performance-critical path: avoid parsing cost if we can + resources_id_filtered = [res_id for res_id in resources_id if filter(res_id)] + if not resources_id_filtered: + return + + action_id = uuid.uuid4() + + async with data.Resource.get_connection(connection.connection) as inner_connection: + async with inner_connection.transaction(): + # validate resources + if only_update_from_states is not None: + resources = await data.Resource.get_resource_ids_with_status( + env.id, + resources_id_filtered, + version, + only_update_from_states, + # acquire lock on Resource before read and before lock on ResourceAction to prevent conflicts with + # cascading deletes + lock=data.RowLockMode.FOR_NO_KEY_UPDATE, + connection=inner_connection, + ) + if not resources: + return None + + resources_version_ids: list[ResourceVersionIdStr] = [ + ResourceVersionIdStr(f"{res_id},v={version}") for res_id in resources_id_filtered + ] + + resource_action = data.ResourceAction( + environment=env.id, + version=version, + resource_version_ids=resources_version_ids, + action_id=action_id, + action=const.ResourceAction.deploy, + started=timestamp, + messages=[ + { + "level": "INFO", + "msg": "Setting deployed due to known good status", + "args": [], + "timestamp": timestamp.isoformat(timespec="microseconds"), + } + ], + changes={}, + status=const.ResourceState.deployed, + change=const.Change.nochange, + finished=timestamp, + ) + await resource_action.insert(connection=inner_connection) + self.log_resource_action( + env.id, + resources_version_ids, + const.LogLevel.INFO.to_int, + timestamp, + "Setting deployed due to known good status", + ) + + await data.Resource.set_deployed_multi(env.id, resources_id_filtered, version, connection=inner_connection) + # Resource persistent state should not be affected + + def post_deploy_update() -> None: + # Make sure tasks are scheduled AFTER the tx is done. + # This method is only called if the transaction commits successfully. + self.add_background_task(data.ConfigurationModel.mark_done_if_done(env.id, version)) + + connection.call_after_tx(post_deploy_update) async def _update_deploy_state( self, @@ -552,8 +598,12 @@ async def _update_deploy_state( ) async def get_increment( - self, env: data.Environment, version: int, connection: Optional[Connection] = None - ) -> tuple[abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]]: + self, + env: data.Environment, + version: int, + connection: Optional[Connection] = None, + run_ahead_lock: Optional[asyncio.Event] = None, + ) -> tuple[int, abc.Set[ResourceIdStr], abc.Set[ResourceIdStr], abc.Mapping[str, abc.Set[ResourceIdStr]]]: """ Get the increment for a given environment and a given version of the model from the _increment_cache if possible. In case of cache miss, the increment calculation is performed behind a lock to make sure it is only done once per @@ -563,31 +613,47 @@ async def get_increment( :param version: The version of the model to consider. :param connection: connection to use towards the DB. When the connection is in a transaction, we will always invalidate the cache + :param run_ahead_lock: lock used to keep agents hanging while building up the latest version """ - def _get_cache_entry() -> Optional[tuple[abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]]]: + async def _get_cache_entry() -> ( + Optional[tuple[int, abc.Set[ResourceIdStr], abc.Set[ResourceIdStr], abc.Mapping[str, abc.Set[ResourceIdStr]]]] + ): """ - Returns a tuple (increment, negative_increment) if a cache entry exists for the given environment and version + Returns a tuple (increment, negative_increment, negative_increment_per_agent) + if a cache entry exists for the given environment and version or None if no such cache entry exists. """ cache_entry = self._increment_cache.get(env.id, None) if cache_entry is None: # No cache entry found return None - (version_cache_entry, incr, neg_incr) = cache_entry - if version_cache_entry != version: + (version_cache_entry, incr, neg_incr, neg_incr_per_agent, cached_run_ahead_lock) = cache_entry + if version_cache_entry >= version: + assert not run_ahead_lock # We only expect a lock if WE are ahead + # Cache is ahead or equal + if cached_run_ahead_lock is not None: + await cached_run_ahead_lock.wait() + elif version_cache_entry != version: # Cache entry exists for another version + # Expire return None - return incr, neg_incr + return version_cache_entry, incr, neg_incr, neg_incr_per_agent - increment: Optional[tuple[abc.Set[ResourceIdStr], abc.Set[ResourceIdStr]]] = _get_cache_entry() + increment: Optional[ + tuple[int, abc.Set[ResourceIdStr], abc.Set[ResourceIdStr], abc.Mapping[str, abc.Set[ResourceIdStr]]] + ] = await _get_cache_entry() if increment is None or (connection is not None and connection.is_in_transaction()): lock = self._increment_cache_locks[env.id] async with lock: - increment = _get_cache_entry() + increment = await _get_cache_entry() if increment is None: - increment = await data.ConfigurationModel.get_increment(env.id, version, connection=connection) - self._increment_cache[env.id] = (version, *increment) + positive, negative = await data.ConfigurationModel.get_increment(env.id, version, connection=connection) + negative_per_agent: dict[str, set[ResourceIdStr]] = defaultdict(set) + for rid in negative: + negative_per_agent[Id.parse_id(rid).agent_name].add(rid) + increment = (version, positive, negative, negative_per_agent) + self._increment_cache[env.id] = (version, positive, negative, negative_per_agent, run_ahead_lock) return increment @handle(methods_v2.resource_deploy_done, env="tid", resource_id="rvid") diff --git a/src/inmanta/util/__init__.py b/src/inmanta/util/__init__.py index ee6a42b82b..def234b3ba 100644 --- a/src/inmanta/util/__init__.py +++ b/src/inmanta/util/__init__.py @@ -776,6 +776,18 @@ async def __aexit__(self, *excinfo: object) -> None: pass +class FinallySet(contextlib.AbstractAsyncContextManager[asyncio.Event]): + + def __init__(self, event: asyncio.Event) -> None: + self.event = event + + async def __aenter__(self) -> asyncio.Event: + return self.event + + async def __aexit__(self, *exc_info: object) -> None: + self.event.set() + + async def join_threadpools(threadpools: list[ThreadPoolExecutor]) -> None: """ Asynchronously join a set of threadpools diff --git a/tests/agent_server/test_server_agent.py b/tests/agent_server/test_server_agent.py index 39f3a516b1..b7a5af3c1c 100644 --- a/tests/agent_server/test_server_agent.py +++ b/tests/agent_server/test_server_agent.py @@ -1403,14 +1403,14 @@ def get_resources(version, value_resource_two): result = await client.set_setting(environment, data.AGENT_TRIGGER_METHOD_ON_AUTO_DEPLOY, agent_trigger_method) assert result.code == 200 - await clienthelper.put_version_simple(resources, version) + await clienthelper.put_version_simple(resources, version, wait_for_released=True) # check deploy result = await client.get_version(environment, version) assert result.code == 200 assert result.result["model"]["released"] assert result.result["model"]["total"] == 3 - assert result.result["model"]["result"] == "deploying" + assert result.result["model"]["result"] in ["deploying", "success"] await _wait_until_deployment_finishes(client, environment, version) @@ -1422,10 +1422,15 @@ def get_resources(version, value_resource_two): assert resource_container.Provider.get("agent1", "key2") == value_resource_two assert not resource_container.Provider.isset("agent1", "key3") - assert resource_container.Provider.readcount("agent1", "key1") == read_resource1 - assert resource_container.Provider.changecount("agent1", "key1") == change_resource1 - assert resource_container.Provider.readcount("agent1", "key2") == read_resource2 - assert resource_container.Provider.changecount("agent1", "key2") == change_resource2 + async def check_final() -> bool: + return ( + (resource_container.Provider.readcount("agent1", "key1") == read_resource1) + and (resource_container.Provider.changecount("agent1", "key1") == change_resource1) + and (resource_container.Provider.readcount("agent1", "key2") == read_resource2) + and (resource_container.Provider.changecount("agent1", "key2") == change_resource2) + ) + + await retry_limited(check_final, 1) async def test_auto_deploy_no_splay(server, client, clienthelper, resource_container, environment, no_agent_backoff): @@ -1556,8 +1561,7 @@ async def test_autostart_mapping(server, client, clienthelper, resource_containe }, ] - await clienthelper.put_version_simple(resources, version) - + await clienthelper.put_version_simple(resources, version, wait_for_released=True) # check deploy result = await client.get_version(environment, version) assert result.code == 200 diff --git a/tests/test_server.py b/tests/test_server.py index 5daa3eb2ee..9fc8851a29 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -18,6 +18,7 @@ import asyncio import base64 +import functools import json import logging import os @@ -1845,7 +1846,7 @@ async def test_put_stale_version(client, server, environment, clienthelper, capl v1 = await clienthelper.get_version() v2 = await clienthelper.get_version() - async def put_version(version): + async def put_version(version: int) -> int: partial = (version == v1 and v1_partial) or (version == v2 and v2_partial) if partial: @@ -1871,7 +1872,7 @@ async def put_version(version): version_info={}, ) assert result.code == 200 - + return result.result["data"] else: result = await client.put_version( tid=environment, @@ -1882,17 +1883,13 @@ async def put_version(version): compiler_version=get_compiler_version(), ) assert result.code == 200 - - await put_version(v0) - - with caplog.at_level(logging.WARNING): - await put_version(v2) - await put_version(v1) - log_contains( - caplog, - "inmanta", - logging.WARNING, - f"Could not perform auto deploy on version 2 in environment {environment}, " - f"because Request conflicts with the current state of the resource: " - f"The version 2 on environment {environment} is older then the latest released version", - ) + return version + + v0 = await put_version(v0) + await retry_limited(functools.partial(clienthelper.is_released, v0), timeout=1, interval=0.05) + v2 = await put_version(v2) + await retry_limited(functools.partial(clienthelper.is_released, v2), timeout=1, interval=0.05) + v1 = await put_version(v1) + # give it time to attempt to be release + await asyncio.sleep(0.1) + assert not await clienthelper.is_released(v1) diff --git a/tests/utils.py b/tests/utils.py index 5d395204a1..47c256f22b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -337,7 +337,7 @@ async def get_version(self) -> int: assert res.code == 200 return res.result["data"] - async def put_version_simple(self, resources: dict[str, Any], version: int) -> None: + async def put_version_simple(self, resources: dict[str, Any], version: int, wait_for_released: bool = False) -> None: res = await self.client.put_version( tid=self.environment, version=version, @@ -347,6 +347,14 @@ async def put_version_simple(self, resources: dict[str, Any], version: int) -> N compiler_version=get_compiler_version(), ) assert res.code == 200, res.result + if wait_for_released: + await retry_limited(functools.partial(self.is_released, version), timeout=0.2, interval=0.05) + + async def is_released(self, version: int) -> bool: + versions = await self.client.list_versions(tid=self.environment) + assert versions.code == 200 + lookup = {v["version"]: v["released"] for v in versions.result["versions"]} + return lookup[version] def get_resource(version: int, key: str = "key1", agent: str = "agent1", value: str = "value1") -> dict[str, Any]: