diff --git a/monkey/infection_monkey/island_api_client/http_client.py b/monkey/infection_monkey/island_api_client/http_client.py index d55c4ba5bf0..113e8bbbf53 100644 --- a/monkey/infection_monkey/island_api_client/http_client.py +++ b/monkey/infection_monkey/island_api_client/http_client.py @@ -1,6 +1,7 @@ import functools import logging from enum import Enum, auto +from http import HTTPStatus from typing import Any, Dict, Optional import requests @@ -8,9 +9,10 @@ from urllib3 import Retry from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT -from common.types import JSONSerializable, SocketAddress +from common.types import JSONSerializable from .island_api_client_errors import ( + IslandAPIAuthenticationError, IslandAPIConnectionError, IslandAPIError, IslandAPIRequestError, @@ -40,12 +42,17 @@ def decorated(*args, **kwargs): except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err: raise IslandAPIConnectionError(err) except requests.exceptions.HTTPError as err: + if err.response.status_code in [ + HTTPStatus.UNAUTHORIZED.value, + HTTPStatus.FORBIDDEN.value, + ]: + raise IslandAPIAuthenticationError(err) if 400 <= err.response.status_code < 500: raise IslandAPIRequestError(err) - elif 500 <= err.response.status_code < 600: + if 500 <= err.response.status_code < 600: raise IslandAPIRequestFailedError(err) - else: - raise IslandAPIError(err) + + raise IslandAPIError(err) except TimeoutError as err: raise IslandAPITimeoutError(err) except Exception as err: @@ -59,27 +66,22 @@ def __init__(self, retries=RETRIES): self._session = requests.Session() retry_config = Retry(retries) self._session.mount("https://", HTTPAdapter(max_retries=retry_config)) - self._api_url: Optional[str] = None + self._server_url: Optional[str] = None + self.additional_headers: Optional[Dict[str, Any]] = None - @handle_island_errors - def connect(self, island_server: SocketAddress): - try: - self._api_url = f"https://{island_server}/api" - # Don't use retries here, because we expect to not be able to connect. - response = requests.get( # noqa: DUO123 - f"{self._api_url}?action=is-up", - verify=False, - timeout=MEDIUM_REQUEST_TIMEOUT, - ) - response.raise_for_status() - except Exception as err: - logger.debug(f"Connection to {island_server} failed: {err}") - self._api_url = None - raise err + @property + def server_url(self): + return self._server_url + + @server_url.setter + def server_url(self, server_url: Optional[str]): + if server_url is not None and not server_url.startswith("https://"): + raise ValueError("Only HTTPS protocol is supported by HTTPClient") + self._server_url = server_url def get( self, - endpoint: str, + endpoint: str = "", params: Optional[Dict[str, Any]] = None, timeout=MEDIUM_REQUEST_TIMEOUT, *args, @@ -91,7 +93,7 @@ def get( def post( self, - endpoint: str, + endpoint: str = "", data: Optional[JSONSerializable] = None, timeout=MEDIUM_REQUEST_TIMEOUT, *args, @@ -103,7 +105,7 @@ def post( def put( self, - endpoint: str, + endpoint: str = "", data: Optional[JSONSerializable] = None, timeout=MEDIUM_REQUEST_TIMEOUT, *args, @@ -122,17 +124,15 @@ def _send_request( *args, **kwargs, ) -> requests.Response: - if self._api_url is None: - raise RuntimeError( - "HTTP client is not connected to the Island server," - "establish a connection with 'connect()' before " - "attempting to send any requests" - ) - url = f"{self._api_url}/{endpoint}".strip("/") + if self._server_url is None: + raise RuntimeError("HTTP client does not have a server URL set") + url = f"{self._server_url.strip('/')}/{endpoint.strip('/')}".strip("/") logger.debug(f"{request_type.name} {url}, timeout={timeout}") method = getattr(self._session, str.lower(request_type.name)) - response = method(url, *args, timeout=timeout, verify=False, **kwargs) + response = method( + url, *args, timeout=timeout, verify=False, headers=self.additional_headers, **kwargs + ) response.raise_for_status() return response diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client.py b/monkey/infection_monkey/island_api_client/http_island_api_client.py index eebb6e40546..71b9a807cb0 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client.py @@ -44,34 +44,52 @@ class HTTPIslandAPIClient(IIslandAPIClient): A client for the Island's HTTP API """ + TOKEN_HEADER_KEY = "Authentication-Token" + def __init__( - self, agent_event_serializer_registry: AgentEventSerializerRegistry, http_client: HTTPClient + self, + agent_event_serializer_registry: AgentEventSerializerRegistry, + http_client: HTTPClient, + otp: str, ): self._agent_event_serializer_registry = agent_event_serializer_registry - self.http_client = http_client + self._http_client = http_client + self._otp = otp def connect( self, island_server: SocketAddress, ): - self.http_client.connect(island_server) + try: + self._http_client.server_url = f"https://{island_server}/api/" + self._http_client.get(params={"action": "is-up"}) + except Exception as err: + self._http_client.server_url = None + raise err + + auth_token = self._get_authentication_token() + self._http_client.additional_headers = {HTTPIslandAPIClient.TOKEN_HEADER_KEY: auth_token} + + def _get_authentication_token(self) -> str: + response = self._http_client.post("/agent-otp-login", {"otp": self._otp}) + return response.json()["token"] def get_agent_binary(self, operating_system: OperatingSystem) -> bytes: os_name = operating_system.value - response = self.http_client.get(f"agent-binaries/{os_name}") + response = self._http_client.get(f"/agent-binaries/{os_name}") return response.content @handle_response_parsing_errors def get_otp(self) -> str: - response = self.http_client.get("agent-otp") + response = self._http_client.get("/agent-otp") return response.json()["otp"] @handle_response_parsing_errors def get_agent_plugin( self, operating_system: OperatingSystem, plugin_type: AgentPluginType, plugin_name: str ) -> AgentPlugin: - response = self.http_client.get( - f"agent-plugins/{operating_system.value}/{plugin_type.value}/{plugin_name}" + response = self._http_client.get( + f"/agent-plugins/{operating_system.value}/{plugin_type.value}/{plugin_name}" ) return AgentPlugin(**response.json()) @@ -80,26 +98,32 @@ def get_agent_plugin( def get_agent_plugin_manifest( self, plugin_type: AgentPluginType, plugin_name: str ) -> AgentPluginManifest: - response = self.http_client.get(f"agent-plugins/{plugin_type.value}/{plugin_name}/manifest") + response = self._http_client.get( + f"/agent-plugins/{plugin_type.value}/{plugin_name}/manifest" + ) return AgentPluginManifest(**response.json()) @handle_response_parsing_errors def get_agent_signals(self, agent_id: str) -> AgentSignals: - response = self.http_client.get(f"agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT) + response = self._http_client.get( + f"/agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT + ) return AgentSignals(**response.json()) @handle_response_parsing_errors def get_agent_configuration_schema(self) -> Dict[str, Any]: - response = self.http_client.get("agent-configuration-schema", timeout=SHORT_REQUEST_TIMEOUT) + response = self._http_client.get( + "/agent-configuration-schema", timeout=SHORT_REQUEST_TIMEOUT + ) schema = response.json() return schema @handle_response_parsing_errors def get_config(self) -> AgentConfiguration: - response = self.http_client.get("agent-configuration", timeout=SHORT_REQUEST_TIMEOUT) + response = self._http_client.get("/agent-configuration", timeout=SHORT_REQUEST_TIMEOUT) config_dict = response.json() logger.debug(f"Received configuration:\n{pformat(config_dict, sort_dicts=False)}") @@ -108,19 +132,19 @@ def get_config(self) -> AgentConfiguration: @handle_response_parsing_errors def get_credentials_for_propagation(self) -> Sequence[Credentials]: - response = self.http_client.get("propagation-credentials", timeout=SHORT_REQUEST_TIMEOUT) + response = self._http_client.get("/propagation-credentials", timeout=SHORT_REQUEST_TIMEOUT) return [Credentials(**credentials) for credentials in response.json()] def register_agent(self, agent_registration_data: AgentRegistrationData): - self.http_client.post( - "agents", + self._http_client.post( + "/agents", agent_registration_data.dict(simplify=True), SHORT_REQUEST_TIMEOUT, ) def send_events(self, events: Sequence[AbstractAgentEvent]): - self.http_client.post("agent-events", self._serialize_events(events)) + self._http_client.post("/agent-events", self._serialize_events(events)) def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable: serialized_events: List[JSONSerializable] = [] @@ -136,10 +160,10 @@ def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSeriali def send_heartbeat(self, agent_id: AgentID, timestamp: float): data = AgentHeartbeat(timestamp=timestamp).dict(simplify=True) - self.http_client.post(f"agent/{agent_id}/heartbeat", data) + self._http_client.post(f"/agent/{agent_id}/heartbeat", data) def send_log(self, agent_id: AgentID, log_contents: str): - self.http_client.put( - f"agent-logs/{agent_id}", + self._http_client.put( + f"/agent-logs/{agent_id}", log_contents, ) diff --git a/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py b/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py index 9570b9814ac..83ef0118ca2 100644 --- a/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py +++ b/monkey/infection_monkey/island_api_client/http_island_api_client_factory.py @@ -10,13 +10,11 @@ class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory): - def __init__( - self, - agent_event_serializer_registry: AgentEventSerializerRegistry, - ): + def __init__(self, agent_event_serializer_registry: AgentEventSerializerRegistry, otp: str): self._agent_event_serializer_registry = agent_event_serializer_registry + self._otp = otp def create_island_api_client(self) -> IIslandAPIClient: return ConfigurationValidatorDecorator( - HTTPIslandAPIClient(self._agent_event_serializer_registry, HTTPClient()) + HTTPIslandAPIClient(self._agent_event_serializer_registry, HTTPClient(), self._otp) ) diff --git a/monkey/infection_monkey/island_api_client/island_api_client_errors.py b/monkey/infection_monkey/island_api_client/island_api_client_errors.py index 9556d53800d..3c8b606a3f2 100644 --- a/monkey/infection_monkey/island_api_client/island_api_client_errors.py +++ b/monkey/infection_monkey/island_api_client/island_api_client_errors.py @@ -30,6 +30,14 @@ class IslandAPIRequestError(IslandAPIError): pass +class IslandAPIAuthenticationError(IslandAPIError): + """ + Raised when the authentication to the API failed + """ + + pass + + class IslandAPIRequestFailedError(IslandAPIError): """ Raised when the API request fails due to an error on the server diff --git a/monkey/infection_monkey/monkey.py b/monkey/infection_monkey/monkey.py index c9752469253..f474ae86964 100644 --- a/monkey/infection_monkey/monkey.py +++ b/monkey/infection_monkey/monkey.py @@ -114,6 +114,8 @@ def __init__(self, args, ipc_logger_queue: multiprocessing.Queue, log_path: Path self._manager = context.Manager() self._opts = self._get_arguments(args) + # TODO read the otp from an env variable + self._otp = "hard-coded-otp" self._ipc_logger_queue = ipc_logger_queue @@ -171,7 +173,8 @@ def _get_arguments(args): def _connect_to_island_api(self) -> Tuple[SocketAddress, IIslandAPIClient]: logger.debug(f"Trying to wake up with servers: {', '.join(map(str, self._opts.servers))}") server_clients = find_available_island_apis( - self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry) + self._opts.servers, + HTTPIslandAPIClientFactory(self._agent_event_serializer_registry, self._otp), ) server, island_api_client = self._select_server(server_clients) diff --git a/monkey/infection_monkey/network/relay/utils.py b/monkey/infection_monkey/network/relay/utils.py index 1435ae3dcb3..f0fa42376a3 100644 --- a/monkey/infection_monkey/network/relay/utils.py +++ b/monkey/infection_monkey/network/relay/utils.py @@ -12,6 +12,7 @@ IslandAPIError, IslandAPITimeoutError, ) +from infection_monkey.island_api_client.island_api_client_errors import IslandAPIAuthenticationError from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST from infection_monkey.utils.threading import ( ThreadSafeIterator, @@ -71,6 +72,8 @@ def _check_if_island_server( logger.error(f"Unable to connect to server/relay {server}: {err}") except IslandAPITimeoutError as err: logger.error(f"Timed out while connecting to server/relay {server}: {err}") + except IslandAPIAuthenticationError as err: + logger.error(f"Authentication to the {server} failed: {err}") except IslandAPIError as err: logger.error( f"Exception encountered when trying to connect to server/relay {server}: {err}" diff --git a/monkey/tests/data_for_tests/otp.py b/monkey/tests/data_for_tests/otp.py new file mode 100644 index 00000000000..3ce1289dc8d --- /dev/null +++ b/monkey/tests/data_for_tests/otp.py @@ -0,0 +1 @@ +OTP = "fake_otp" diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_client.py index 929e4478867..1c509b25fab 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_client.py @@ -15,15 +15,16 @@ IslandAPITimeoutError, ) from infection_monkey.island_api_client.http_client import RETRIES, HTTPClient +from infection_monkey.island_api_client.island_api_client_errors import IslandAPIAuthenticationError SERVER = SocketAddress(ip="1.1.1.1", port=9999) AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673") ISLAND_URI = f"https://{SERVER}/api?action=is-up" -LOG_ENDPOINT = f"agent-logs/{AGENT_ID}" -ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/{LOG_ENDPOINT}" -PROPAGATION_CREDENTIALS_ENDPOINT = "propagation-credentials" -ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/{PROPAGATION_CREDENTIALS_ENDPOINT}" +LOG_ENDPOINT = f"/agent-logs/{AGENT_ID}" +ISLAND_SEND_LOG_URI = f"https://{SERVER}/api{LOG_ENDPOINT}" +PROPAGATION_CREDENTIALS_ENDPOINT = "/propagation-credentials" +ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api{PROPAGATION_CREDENTIALS_ENDPOINT}" @pytest.fixture @@ -35,8 +36,8 @@ def request_mock_instance(): @pytest.fixture def connected_client(request_mock_instance): http_client = HTTPClient() + http_client.server_url = f"https://{SERVER}/api" request_mock_instance.get(ISLAND_URI) - http_client.connect(SERVER) return http_client @@ -56,10 +57,20 @@ def test_http_client__error_handling( connected_client.get(PROPAGATION_CREDENTIALS_ENDPOINT) +@pytest.mark.parametrize("server", ["http://1.1.1.1:5000", ""]) +def test_http_client__unsupported_protocol(server): + client = HTTPClient() + + with pytest.raises(ValueError): + client.server_url = server + + @pytest.mark.parametrize( "status_code, expected_error", [ - (401, IslandAPIRequestError), + (401, IslandAPIAuthenticationError), + (403, IslandAPIAuthenticationError), + (400, IslandAPIRequestError), (501, IslandAPIRequestFailedError), ], ) @@ -93,7 +104,7 @@ def test_http_client__unconnected(): def test_http_client__retries(monkeypatch): http_client = HTTPClient() # skip the connect method - http_client._api_url = f"https://{SERVER}/api" + http_client._server_url = f"https://{SERVER}/api" mock_send = MagicMock(side_effect=ConnectTimeoutError) # requests_mock can't be used for this, because it mocks higher level than we are testing monkeypatch.setattr("urllib3.connectionpool.HTTPSConnectionPool._validate_conn", mock_send) diff --git a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py index e06755b70f7..1274b50f169 100644 --- a/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py +++ b/monkey/tests/unit_tests/infection_monkey/island_api_client/test_http_island_api_client.py @@ -6,6 +6,7 @@ import pytest import requests from tests.common.example_agent_configuration import AGENT_CONFIGURATION +from tests.data_for_tests.otp import OTP from tests.data_for_tests.propagation_credentials import CREDENTIALS_DICTS from tests.unit_tests.common.agent_plugins.test_agent_plugin_manifest import ( FAKE_AGENT_MANIFEST_DICT, @@ -82,7 +83,7 @@ def agent_event_serializer_registry(): def build_api_client(http_client): - return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client) + return HTTPIslandAPIClient(agent_event_serializer_registry(), http_client, OTP) def _build_client_with_json_response(response): @@ -91,6 +92,44 @@ def _build_client_with_json_response(response): return build_api_client(client_stub) +def test_connect__connection_error(): + http_client_stub = MagicMock() + http_client_stub.get = MagicMock(side_effect=RuntimeError) + + api_client = build_api_client(http_client_stub) + + with pytest.raises(RuntimeError): + api_client.connect(SERVER) + assert api_client._http_client.server_url is None + + +def test_connect__authentication_error(): + http_client_stub = MagicMock() + http_client_stub.get = MagicMock() + http_client_stub.post = MagicMock(side_effect=RuntimeError) + api_client = build_api_client(http_client_stub) + with pytest.raises(RuntimeError): + api_client.connect(SERVER) + assert api_client._http_client.server_url is not None + + +def test_connect(): + fake_auth_token = "fake_auth_token" + http_client_stub = MagicMock() + http_client_stub.get = MagicMock() + http_client_stub.post = MagicMock() + http_client_stub.post.return_value.json.return_value = {"token": fake_auth_token} + api_client = build_api_client(http_client_stub) + + api_client.connect(SERVER) + + assert api_client._http_client.server_url is not None + assert ( + api_client._http_client.additional_headers[HTTPIslandAPIClient.TOKEN_HEADER_KEY] + == fake_auth_token + ) + + def test_island_api_client__get_agent_binary(): fake_binary = b"agent-binary" os = OperatingSystem.LINUX @@ -101,7 +140,7 @@ def test_island_api_client__get_agent_binary(): api_client = build_api_client(http_client_stub) assert api_client.get_agent_binary(os) == fake_binary - assert http_client_stub.get.called_with("agent-binaries/linux") + assert http_client_stub.get.called_with("/agent-binaries/linux") def test_island_api_client_send_events__serialization(): @@ -132,7 +171,7 @@ def test_island_api_client_send_events__serialization(): api_client.send_events(events=events_to_send) - assert client_spy.post.call_args[0] == ("agent-events", expected_json) + assert client_spy.post.call_args[0] == ("/agent-events", expected_json) def test_island_api_client_send_events__serialization_failed(): diff --git a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py index 6a7aba04d7e..e01871b2227 100644 --- a/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py +++ b/monkey/tests/unit_tests/infection_monkey/network/relay/test_utils.py @@ -1,5 +1,6 @@ import pytest import requests_mock +from tests.data_for_tests.otp import OTP from common.agent_event_serializers import AgentEventSerializerRegistry from common.types import SocketAddress @@ -21,7 +22,7 @@ @pytest.fixture def island_api_client_factory(): - return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry()) + return HTTPIslandAPIClientFactory(AgentEventSerializerRegistry(), OTP) @pytest.mark.parametrize( @@ -41,6 +42,7 @@ def test_find_available_island_apis( with requests_mock.Mocker() as mock: for server, response in server_response_pairs: mock.get(f"https://{server}/api?action=is-up", **response) + mock.post(f"https://{server}/api/agent-otp-login", json={"token": "fake-token"}) available_apis = find_available_island_apis(servers, island_api_client_factory) @@ -58,6 +60,7 @@ def test_find_available_island_apis__multiple_successes(island_api_client_factor with requests_mock.Mocker() as mock: mock.get(f"https://{SERVER_1}/api?action=is-up", exc=IslandAPIConnectionError) for server in available_servers: + mock.post(f"https://{server}/api/agent-otp-login", json={"token": "fake-token"}) mock.get(f"https://{server}/api?action=is-up", text="") available_apis = find_available_island_apis(servers, island_api_client_factory)