Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add http_proxy to client & Fix deviceflow #1611

Merged
merged 3 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 24 additions & 7 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@
Base authenticator for all authentication flows
"""

def __init__(self, endpoint: str, header_key: str, credentials: Credentials = None):
def __init__(

Check warning on line 51 in flytekit/clients/auth/authenticator.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/authenticator.py#L51

Added line #L51 was not covered by tests
self,
endpoint: str,
header_key: str,
credentials: Credentials = None,
http_proxy_url: typing.Optional[str] = None,
):
self._endpoint = endpoint
self._creds = credentials
self._header_key = header_key if header_key else "authorization"
self._http_proxy_url = http_proxy_url

def get_credentials(self) -> Credentials:
return self._creds
Expand Down Expand Up @@ -162,6 +169,7 @@
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
scopes: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
):
if not client_id or not client_secret:
raise ValueError("Client ID and Client SECRET both are required.")
Expand All @@ -171,7 +179,7 @@
self._scopes = scopes or cfg.scopes
self._client_id = client_id
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key)
super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url)

def refresh_credentials(self):
"""
Expand All @@ -187,7 +195,9 @@
# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header)
token, expires_in = token_client.get_token(
token_endpoint, scopes, authorization_header, http_proxy_url=self._http_proxy_url
)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)

Expand All @@ -207,6 +217,7 @@
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
http_proxy_url: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
Expand All @@ -219,21 +230,27 @@
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint)
endpoint=endpoint,
header_key=header_key or cfg.header_key,
credentials=KeyringStore.retrieve(endpoint),
http_proxy_url=http_proxy_url,
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope)
resp = token_client.get_device_code(
self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url
)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
OR copy paste the following URL: {resp.verification_uri_complete}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check for this field's existence, and show it if it's there? Even if it's not part of the standard, it's nice to have.

Copy link
Collaborator Author

@ByronHsu ByronHsu May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it should be there because uri_complete is not in standard. resp.verification_uri + resp.user_code is enough for login.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but okta supports it. and lots of people use okta. it's not in the standard but i feel it's okay to support a nice to have.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kumare3 wdyt ^^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just add back and merge?

"""
)
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id)
token, expires_in = token_client.poll_token_endpoint(
resp, self._token_endpoint, client_id=self._client_id, http_proxy_url=self._http_proxy_url
)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
Expand Down
25 changes: 16 additions & 9 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import requests

from flytekit import logger

Check warning on line 12 in flytekit/clients/auth/token_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/token_client.py#L12

Added line #L12 was not covered by tests
from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"
Expand All @@ -31,15 +32,13 @@
{'device_code': 'code',
'user_code': 'BNDJJFXL',
'verification_uri': 'url',
'verification_uri_complete': 'url',
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
'expires_in': 600,
'interval': 5}
"""

device_code: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_in: int
interval: int

Expand All @@ -49,7 +48,6 @@
device_code=j["device_code"],
user_code=j["user_code"],
verification_uri=j["verification_uri"],
verification_uri_complete=j["verification_uri_complete"],
expires_in=j["expires_in"],
interval=j["interval"],
)
Expand Down Expand Up @@ -77,6 +75,7 @@
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
http_proxy_url: typing.Optional[str] = None,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
Expand All @@ -99,7 +98,8 @@
if scopes is not None:
body["scope"] = ",".join(scopes)

response = requests.post(token_endpoint, data=body, headers=headers)
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies)
if not response.ok:
j = response.json()
if "error" in j:
Expand All @@ -118,19 +118,24 @@
client_id: str,
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
http_proxy_url: typing.Optional[str] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
separate device
"""
payload = {"client_id": client_id, "scope": scope, "audience": audience}
resp = requests.post(device_auth_endpoint, payload)
_scope = " ".join(scope) if scope is not None else ""
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None
resp = requests.post(device_auth_endpoint, payload, proxies=proxies)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]:
def poll_token_endpoint(

Check warning on line 136 in flytekit/clients/auth/token_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/token_client.py#L136

Added line #L136 was not covered by tests
resp: DeviceCodeResponse, token_endpoint: str, client_id: str, http_proxy_url: typing.Optional[str] = None
) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
Expand All @@ -141,13 +146,15 @@
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
http_proxy_url=http_proxy_url,
)
print("Authentication successful!")
return access_token, expires_in
except AuthenticationPending:
...
except Exception:
raise
except Exception as e:
logger.error("Authentication attempt failed: ", e)
raise e

Check warning on line 157 in flytekit/clients/auth/token_client.py

View check run for this annotation

Codecov / codecov/patch

flytekit/clients/auth/token_client.py#L155-L157

Added lines #L155 - L157 were not covered by tests
print("Authentication Pending...")
time.sleep(interval.total_seconds())
tick = tick + interval
Expand Down
5 changes: 4 additions & 1 deletion flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
client_secret=cfg.client_credentials_secret,
cfg_store=cfg_store,
scopes=cfg.scopes,
http_proxy_url=cfg.http_proxy_url,
)
elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND:
client_cfg = None
Expand All @@ -82,7 +83,9 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience)
return DeviceCodeAuthenticator(
endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience, http_proxy_url=cfg.http_proxy_url
)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
3 changes: 3 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@
:param scopes: List of scopes to request. This is only applicable to the client credentials flow
:param auth_mode: The OAuth mode to use. Defaults to pkce flow
:param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
:param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests
"""

endpoint: str = "localhost:30080"
Expand All @@ -390,6 +391,7 @@
auth_mode: AuthType = AuthType.STANDARD
audience: typing.Optional[str] = None
rpc_retries: int = 3
http_proxy_url: typing.Optional[str] = None

Check warning on line 394 in flytekit/configuration/__init__.py

View check run for this annotation

Codecov / codecov/patch

flytekit/configuration/__init__.py#L394

Added line #L394 was not covered by tests

@classmethod
def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None) -> PlatformConfig:
Expand Down Expand Up @@ -426,6 +428,7 @@
kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file))
kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file))
kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file))
kwargs = set_if_exists(kwargs, "http_proxy_url", _internal.Platform.HTTP_PROXY_URL.read(config_file))
return PlatformConfig(**kwargs)

@classmethod
Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
CA_CERT_FILE_PATH = ConfigEntry(
LegacyConfigEntry(SECTION, "ca_cert_file_path"), YamlConfigEntry("admin.caCertFilePath")
)
HTTP_PROXY_URL = ConfigEntry(LegacyConfigEntry(SECTION, "http_proxy_url"), YamlConfigEntry("admin.httpProxyURL"))

Check warning on line 114 in flytekit/configuration/internal.py

View check run for this annotation

Codecov / codecov/patch

flytekit/configuration/internal.py#L114

Added line #L114 was not covered by tests


class LocalSDK(object):
Expand Down
14 changes: 7 additions & 7 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def test_command_authenticator(mock_subprocess: MagicMock):
@patch("flytekit.clients.auth.token_client.requests")
def test_client_creds_authenticator(mock_requests):
authn = ClientCredentialsAuthenticator(
ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store
ENDPOINT,
client_id="client",
client_secret="secret",
cfg_store=static_cfg_store,
http_proxy_url="https://my-proxy:31111",
)

response = MagicMock()
Expand Down Expand Up @@ -103,13 +107,9 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(
ENDPOINT,
cfg_store,
audience="x",
)
authn = DeviceCodeAuthenticator(ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000")

device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0)
device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", 100)
authn.refresh_credentials()
assert authn._creds
Expand Down
21 changes: 9 additions & 12 deletions tests/flytekit/unit/clients/auth/test_token_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_get_token(mock_requests):
response.status_code = 200
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"])
access, expiration = get_token(
"https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000"
)
assert access == "abc"
assert expiration == 60

Expand All @@ -39,19 +41,18 @@ def test_get_device_code(mock_requests):
response.ok = False
mock_requests.post.return_value = response
with pytest.raises(AuthenticationError):
get_device_code("test.com", "test")
get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")

response.ok = True
response.json.return_value = {
"device_code": "code",
"user_code": "BNDJJFXL",
"verification_uri": "url",
"verification_uri_complete": "url",
"expires_in": 600,
"interval": 5,
}
mock_requests.post.return_value = response
c = get_device_code("test.com", "test")
c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000")
assert c
assert c.device_code == "code"

Expand All @@ -63,19 +64,15 @@ def test_poll_token_endpoint(mock_requests):
response.json.return_value = {"error": error_auth_pending}
mock_requests.post.return_value = response

r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1
)
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1)
with pytest.raises(AuthenticationError):
poll_token_endpoint(r, "test.com", "test")
poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")

response = MagicMock()
response.ok = True
response.json.return_value = {"access_token": "abc", "expires_in": 60}
mock_requests.post.return_value = response
r = DeviceCodeResponse(
device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0
)
t, e = poll_token_endpoint(r, "test.com", "test")
r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0)
t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000")
assert t
assert e