diff --git a/changelogs/unreleased/ensure-agents-improvements.yml b/changelogs/unreleased/ensure-agents-improvements.yml new file mode 100644 index 0000000000..968afff3dc --- /dev/null +++ b/changelogs/unreleased/ensure-agents-improvements.yml @@ -0,0 +1,10 @@ +description: "Made various improvements to the AutostartedAgent._ensure_agents method" +sections: + bugfix: "Fixed a race condition where autostarted agents might become unresponsive for 30s when restarted" +issue-nr: 7612 +change-type: patch +destination-branches: + - master + - iso7 + - iso6 + diff --git a/src/inmanta/agent/agent.py b/src/inmanta/agent/agent.py index 122c0673f5..361c6ebf60 100644 --- a/src/inmanta/agent/agent.py +++ b/src/inmanta/agent/agent.py @@ -1168,16 +1168,18 @@ async def _init_agent_map(self) -> None: self.agent_map = cfg.agent_map.get() async def _init_endpoint_names(self) -> None: - if self.hostname is not None: - await self.add_end_point_name(self.hostname) - else: - # load agent names from the config file - agent_names = cfg.agent_names.get() - if agent_names is not None: - for name in agent_names: - if "$" in name: - name = name.replace("$node-name", self.node_name) - await self.add_end_point_name(name) + assert self.agent_map is not None + endpoints: Iterable[str] = ( + [self.hostname] + if self.hostname is not None + else ( + self.agent_map.keys() + if cfg.use_autostart_agent_map.get() + else (name if "$" not in name else name.replace("$node-name", self.node_name) for name in cfg.agent_names.get()) + ) + ) + for endpoint in endpoints: + await self.add_end_point_name(endpoint) async def stop(self) -> None: await super().stop() @@ -1255,6 +1257,13 @@ async def update_agent_map(self, agent_map: dict[str, str]) -> None: await self._update_agent_map(agent_map) async def _update_agent_map(self, agent_map: dict[str, str]) -> None: + if "internal" not in agent_map: + LOGGER.warning( + "Agent received an update_agent_map() trigger without internal agent in the agent_map %s", + agent_map, + ) + agent_map = {"internal": "local:", **agent_map} + async with self._instances_lock: self.agent_map = agent_map # Add missing agents diff --git a/src/inmanta/data/__init__.py b/src/inmanta/data/__init__.py index e4ef262185..6cdf687afe 100644 --- a/src/inmanta/data/__init__.py +++ b/src/inmanta/data/__init__.py @@ -29,7 +29,7 @@ import warnings from abc import ABC, abstractmethod from collections import abc, defaultdict -from collections.abc import Awaitable, Callable, Iterable, Sequence +from collections.abc import Awaitable, Callable, Iterable, Sequence, Set from configparser import RawConfigParser from contextlib import AbstractAsyncContextManager from itertools import chain @@ -1290,7 +1290,7 @@ def get_connection( """ if connection is not None: return util.nullcontext(connection) - # Make pypi happy + # Make mypy happy assert cls._connection_pool is not None return cls._connection_pool.acquire() @@ -3415,10 +3415,12 @@ def get_valid_field_names(cls) -> list[str]: return super().get_valid_field_names() + ["process_name", "status"] @classmethod - async def get_statuses(cls, env_id: uuid.UUID, agent_names: set[str]) -> dict[str, Optional[AgentStatus]]: + async def get_statuses( + cls, env_id: uuid.UUID, agent_names: Set[str], *, connection: Optional[asyncpg.connection.Connection] = None + ) -> dict[str, Optional[AgentStatus]]: result: dict[str, Optional[AgentStatus]] = {} for agent_name in agent_names: - agent = await cls.get_one(environment=env_id, name=agent_name) + agent = await cls.get_one(environment=env_id, name=agent_name, connection=connection) if agent: result[agent_name] = agent.get_status() else: diff --git a/src/inmanta/server/agentmanager.py b/src/inmanta/server/agentmanager.py index 7915a93bdf..57b518d822 100644 --- a/src/inmanta/server/agentmanager.py +++ b/src/inmanta/server/agentmanager.py @@ -23,7 +23,7 @@ import time import uuid from asyncio import queues, subprocess -from collections.abc import Iterable, Sequence +from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence, Set from datetime import datetime from enum import Enum from typing import Any, Optional, Union, cast @@ -632,20 +632,15 @@ def get_agent_client(self, tid: uuid.UUID, endpoint: str, live_agent_only: bool else: return None - async def are_agents_active(self, tid: uuid.UUID, endpoints: list[str]) -> bool: + async def expire_sessions_for_agents(self, env_id: uuid.UUID, endpoints: Set[str]) -> None: """ - Return true iff all the given agents are in the up or the paused state. + Expire all sessions for any of the requested agent endpoints. """ - return all(active for (_, active) in await self.get_agent_active_status(tid, endpoints)) - - async def get_agent_active_status(self, tid: uuid.UUID, endpoints: list[str]) -> list[tuple[str, bool]]: - """ - Return a list of tuples where the first element of the tuple contains the name of an endpoint - and the second a boolean indicating where there is an active (up or paused) agent for that endpoint. - """ - all_sids_for_env = [sid for (sid, session) in self.sessions.items() if session.tid == tid] - all_active_endpoints_for_env = {ep for sid in all_sids_for_env for ep in self.endpoints_for_sid[sid]} - return [(ep, ep in all_active_endpoints_for_env) for ep in endpoints] + async with self.session_lock: + sessions_to_expire: Iterator[protocol.Session] = ( + session for session in self.sessions.values() if endpoints & session.endpoint_names and session.tid == env_id + ) + await asyncio.gather(*(s.expire_and_abort(timeout=0) for s in sessions_to_expire)) async def expire_all_sessions_for_environment(self, env_id: uuid.UUID) -> None: async with self.session_lock: @@ -969,7 +964,7 @@ async def restart_agents(self, env: data.Environment) -> None: LOGGER.debug("Restarting agents in environment %s", env.id) agents = await data.Agent.get_list(environment=env.id) agent_list = [a.name for a in agents] - await self._ensure_agents(env, agent_list, True) + await self._ensure_agents(env, agent_list, restart=True) async def stop_agents(self, env: data.Environment) -> None: """ @@ -986,6 +981,51 @@ async def stop_agents(self, env: data.Environment) -> None: LOGGER.debug("Expiring all sessions for %s", env.id) await self._agent_manager.expire_all_sessions_for_environment(env.id) + async def _stop_autostarted_agents( + self, + env: data.Environment, + *, + connection: Optional[asyncpg.connection.Connection] = None, + ) -> None: + """ + Stop the autostarted agent process for this environment and expire all its sessions. + Does not expire non-autostarted agents' sessions. + + Must be called under the agent lock + """ + LOGGER.debug("Stopping all autostarted agents for env %s", env.id) + if env.id in self._agent_procs: + subproc = self._agent_procs[env.id] + self._stop_process(subproc) + await self._wait_for_proc_bounded([subproc]) + del self._agent_procs[env.id] + + # fetch the agent map after stopping the process to prevent races with agent map update notifying the process + agent_map: Mapping[str, str] = cast( + Mapping[str, str], await env.get(data.AUTOSTART_AGENT_MAP, connection=connection) + ) # we know the type of this map + + LOGGER.debug("Expiring sessions for autostarted agents %s", sorted(agent_map.keys())) + await self._agent_manager.expire_sessions_for_agents(env.id, agent_map.keys()) + + def _get_state_dir_for_agent_in_env(self, env_id: uuid.UUID) -> str: + """ + Return the state dir to be used by the auto-started agent in the given environment. + """ + state_dir: str = inmanta.config.state_dir.get() + return os.path.join(state_dir, str(env_id)) + + def _remove_venv_for_agent_in_env(self, env_id: uuid.UUID) -> None: + """ + Remove the venv for the auto-started agent in the given environment. + """ + agent_state_dir: str = self._get_state_dir_for_agent_in_env(env_id) + venv_dir: str = os.path.join(agent_state_dir, "agent", "env") + try: + shutil.rmtree(venv_dir) + except FileNotFoundError: + pass + def _stop_process(self, process: subprocess.Process) -> None: try: process.terminate() @@ -1006,62 +1046,86 @@ async def _terminate_agents(self) -> None: async def _ensure_agents( self, env: data.Environment, - agents: Sequence[str], - restart: bool = False, + agents: Collection[str], *, + restart: bool = False, connection: Optional[asyncpg.connection.Connection] = None, ) -> bool: """ Ensure that all agents defined in the current environment (model) and that should be autostarted, are started. :param env: The environment to start the agents for - :param agents: A list of agent names that possibly should be started in this environment. + :param agents: A list of agent names that should be running in this environment. Waits for the agents that are both in + this list and in the agent map to be active before returning. :param restart: Restart all agents even if the list of agents is up to date. + :param connection: The database connection to use. Must not be in a transaction context. + + :return: True iff a new agent process was started. """ if self._stopping: raise ShutdownInProgress() - agent_map: dict[str, str] = cast( - dict[str, str], await env.get(data.AUTOSTART_AGENT_MAP, connection=connection) - ) # we know the type of this map + if connection is not None and connection.is_in_transaction(): + # Should not be called in a transaction context because it has (immediate) side effects outside of the database + # that are tied to the database state. Several inconsistency issues could occur if this runs in a transaction + # context: + # - side effects based on oncommitted reads (may even need to be rolled back) + # - race condition with similar side effect flows due to stale reads (e.g. other flow pauses agent and kills + # process, this one brings it back because it reads the agent as unpaused) + raise Exception("_ensure_agents should not be called in a transaction context") + + async with data.Agent.get_connection(connection) as connection: + agent_map: Mapping[str, str] = cast( + Mapping[str, str], await env.get(data.AUTOSTART_AGENT_MAP, connection=connection) + ) # we know the type of this map + + autostart_agents: Set[str] = set(agents) & agent_map.keys() + if len(autostart_agents) == 0: + return False - agents = [agent for agent in agents if agent in agent_map] - needsstart = restart - if len(agents) == 0: - return False + async with self.agent_lock: + # silently ignore requests if this environment is halted + refreshed_env: Optional[data.Environment] = await data.Environment.get_by_id(env.id, connection=connection) + if refreshed_env is None: + raise Exception("Can't ensure agent: environment %s does not exist" % env.id) + env = refreshed_env + if env.halted: + return False + + start_new_process: bool + if env.id not in self._agent_procs or self._agent_procs[env.id].returncode is not None: + # Start new process if none is currently running for this environment. + # Otherwise trust that it tracks any changes to the agent map. + LOGGER.info("%s matches agents managed by server, ensuring they are started.", autostart_agents) + start_new_process = True + elif restart: + LOGGER.info( + "%s matches agents managed by server, forcing restart: stopping process with PID %s.", + autostart_agents, + self._agent_procs[env.id], + ) + await self._stop_autostarted_agents(env, connection=connection) + start_new_process = True + else: + start_new_process = False - async def is_start_agent_required() -> bool: - if needsstart: - return True - return not await self._agent_manager.are_agents_active(env.id, agents) + if start_new_process: + self._agent_procs[env.id] = await self.__do_start_agent(env, connection=connection) - async with self.agent_lock: - # silently ignore requests if this environment is halted - refreshed_env: Optional[data.Environment] = await data.Environment.get_by_id(env.id, connection=connection) - if refreshed_env is None: - raise Exception("Can't ensure agent: environment %s does not exist" % env.id) - env = refreshed_env - if env.halted: - return False - - if await is_start_agent_required(): - LOGGER.info("%s matches agents managed by server, ensuring it is started.", agents) - res = await self.__do_start_agent(agents, env, connection=connection) - return res - return False + # Wait for all agents to start + try: + await self._wait_for_agents(env, autostart_agents, connection=connection) + except asyncio.TimeoutError: + LOGGER.warning("Not all agent instances started successfully") + return start_new_process async def __do_start_agent( - self, agents: list[str], env: data.Environment, *, connection: Optional[asyncpg.connection.Connection] = None - ) -> bool: + self, env: data.Environment, *, connection: Optional[asyncpg.connection.Connection] = None + ) -> subprocess.Process: """ - Start an agent process for the given agents in the given environment - - Note: Always call under agent_lock + Start an autostarted agent process for the given environment. Should only be called if none is running yet. """ - agent_map: dict[str, str] - agent_map = cast(dict[str, str], await env.get(data.AUTOSTART_AGENT_MAP, connection=connection)) - config: str - config = await self._make_agent_config(env, agents, agent_map, connection=connection) + config: str = await self._make_agent_config(env, connection=connection) config_dir = os.path.join(self._server_storage["agents"], str(env.id)) if not os.path.exists(config_dir): @@ -1076,94 +1140,29 @@ async def __do_start_agent( agent_log = os.path.join(self._server_storage["logs"], "agent-%s.log" % env.id) - proc: Optional[subprocess.Process] = None - try: - proc = await self._fork_inmanta( - [ - "--log-file-level", - "DEBUG", - "--timed-logs", - "--config", - config_path, - "--config-dir", - Config._config_dir if Config._config_dir is not None else "", - "--log-file", - agent_log, - "agent", - ], - out, - err, - ) - - if env.id in self._agent_procs and self._agent_procs[env.id] is not None: - # If the return code is not None the process is already terminated - if self._agent_procs[env.id].returncode is None: - LOGGER.debug("Terminating old agent with PID %s", self._agent_procs[env.id].pid) - self._agent_procs[env.id].terminate() - await self._wait_for_proc_bounded([self._agent_procs[env.id]]) - self._agent_procs[env.id] = proc - except Exception as e: - # Prevent dangling processes - if proc is not None and proc.returncode is None: - proc.kill() - raise e - - async def _wait_until_agent_instances_are_active() -> None: - """ - Wait until all AgentInstances for the endpoints `agents` are active. - A TimeoutError is raised when not all AgentInstances are active and no new AgentInstance - became active in the last 5 seconds. - """ - agent_statuses: dict[str, Optional[AgentStatus]] = await data.Agent.get_statuses(env.id, set(agents)) - # Only wait for agents that are not paused - expected_agents_in_up_state: set[str] = { - agent_name - for agent_name, status in agent_statuses.items() - if status is not None and status is not AgentStatus.paused - } - actual_agents_in_up_state: set[str] = set() - started = int(time.time()) - last_new_agent_seen = started - last_log = started - - while len(expected_agents_in_up_state) != len(actual_agents_in_up_state): - await asyncio.sleep(0.1) - now = int(time.time()) - if now - last_new_agent_seen > AUTO_STARTED_AGENT_WAIT: - raise asyncio.TimeoutError() - if now - last_log > AUTO_STARTED_AGENT_WAIT_LOG_INTERVAL: - last_log = now - LOGGER.debug( - "Waiting for agent with PID %s, waited %d seconds, %d/%d instances up", - proc.pid, - now - started, - len(actual_agents_in_up_state), - len(expected_agents_in_up_state), - ) - new_actual_agents_in_up_state = { - agent_name - for agent_name in expected_agents_in_up_state - if (env.id, agent_name) in self._agent_manager.tid_endpoint_to_session - } - if len(new_actual_agents_in_up_state) > len(actual_agents_in_up_state): - # Reset timeout timer because a new instance became active - last_new_agent_seen = now - actual_agents_in_up_state = new_actual_agents_in_up_state + proc: subprocess.Process = await self._fork_inmanta( + [ + "--log-file-level", + "DEBUG", + "--timed-logs", + "--config", + config_path, + "--config-dir", + Config._config_dir if Config._config_dir is not None else "", + "--log-file", + agent_log, + "agent", + ], + out, + err, + ) LOGGER.debug("Started new agent with PID %s", proc.pid) - # Wait for all agents to start - try: - await _wait_until_agent_instances_are_active() - LOGGER.debug("Agent with PID %s is up", proc.pid) - except asyncio.TimeoutError: - LOGGER.warning("Timeout: agent with PID %s took too long to start", proc.pid) - return True + return proc async def _make_agent_config( self, env: data.Environment, - agent_names: list[str], - agent_map: dict[str, str], *, connection: Optional[asyncpg.connection.Connection], ) -> str: @@ -1171,8 +1170,6 @@ async def _make_agent_config( Generate the config file for the process that hosts the autostarted agents :param env: The environment for which to autostart agents - :param agent_names: The names of the agents - :param agent_map: The agent mapping to use :return: A string that contains the config file content. """ environment_id = str(env.id) @@ -1186,16 +1183,11 @@ async def _make_agent_config( agent_repair_splay: int = cast(int, await env.get(data.AUTOSTART_AGENT_REPAIR_SPLAY_TIME, connection=connection)) agent_repair_interval: str = cast(str, await env.get(data.AUTOSTART_AGENT_REPAIR_INTERVAL, connection=connection)) - # The internal agent always needs to have a session. Otherwise the agentmap update trigger doesn't work - if "internal" not in agent_names: - agent_names.append("internal") - # generate config file config = """[config] state-dir=%(statedir)s use_autostart_agent_map=true -agent-names = %(agents)s environment=%(env_id)s agent-deploy-splay-time=%(agent_deploy_splay)d @@ -1209,7 +1201,6 @@ async def _make_agent_config( port=%(port)s host=%(serveradress)s """ % { - "agents": ",".join(agent_names), "env_id": environment_id, "port": port, "statedir": privatestatedir, @@ -1273,6 +1264,75 @@ async def _fork_inmanta( if errhandle is not None: errhandle.close() + async def _wait_for_agents( + self, env: data.Environment, agents: Set[str], *, connection: Optional[asyncpg.connection.Connection] = None + ) -> None: + """ + Wait until all requested autostarted agent instances are active, e.g. after starting a new agent process. + + Must be called under the agent lock. + + :param env: The environment for which to wait for agents. + :param agents: Autostarted agent endpoints to wait for. + + :raises TimeoutError: When not all agent instances are active and no new agent instance became active in the last + 5 seconds. + """ + agent_statuses: dict[str, Optional[AgentStatus]] = await data.Agent.get_statuses(env.id, agents, connection=connection) + # Only wait for agents that are not paused + expected_agents_in_up_state: Set[str] = { + agent_name + for agent_name, status in agent_statuses.items() + if status is not None and status is not AgentStatus.paused + } + + assert env.id in self._agent_procs + proc = self._agent_procs[env.id] + + actual_agents_in_up_state: set[str] = set() + started = int(time.time()) + last_new_agent_seen = started + last_log = started + + while len(expected_agents_in_up_state) != len(actual_agents_in_up_state): + await asyncio.sleep(0.1) + now = int(time.time()) + if now - last_new_agent_seen > AUTO_STARTED_AGENT_WAIT: + LOGGER.warning( + "Timeout: agent with PID %s took too long to start: still waiting for agent instances %s", + proc.pid, + ",".join(sorted(expected_agents_in_up_state - actual_agents_in_up_state)), + ) + raise asyncio.TimeoutError() + if now - last_log > AUTO_STARTED_AGENT_WAIT_LOG_INTERVAL: + last_log = now + LOGGER.debug( + "Waiting for agent with PID %s, waited %d seconds, %d/%d instances up", + proc.pid, + now - started, + len(actual_agents_in_up_state), + len(expected_agents_in_up_state), + ) + new_actual_agents_in_up_state = { + agent_name + for agent_name in expected_agents_in_up_state + if ( + (session := self._agent_manager.tid_endpoint_to_session.get((env.id, agent_name), None)) is not None + # make sure to check for expiry because sessions are unregistered from the agent manager asynchronously + and not session.expired + ) + } + if len(new_actual_agents_in_up_state) > len(actual_agents_in_up_state): + # Reset timeout timer because a new instance became active + last_new_agent_seen = now + actual_agents_in_up_state = new_actual_agents_in_up_state + + LOGGER.debug( + "Agent process with PID %s is up for agent instances %s", + proc.pid, + ",".join(sorted(expected_agents_in_up_state)), + ) + async def notify_agent_about_agent_map_update(self, env: data.Environment) -> None: agent_client = self._agent_manager.get_agent_client(tid=env.id, endpoint="internal", live_agent_only=False) if agent_client: diff --git a/tests/agent_server/test_server_agent.py b/tests/agent_server/test_server_agent.py index 162658a0df..9254e58e6c 100644 --- a/tests/agent_server/test_server_agent.py +++ b/tests/agent_server/test_server_agent.py @@ -21,6 +21,7 @@ import logging import time import uuid +from collections.abc import Mapping from functools import partial from itertools import groupby from logging import DEBUG @@ -527,9 +528,7 @@ async def test_env_setting_wiring_to_autostarted_agent( env = await data.Environment.get_by_id(env_id) autostarted_agent_manager = server.get_slice(SLICE_AUTOSTARTED_AGENT_MANAGER) - config = await autostarted_agent_manager._make_agent_config( - env, agent_names=[], agent_map={"internal": ""}, connection=None - ) + config = await autostarted_agent_manager._make_agent_config(env, connection=None) assert f"agent-deploy-interval={interval}" in config assert f"agent-repair-interval={interval}" in config @@ -1648,6 +1647,53 @@ async def assert_session_state(expected_agent_states: dict[str, AgentStatus], ex assert len(new_agent_processes) == 0, new_agent_processes +@pytest.mark.parametrize("autostarted", (True, False)) +async def test_autostart_mapping_overrides_config(server, client, environment, async_finalizer, caplog, autostarted: bool): + """ + Verify that the use_autostart_agent_map setting takes precedence over agents configured in the config file. + When the option is set the server's agent map should be the authority for which agents to manage. + """ + # configure agent as an autostarted agent or not + agent_config.use_autostart_agent_map.set(str(autostarted).lower()) + # also configure the agent with an explicit agent config and agent map, which should be ignored + configured_agent: str = "configured_agent" + agent_config.agent_names.set(configured_agent) + + env_uuid = uuid.UUID(environment) + agent_manager = server.get_slice(SLICE_AGENT_MANAGER) + + # configure server's autostarted agent map + autostarted_agent: str = "autostarted_agent" + result = await client.set_setting( + env_uuid, + data.AUTOSTART_AGENT_MAP, + {"internal": "localhost", autostarted_agent: "localhost"}, + ) + assert result.code == 200 + + # Start agent + a = agent.Agent(environment=env_uuid, code_loader=False) + await a.start() + async_finalizer(a.stop) + + # Wait until agents are up + await retry_limited(lambda: len(agent_manager.tid_endpoint_to_session) == (2 if autostarted else 1), timeout=2) + + endpoint_sessions: Mapping[str, UUID] = { + key[1]: session.id for key, session in agent_manager.tid_endpoint_to_session.items() + } + assert endpoint_sessions == ( + { + "internal": a.sessionid, + autostarted_agent: a.sessionid, + } + if autostarted + else { + configured_agent: a.sessionid, + } + ) + + async def test_autostart_mapping_update_uri(server, client, environment, async_finalizer, caplog): caplog.set_level(logging.INFO) agent_config.use_autostart_agent_map.set("true") diff --git a/tests/test_agent_manager.py b/tests/test_agent_manager.py index a87f7fe6e3..940a2ba058 100644 --- a/tests/test_agent_manager.py +++ b/tests/test_agent_manager.py @@ -1207,34 +1207,6 @@ async def _dummy_fork_inmanta( assert exception_message in str(excinfo.value) -async def test_are_agents_active(server, client, environment, agent_factory) -> None: - """ - Ensure that the `AgentManager.are_agents_active()` method returns True when an agent - is in the up or the paused state. - """ - agentmanager = server.get_slice(SLICE_AGENT_MANAGER) - agent_name = "agent1" - env_id = UUID(environment) - env = await data.Environment.get_by_id(env_id) - - # The agent is not started yet -> it should not be active - assert not await agentmanager.are_agents_active(tid=env_id, endpoints=[agent_name]) - - # Start agent - await agentmanager.ensure_agent_registered(env, agent_name) - await agent_factory(environment=environment, agent_map={agent_name: ""}, agent_names=[agent_name]) - - # Verify agent is active - await retry_limited(agentmanager.are_agents_active, tid=env_id, endpoints=[agent_name], timeout=10) - - # Pause agent - result = await client.agent_action(tid=env_id, name=agent_name, action=AgentAction.pause.value) - assert result.code == 200, result.result - - # Ensure the agent is still active - await retry_limited(agentmanager.are_agents_active, tid=env_id, endpoints=[agent_name], timeout=10) - - async def test_dont_start_paused_agent(server, client, environment, caplog) -> None: """ Ensure that the AutostartedAgentManager doesn't try to start an agent that is paused (inmanta/inmanta-core#4398). @@ -1269,6 +1241,9 @@ async def test_dont_start_paused_agent(server, client, environment, caplog) -> N result = await client.agent_action(tid=env_id, name=agent_name, action=AgentAction.pause.value) assert result.code == 200, result.result + # Pausing an agent should have no direct effect on the autostarted agent manager + assert len(autostarted_agent_manager._agent_procs) == 1 + # Execute _ensure_agents() again and verify that no restart is triggered caplog.clear() await autostarted_agent_manager._ensure_agents(env=env, agents=[agent_name]) @@ -1278,6 +1253,99 @@ async def test_dont_start_paused_agent(server, client, environment, caplog) -> N assert "took too long to start" not in caplog.text +async def test_wait_for_agent_map_update(server, client, environment) -> None: + """ + Verify that when _ensure_agents is called with an agent that is still starting, we wait for it rather than to kill the + current process and to start a new one. + """ + env_id: UUID = UUID(environment) + agent1: str = "agent1" + agent2: str = "agent2" + agent_manager = server.get_slice(SLICE_AGENT_MANAGER) + autostarted_agent_manager = server.get_slice(SLICE_AUTOSTARTED_AGENT_MANAGER) + + # Register agents in model + env = await data.Environment.get_by_id(env_id) + assert env is not None + await agent_manager.ensure_agent_registered(env=env, nodename=agent1) + await agent_manager.ensure_agent_registered(env=env, nodename=agent2) + + # Add agent1 to AUTOSTART_AGENT_MAP + result = await client.set_setting(tid=environment, id=data.AUTOSTART_AGENT_MAP, value={"internal": "", agent1: ""}) + assert result.code == 200, result.result + + env: Optional[data.Environment] + + # Start agent1 + assert (env_id, agent1) not in agent_manager.tid_endpoint_to_session + env = await data.Environment.get_by_id(env_id) + assert env is not None + started: bool = await autostarted_agent_manager._ensure_agents(env=env, agents=[agent1]) + assert started + assert (env_id, agent1) in agent_manager.tid_endpoint_to_session + assert len(autostarted_agent_manager._agent_procs) == 1 + + # Add agent2 to AUTOSTART_AGENT_MAP + result = await client.set_setting( + tid=environment, + id=data.AUTOSTART_AGENT_MAP, + value={"internal": "", agent1: "", agent2: ""}, + ) + assert result.code == 200, result.result + + # Call _ensure_agents with agent2 very shortly after adding it to the agent map (before the new instance has connected) + env = await data.Environment.get_by_id(env_id) + assert env is not None + started = await autostarted_agent_manager._ensure_agents(env=env, agents=[agent1, agent2]) + assert (env_id, agent2) in agent_manager.tid_endpoint_to_session + # Verify that we did not start a new process + assert not started + + +async def test_expire_sessions_on_restart(server, client, environment) -> None: + """ + Verify that when _ensure_agents is called for an explicit restart, we properly expire the old session instead of letting + it time out. + """ + env_id: UUID = UUID(environment) + agent_name: str = "agent1" + agent_manager = server.get_slice(SLICE_AGENT_MANAGER) + autostarted_agent_manager = server.get_slice(SLICE_AUTOSTARTED_AGENT_MANAGER) + + # Register agents in model + env = await data.Environment.get_by_id(env_id) + assert env is not None + await agent_manager.ensure_agent_registered(env=env, nodename=agent_name) + + # Add agent1 to AUTOSTART_AGENT_MAP + result = await client.set_setting(tid=environment, id=data.AUTOSTART_AGENT_MAP, value={"internal": "", agent_name: ""}) + assert result.code == 200, result.result + + env: Optional[data.Environment] + + # Start agent1 + assert (env_id, agent_name) not in agent_manager.tid_endpoint_to_session + env = await data.Environment.get_by_id(env_id) + assert env is not None + started: bool = await autostarted_agent_manager._ensure_agents(env=env, agents=[agent_name]) + assert started + current_session: Optional[protocol.Session] = agent_manager.tid_endpoint_to_session.get((env_id, agent_name), None) + assert current_session is not None + current_process: object = autostarted_agent_manager._agent_procs.get(env_id, None) + assert current_process is not None + + # restart agent1 + started = await autostarted_agent_manager._ensure_agents(env=env, agents=[agent_name], restart=True) + assert started + new_session: Optional[protocol.Session] = agent_manager.tid_endpoint_to_session.get((env_id, agent_name), None) + assert new_session is not None + new_process: object = autostarted_agent_manager._agent_procs.get(env_id, None) + assert new_process is not None + # assertions: should have started a new process, and expired the old session, then wait until the new one becomes active + assert new_process is not current_process + assert current_session.id != new_session.id + + async def test_auto_started_agent_log_in_debug_mode(server, environment): """ Test the logging of an autostarted agent