Skip to content

Commit

Permalink
Merge branch '3077-agent-fetch-auth-token' into develop
Browse files Browse the repository at this point in the history
Issue #3077
PR #3011
  • Loading branch information
VakarisZ committed Mar 16, 2023
2 parents 3876fc2 + 7185a67 commit d3aa582
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 66 deletions.
62 changes: 31 additions & 31 deletions monkey/infection_monkey/island_api_client/http_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import functools
import logging
from enum import Enum, auto
from http import HTTPStatus
from typing import Any, Dict, Optional

import requests
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from common.common_consts.timeouts import MEDIUM_REQUEST_TIMEOUT
from common.types import JSONSerializable, SocketAddress
from common.types import JSONSerializable

from .island_api_client_errors import (
IslandAPIAuthenticationError,
IslandAPIConnectionError,
IslandAPIError,
IslandAPIRequestError,
Expand Down Expand Up @@ -40,12 +42,17 @@ def decorated(*args, **kwargs):
except (requests.exceptions.ConnectionError, requests.exceptions.TooManyRedirects) as err:
raise IslandAPIConnectionError(err)
except requests.exceptions.HTTPError as err:
if err.response.status_code in [
HTTPStatus.UNAUTHORIZED.value,
HTTPStatus.FORBIDDEN.value,
]:
raise IslandAPIAuthenticationError(err)
if 400 <= err.response.status_code < 500:
raise IslandAPIRequestError(err)
elif 500 <= err.response.status_code < 600:
if 500 <= err.response.status_code < 600:
raise IslandAPIRequestFailedError(err)
else:
raise IslandAPIError(err)

raise IslandAPIError(err)
except TimeoutError as err:
raise IslandAPITimeoutError(err)
except Exception as err:
Expand All @@ -59,27 +66,22 @@ def __init__(self, retries=RETRIES):
self._session = requests.Session()
retry_config = Retry(retries)
self._session.mount("https://", HTTPAdapter(max_retries=retry_config))
self._api_url: Optional[str] = None
self._server_url: Optional[str] = None
self.additional_headers: Optional[Dict[str, Any]] = None

@handle_island_errors
def connect(self, island_server: SocketAddress):
try:
self._api_url = f"https://{island_server}/api"
# Don't use retries here, because we expect to not be able to connect.
response = requests.get( # noqa: DUO123
f"{self._api_url}?action=is-up",
verify=False,
timeout=MEDIUM_REQUEST_TIMEOUT,
)
response.raise_for_status()
except Exception as err:
logger.debug(f"Connection to {island_server} failed: {err}")
self._api_url = None
raise err
@property
def server_url(self):
return self._server_url

@server_url.setter
def server_url(self, server_url: Optional[str]):
if server_url is not None and not server_url.startswith("https://"):
raise ValueError("Only HTTPS protocol is supported by HTTPClient")
self._server_url = server_url

def get(
self,
endpoint: str,
endpoint: str = "",
params: Optional[Dict[str, Any]] = None,
timeout=MEDIUM_REQUEST_TIMEOUT,
*args,
Expand All @@ -91,7 +93,7 @@ def get(

def post(
self,
endpoint: str,
endpoint: str = "",
data: Optional[JSONSerializable] = None,
timeout=MEDIUM_REQUEST_TIMEOUT,
*args,
Expand All @@ -103,7 +105,7 @@ def post(

def put(
self,
endpoint: str,
endpoint: str = "",
data: Optional[JSONSerializable] = None,
timeout=MEDIUM_REQUEST_TIMEOUT,
*args,
Expand All @@ -122,17 +124,15 @@ def _send_request(
*args,
**kwargs,
) -> requests.Response:
if self._api_url is None:
raise RuntimeError(
"HTTP client is not connected to the Island server,"
"establish a connection with 'connect()' before "
"attempting to send any requests"
)
url = f"{self._api_url}/{endpoint}".strip("/")
if self._server_url is None:
raise RuntimeError("HTTP client does not have a server URL set")
url = f"{self._server_url.strip('/')}/{endpoint.strip('/')}".strip("/")
logger.debug(f"{request_type.name} {url}, timeout={timeout}")

method = getattr(self._session, str.lower(request_type.name))
response = method(url, *args, timeout=timeout, verify=False, **kwargs)
response = method(
url, *args, timeout=timeout, verify=False, headers=self.additional_headers, **kwargs
)
response.raise_for_status()

return response
60 changes: 42 additions & 18 deletions monkey/infection_monkey/island_api_client/http_island_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,52 @@ class HTTPIslandAPIClient(IIslandAPIClient):
A client for the Island's HTTP API
"""

TOKEN_HEADER_KEY = "Authentication-Token"

def __init__(
self, agent_event_serializer_registry: AgentEventSerializerRegistry, http_client: HTTPClient
self,
agent_event_serializer_registry: AgentEventSerializerRegistry,
http_client: HTTPClient,
otp: str,
):
self._agent_event_serializer_registry = agent_event_serializer_registry
self.http_client = http_client
self._http_client = http_client
self._otp = otp

def connect(
self,
island_server: SocketAddress,
):
self.http_client.connect(island_server)
try:
self._http_client.server_url = f"https://{island_server}/api/"
self._http_client.get(params={"action": "is-up"})
except Exception as err:
self._http_client.server_url = None
raise err

auth_token = self._get_authentication_token()
self._http_client.additional_headers = {HTTPIslandAPIClient.TOKEN_HEADER_KEY: auth_token}

def _get_authentication_token(self) -> str:
response = self._http_client.post("/agent-otp-login", {"otp": self._otp})
return response.json()["token"]

def get_agent_binary(self, operating_system: OperatingSystem) -> bytes:
os_name = operating_system.value
response = self.http_client.get(f"agent-binaries/{os_name}")
response = self._http_client.get(f"/agent-binaries/{os_name}")
return response.content

@handle_response_parsing_errors
def get_otp(self) -> str:
response = self.http_client.get("agent-otp")
response = self._http_client.get("/agent-otp")
return response.json()["otp"]

@handle_response_parsing_errors
def get_agent_plugin(
self, operating_system: OperatingSystem, plugin_type: AgentPluginType, plugin_name: str
) -> AgentPlugin:
response = self.http_client.get(
f"agent-plugins/{operating_system.value}/{plugin_type.value}/{plugin_name}"
response = self._http_client.get(
f"/agent-plugins/{operating_system.value}/{plugin_type.value}/{plugin_name}"
)

return AgentPlugin(**response.json())
Expand All @@ -80,26 +98,32 @@ def get_agent_plugin(
def get_agent_plugin_manifest(
self, plugin_type: AgentPluginType, plugin_name: str
) -> AgentPluginManifest:
response = self.http_client.get(f"agent-plugins/{plugin_type.value}/{plugin_name}/manifest")
response = self._http_client.get(
f"/agent-plugins/{plugin_type.value}/{plugin_name}/manifest"
)

return AgentPluginManifest(**response.json())

@handle_response_parsing_errors
def get_agent_signals(self, agent_id: str) -> AgentSignals:
response = self.http_client.get(f"agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT)
response = self._http_client.get(
f"/agent-signals/{agent_id}", timeout=SHORT_REQUEST_TIMEOUT
)

return AgentSignals(**response.json())

@handle_response_parsing_errors
def get_agent_configuration_schema(self) -> Dict[str, Any]:
response = self.http_client.get("agent-configuration-schema", timeout=SHORT_REQUEST_TIMEOUT)
response = self._http_client.get(
"/agent-configuration-schema", timeout=SHORT_REQUEST_TIMEOUT
)
schema = response.json()

return schema

@handle_response_parsing_errors
def get_config(self) -> AgentConfiguration:
response = self.http_client.get("agent-configuration", timeout=SHORT_REQUEST_TIMEOUT)
response = self._http_client.get("/agent-configuration", timeout=SHORT_REQUEST_TIMEOUT)

config_dict = response.json()
logger.debug(f"Received configuration:\n{pformat(config_dict, sort_dicts=False)}")
Expand All @@ -108,19 +132,19 @@ def get_config(self) -> AgentConfiguration:

@handle_response_parsing_errors
def get_credentials_for_propagation(self) -> Sequence[Credentials]:
response = self.http_client.get("propagation-credentials", timeout=SHORT_REQUEST_TIMEOUT)
response = self._http_client.get("/propagation-credentials", timeout=SHORT_REQUEST_TIMEOUT)

return [Credentials(**credentials) for credentials in response.json()]

def register_agent(self, agent_registration_data: AgentRegistrationData):
self.http_client.post(
"agents",
self._http_client.post(
"/agents",
agent_registration_data.dict(simplify=True),
SHORT_REQUEST_TIMEOUT,
)

def send_events(self, events: Sequence[AbstractAgentEvent]):
self.http_client.post("agent-events", self._serialize_events(events))
self._http_client.post("/agent-events", self._serialize_events(events))

def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSerializable:
serialized_events: List[JSONSerializable] = []
Expand All @@ -136,10 +160,10 @@ def _serialize_events(self, events: Sequence[AbstractAgentEvent]) -> JSONSeriali

def send_heartbeat(self, agent_id: AgentID, timestamp: float):
data = AgentHeartbeat(timestamp=timestamp).dict(simplify=True)
self.http_client.post(f"agent/{agent_id}/heartbeat", data)
self._http_client.post(f"/agent/{agent_id}/heartbeat", data)

def send_log(self, agent_id: AgentID, log_contents: str):
self.http_client.put(
f"agent-logs/{agent_id}",
self._http_client.put(
f"/agent-logs/{agent_id}",
log_contents,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@


class HTTPIslandAPIClientFactory(AbstractIslandAPIClientFactory):
def __init__(
self,
agent_event_serializer_registry: AgentEventSerializerRegistry,
):
def __init__(self, agent_event_serializer_registry: AgentEventSerializerRegistry, otp: str):
self._agent_event_serializer_registry = agent_event_serializer_registry
self._otp = otp

def create_island_api_client(self) -> IIslandAPIClient:
return ConfigurationValidatorDecorator(
HTTPIslandAPIClient(self._agent_event_serializer_registry, HTTPClient())
HTTPIslandAPIClient(self._agent_event_serializer_registry, HTTPClient(), self._otp)
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class IslandAPIRequestError(IslandAPIError):
pass


class IslandAPIAuthenticationError(IslandAPIError):
"""
Raised when the authentication to the API failed
"""

pass


class IslandAPIRequestFailedError(IslandAPIError):
"""
Raised when the API request fails due to an error on the server
Expand Down
5 changes: 4 additions & 1 deletion monkey/infection_monkey/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(self, args, ipc_logger_queue: multiprocessing.Queue, log_path: Path
self._manager = context.Manager()

self._opts = self._get_arguments(args)
# TODO read the otp from an env variable
self._otp = "hard-coded-otp"

self._ipc_logger_queue = ipc_logger_queue

Expand Down Expand Up @@ -171,7 +173,8 @@ def _get_arguments(args):
def _connect_to_island_api(self) -> Tuple[SocketAddress, IIslandAPIClient]:
logger.debug(f"Trying to wake up with servers: {', '.join(map(str, self._opts.servers))}")
server_clients = find_available_island_apis(
self._opts.servers, HTTPIslandAPIClientFactory(self._agent_event_serializer_registry)
self._opts.servers,
HTTPIslandAPIClientFactory(self._agent_event_serializer_registry, self._otp),
)

server, island_api_client = self._select_server(server_clients)
Expand Down
3 changes: 3 additions & 0 deletions monkey/infection_monkey/network/relay/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IslandAPIError,
IslandAPITimeoutError,
)
from infection_monkey.island_api_client.island_api_client_errors import IslandAPIAuthenticationError
from infection_monkey.network.relay import RELAY_CONTROL_MESSAGE_REMOVE_FROM_WAITLIST
from infection_monkey.utils.threading import (
ThreadSafeIterator,
Expand Down Expand Up @@ -71,6 +72,8 @@ def _check_if_island_server(
logger.error(f"Unable to connect to server/relay {server}: {err}")
except IslandAPITimeoutError as err:
logger.error(f"Timed out while connecting to server/relay {server}: {err}")
except IslandAPIAuthenticationError as err:
logger.error(f"Authentication to the {server} failed: {err}")
except IslandAPIError as err:
logger.error(
f"Exception encountered when trying to connect to server/relay {server}: {err}"
Expand Down
1 change: 1 addition & 0 deletions monkey/tests/data_for_tests/otp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OTP = "fake_otp"
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
IslandAPITimeoutError,
)
from infection_monkey.island_api_client.http_client import RETRIES, HTTPClient
from infection_monkey.island_api_client.island_api_client_errors import IslandAPIAuthenticationError

SERVER = SocketAddress(ip="1.1.1.1", port=9999)
AGENT_ID = UUID("80988359-a1cd-42a2-9b47-5b94b37cd673")

ISLAND_URI = f"https://{SERVER}/api?action=is-up"
LOG_ENDPOINT = f"agent-logs/{AGENT_ID}"
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api/{LOG_ENDPOINT}"
PROPAGATION_CREDENTIALS_ENDPOINT = "propagation-credentials"
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api/{PROPAGATION_CREDENTIALS_ENDPOINT}"
LOG_ENDPOINT = f"/agent-logs/{AGENT_ID}"
ISLAND_SEND_LOG_URI = f"https://{SERVER}/api{LOG_ENDPOINT}"
PROPAGATION_CREDENTIALS_ENDPOINT = "/propagation-credentials"
ISLAND_GET_PROPAGATION_CREDENTIALS_URI = f"https://{SERVER}/api{PROPAGATION_CREDENTIALS_ENDPOINT}"


@pytest.fixture
Expand All @@ -35,8 +36,8 @@ def request_mock_instance():
@pytest.fixture
def connected_client(request_mock_instance):
http_client = HTTPClient()
http_client.server_url = f"https://{SERVER}/api"
request_mock_instance.get(ISLAND_URI)
http_client.connect(SERVER)
return http_client


Expand All @@ -56,10 +57,20 @@ def test_http_client__error_handling(
connected_client.get(PROPAGATION_CREDENTIALS_ENDPOINT)


@pytest.mark.parametrize("server", ["http://1.1.1.1:5000", ""])
def test_http_client__unsupported_protocol(server):
client = HTTPClient()

with pytest.raises(ValueError):
client.server_url = server


@pytest.mark.parametrize(
"status_code, expected_error",
[
(401, IslandAPIRequestError),
(401, IslandAPIAuthenticationError),
(403, IslandAPIAuthenticationError),
(400, IslandAPIRequestError),
(501, IslandAPIRequestFailedError),
],
)
Expand Down Expand Up @@ -93,7 +104,7 @@ def test_http_client__unconnected():
def test_http_client__retries(monkeypatch):
http_client = HTTPClient()
# skip the connect method
http_client._api_url = f"https://{SERVER}/api"
http_client._server_url = f"https://{SERVER}/api"
mock_send = MagicMock(side_effect=ConnectTimeoutError)
# requests_mock can't be used for this, because it mocks higher level than we are testing
monkeypatch.setattr("urllib3.connectionpool.HTTPSConnectionPool._validate_conn", mock_send)
Expand Down
Loading

0 comments on commit d3aa582

Please sign in to comment.