From a059aa7dcb58759b24ca1eabd10f2da0262c01fb Mon Sep 17 00:00:00 2001 From: Eoin Power-Moran Date: Thu, 22 Dec 2022 13:35:01 +0000 Subject: [PATCH] feat: support aiohttp requests for outbound relay --- evervault/__init__.py | 13 +++++++-- evervault/client.py | 3 ++ evervault/http/requestintercept.py | 44 ++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/evervault/__init__.py b/evervault/__init__.py index 7283255..5c58fce 100644 --- a/evervault/__init__.py +++ b/evervault/__init__.py @@ -2,6 +2,7 @@ from .client import Client from .errors.evervault_errors import AuthenticationError, UnsupportedCurveError import os +import sys from warnings import warn __version__ = "1.4.0" @@ -98,14 +99,22 @@ def create_run_token(function_name, data): return __client().create_run_token(function_name, data) -def enable_outbound_relay(decryption_domains=None, debug_requests=False): +def enable_outbound_relay(decryption_domains=None, debug_requests=False, client_session=None): + if client_session is not None : + _warn_if_python_version_unsupported_for_async() + if decryption_domains is None: - __client().enable_outbound_relay(debug_requests, enable_outbound_relay=True) + __client().enable_outbound_relay(debug_requests, enable_outbound_relay=True, client_session=client_session) else: __client().enable_outbound_relay( debug_requests, decryption_domains=decryption_domains ) +def _warn_if_python_version_unsupported_for_async() : + if sys.version_info.minor < 11 : + warn( + "Using Outbound Relay with Asynchronous Python is only supported in Python >= 3.11" + ) def __client(): if not _api_key: diff --git a/evervault/client.py b/evervault/client.py index a229014..4e20b6c 100644 --- a/evervault/client.py +++ b/evervault/client.py @@ -61,6 +61,7 @@ def enable_outbound_relay( ignore_domains=[], decryption_domains=[], enable_outbound_relay=False, + client_session=None, ): if len(decryption_domains) > 0: self.cert.setup_decryption_domains(decryption_domains, debug_requests) @@ -69,6 +70,8 @@ def enable_outbound_relay( else: self.cert.setup_ignore_domains(ignore_domains, debug_requests) self.cert.setup() + if client_session : + self.cert.setup_aiohttp(client_session) def create_run_token(self, cage_name, data): return self.post(f"v2/functions/{cage_name}/run-token", data, {}) diff --git a/evervault/http/requestintercept.py b/evervault/http/requestintercept.py index 9d64798..6c9bf76 100644 --- a/evervault/http/requestintercept.py +++ b/evervault/http/requestintercept.py @@ -5,6 +5,7 @@ import warnings import certifi import tempfile +import ssl from evervault.errors.evervault_errors import CertDownloadError from evervault.http.outboundrelayconfig import RelayOutboundConfig @@ -22,6 +23,12 @@ def is_ignore_domain(domain, decryption_domains, always_ignore_domains): for decryption_domain in decryption_domains ) +def hostname_from_str_or_url(str_or_url) : + if hasattr(str_or_url, 'host') : + return str_or_url.host + + return urlparse(str_or_url).netloc + class RequestIntercept(object): def __init__( @@ -93,6 +100,43 @@ def set_relay_outbound_config(self, debug_requests=False): host, RelayOutboundConfig.get_decryption_domains(), always_ignore_domains ) + def setup_aiohttp(self, client_session): + self.__get_cert() + default_ssl_context = ssl.create_default_context(cafile=certifi.where()) + evervault_ssl_context = ssl.create_default_context(cafile=self.cert_path) + api_key = self.api_key + relay_url = self.relay_url + + old_request = client_session._request + + def new_req_func(method, str_or_url, **kwargs) : + domain = hostname_from_str_or_url(str_or_url) + print(domain) + should_proxy = self.should_proxy_domain(domain) + + if not 'headers' in kwargs : + kwargs['headers'] = {} + + if self.debug_requests and not any( + map( + lambda evervault_domain: domain.endswith(evervault_domain), + EVERVAULT_DOMAINS, + ) + ): + print( + f"Request to domain: {domain}, Outbound Proxy enabled: {should_proxy}" + ) + if should_proxy : + kwargs['proxy'] = relay_url + kwargs['headers']['Proxy-Authorization'] = api_key + kwargs['ssl'] = evervault_ssl_context + else : + kwargs['ssl'] = default_ssl_context + + return old_request(method, str_or_url, **kwargs) + + client_session._request = new_req_func + def setup(client_self): client_self.__get_cert()