Skip to content

Commit

Permalink
[Auth] Authenticate with OAuth client-credentials grant (#5456)
Browse files Browse the repository at this point in the history
  • Loading branch information
theSaarco committed May 1, 2024
1 parent 8a13381 commit ea4cdfd
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 19 deletions.
4 changes: 4 additions & 0 deletions mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@
# supported modes: "enabled", "disabled".
"mode": "disabled"
},
"auth_with_client_id": {
"enabled": False,
"request_timeout": 5,
},
}

_is_running_as_api = None
Expand Down
152 changes: 152 additions & 0 deletions mlrun/db/auth_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from datetime import datetime, timedelta

import requests

import mlrun.errors
from mlrun.utils import logger


class TokenProvider(ABC):
@abstractmethod
def get_token(self):
pass

@abstractmethod
def is_iguazio_session(self):
pass


class StaticTokenProvider(TokenProvider):
def __init__(self, token: str):
self.token = token

def get_token(self):
return self.token

def is_iguazio_session(self):
return mlrun.platforms.iguazio.is_iguazio_session(self.token)


class OAuthClientIDTokenProvider(TokenProvider):
def __init__(
self, token_endpoint: str, client_id: str, client_secret: str, timeout=5
):
if not token_endpoint or not client_id or not client_secret:
raise mlrun.errors.MLRunValueError(
"Invalid client_id configuration for authentication. Must provide token endpoint, client-id and secret"
)
self.token_endpoint = token_endpoint
self.client_id = client_id
self.client_secret = client_secret
self.timeout = timeout

# Since we're only issuing POST requests, which are actually a disguised GET, then it's ok to allow retries
# on them.
self._session = mlrun.utils.HTTPSessionWithRetry(
retry_on_post=True,
verbose=True,
)

self._cleanup()
self._refresh_token_if_needed()

def get_token(self):
self._refresh_token_if_needed()
return self.token

def is_iguazio_session(self):
return False

def _cleanup(self):
self.token = self.token_expiry_time = self.token_refresh_time = None

def _refresh_token_if_needed(self):
now = datetime.now()
if self.token:
if self.token_refresh_time and now <= self.token_refresh_time:
return self.token

# We only cleanup if token was really expired - even if we fail in refreshing the token, we can still
# use the existing one given that it's not expired.
if now >= self.token_expiry_time:
self._cleanup()

self._issue_token_request()
return self.token

def _issue_token_request(self, raise_on_error=False):
try:
headers = {"Content-Type": "application/x-www-form-urlencoded"}
request_body = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = self._session.request(
"POST",
self.token_endpoint,
timeout=self.timeout,
headers=headers,
data=request_body,
)
except requests.RequestException as exc:
error = f"Retrieving token failed: {mlrun.errors.err_to_str(exc)}"
if raise_on_error:
raise mlrun.errors.MLRunRuntimeError(error) from exc
else:
logger.warning(error)
return

if not response.ok:
error = "No error available"
if response.content:
try:
data = response.json()
error = data.get("error")
except Exception:
pass
logger.warning(
"Retrieving token failed", status=response.status_code, error=error
)
if raise_on_error:
mlrun.errors.raise_for_status(response)
return

self._parse_response(response.json())

def _parse_response(self, data: dict):
# Response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.3
# According to spec, there isn't a refresh token - just the access token and its expiry time (in seconds).
self.token = data.get("access_token")
expires_in = data.get("expires_in")
if not self.token or not expires_in:
token_str = "****" if self.token else "missing"
logger.warning(
"Failed to parse token response", token=token_str, expires_in=expires_in
)
return

now = datetime.now()
self.token_expiry_time = now + timedelta(seconds=expires_in)
self.token_refresh_time = now + timedelta(seconds=expires_in / 2)
logger.info(
"Successfully retrieved client-id token",
expires_in=expires_in,
expiry=str(self.token_expiry_time),
refresh=str(self.token_refresh_time),
)
49 changes: 31 additions & 18 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import mlrun.projects
import mlrun.runtimes.nuclio.api_gateway
import mlrun.utils
from mlrun.db.auth_utils import OAuthClientIDTokenProvider, StaticTokenProvider
from mlrun.errors import MLRunInvalidArgumentError, err_to_str

from ..artifacts import Artifact
Expand Down Expand Up @@ -138,17 +139,27 @@ def _enrich_and_validate(self, url):
endpoint += f":{parsed_url.port}"
base_url = f"{parsed_url.scheme}://{endpoint}{parsed_url.path}"

self.base_url = base_url
username = parsed_url.username or config.httpdb.user
password = parsed_url.password or config.httpdb.password

username, password, token = mlrun.platforms.add_or_refresh_credentials(
parsed_url.hostname, username, password, config.httpdb.token
)

self.base_url = base_url
self.user = username
self.password = password
self.token = token
self.token_provider = None

if config.auth_with_client_id.enabled:
self.token_provider = OAuthClientIDTokenProvider(
token_endpoint=mlrun.get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
client_id=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
client_secret=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
timeout=config.auth_with_client_id.request_timeout,
)
else:
username, password, token = mlrun.platforms.add_or_refresh_credentials(
parsed_url.hostname, username, password, config.httpdb.token
)

if token:
self.token_provider = StaticTokenProvider(token)

def __repr__(self):
cls = self.__class__.__name__
Expand Down Expand Up @@ -218,17 +229,19 @@ def api_call(

if self.user:
kw["auth"] = (self.user, self.password)
elif self.token:
# Iguazio auth doesn't support passing token through bearer, so use cookie instead
if mlrun.platforms.iguazio.is_iguazio_session(self.token):
session_cookie = f'j:{{"sid": "{self.token}"}}'
cookies = {
"session": session_cookie,
}
kw["cookies"] = cookies
else:
if "Authorization" not in kw.setdefault("headers", {}):
kw["headers"].update({"Authorization": "Bearer " + self.token})
elif self.token_provider:
token = self.token_provider.get_token()
if token:
# Iguazio auth doesn't support passing token through bearer, so use cookie instead
if self.token_provider.is_iguazio_session():
session_cookie = f'j:{{"sid": "{token}"}}'
cookies = {
"session": session_cookie,
}
kw["cookies"] = cookies
else:
if "Authorization" not in kw.setdefault("headers", {}):
kw["headers"].update({"Authorization": "Bearer " + token})

if mlrun.common.schemas.HeaderNames.client_version not in kw.setdefault(
"headers", {}
Expand Down
83 changes: 82 additions & 1 deletion tests/rundb/test_httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@

import deepdiff
import pytest
import requests_mock as requests_mock_package

import mlrun.artifacts.base
import mlrun.common.schemas
import mlrun.errors
import mlrun.projects.project
from mlrun import RunObject
from mlrun.db.auth_utils import StaticTokenProvider
from mlrun.db.httpdb import HTTPRunDB
from tests.conftest import tests_root_directory, wait_for_server

Expand Down Expand Up @@ -323,10 +325,89 @@ def test_bearer_auth(create_server):
with pytest.raises(mlrun.errors.MLRunUnauthorizedError):
db.list_runs()

db.token = token
db.token_provider = StaticTokenProvider(token)
db.list_runs()


def test_client_id_auth(requests_mock: requests_mock_package.Mocker, monkeypatch):
"""
Test the httpdb behavior when using a client-id OAuth token. Test verifies that:
- Token is retrieved successfully, and kept in the httpdb class.
- Token is added as Bearer token when issuing API calls to BE.
- Token is refreshed when its expiry time is nearing.
- Some error flows when token cannot be retrieved - such as that token is still used while it hasn't expired.
"""

token_url = "https://mock/token_endpoint/protocol/openid-connect/token"
test_env = {
"MLRUN_AUTH_TOKEN_ENDPOINT": token_url,
"MLRUN_AUTH_CLIENT_ID": "some-client-id",
"MLRUN_AUTH_CLIENT_SECRET": "some-client-secret",
}

mlrun.mlconf.auth_with_client_id.enabled = True
for key, value in test_env.items():
monkeypatch.setenv(key, value)

expected_token = "my-cool-token"
# Set a 4-second expiry, so a refresh will happen in 2 seconds
requests_mock.post(
token_url, json={"access_token": expected_token, "expires_in": 4}
)

db_url = "http://mock-server:1919"
db = HTTPRunDB(db_url)
db.connect()
token = db.token_provider.get_token()
assert token == expected_token
assert len(requests_mock.request_history) == 1

time.sleep(1)
token = db.token_provider.get_token()
assert token == expected_token
# verify no additional calls were made (too early)
assert len(requests_mock.request_history) == 1

time.sleep(1.5)
expected_token = "my-other-cool-token"
requests_mock.post(
token_url, json={"access_token": expected_token, "expires_in": 3}
)
token = db.token_provider.get_token()
assert token == expected_token

# Check that httpdb attaches the token to API calls as Authorization header.
# Using trigger-migrations since it needs no payload and returns nothing, so easy to simulate.
requests_mock.post(f"{db_url}/api/v1/operations/migrations", status_code=200)
db.trigger_migrations()

expected_auth = f"Bearer {expected_token}"
last_request = requests_mock.last_request
assert last_request.headers["Authorization"] == expected_auth

# Check flow where we fail token retrieval while token is still active (not expired).
requests_mock.reset_mock()
requests_mock.post(token_url, status_code=401)

time.sleep(2)
db.trigger_migrations()

request_history = requests_mock.request_history
# We expect 2 calls - one for the token (which failed but didn't fail the flow) and one for the actual api call.
assert len(request_history) == 2
# The token should still be the previous token, since it was not refreshed but it's not expired yet.
assert request_history[-1].headers["Authorization"] == expected_auth

# Now let the token expire, and verify commands still go out, only without auth
time.sleep(2)
requests_mock.reset_mock()

db.trigger_migrations()
assert len(requests_mock.request_history) == 2
assert "Authorization" not in requests_mock.last_request.headers
assert db.token_provider.token is None


def _generate_runtime(name) -> mlrun.runtimes.KubejobRuntime:
runtime = mlrun.runtimes.KubejobRuntime()
runtime.metadata.name = name
Expand Down

0 comments on commit ea4cdfd

Please sign in to comment.