Skip to content

Commit

Permalink
feat: implement better Oauth algorithm graph and refresh token
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjermiah committed Feb 5, 2024
1 parent 289f846 commit 3d81d2b
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 101 deletions.
131 changes: 78 additions & 53 deletions src/nbiatoolkit/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,90 +98,115 @@ def __init__(
else:
self.base_url = base_url

self.access_token = None
self.api_headers = None
self._access_token = None
self.expiry_time = None
self.refresh_token = None
self.refresh_expiry = None
self.refresh_token = "" # Fix: Assign an empty string instead of None
self.scope = None

@property
def access_token(self) -> str | None:
# Check if access token is not set or it's expired
if not self._access_token or self.is_token_expired():
self.refresh_token_or_request_new()

def getToken(self) -> Union[dict, None]:
"""
Retrieves the access token from the API.
return self._access_token

Returns
-------
api_headers : dict
The authentication headers containing the access token.
def is_token_expired(self) -> bool:
# Check if the token expiration time is set and if it's expired
return self.expiry_time is not None and time.time() > self.expiry_time

Example Usage
-------------
>>> from nbiatoolkit import OAuth2
>>> oauth = OAuth2()
>>> api_headers = oauth.getToken()
def refresh_token_or_request_new(self) -> None:
if self.refresh_token != "":
self._refresh_access_token()
else:
self.request_new_access_token()

>>> requests.get(url=query_url, headers=api_headers)
"""
# Check if the access token is valid and not expired
if self.access_token is not None:
return None if self.access_token == None else self.access_token
def _refresh_access_token(self) -> None:
assert self.refresh_token != "", "Refresh token is not set"

# Prepare the request data
data: dict[str, str] = {
"refresh_token": self.refresh_token,
"client_id": self.client_id,
"grant_type": "refresh_token",
}

token_url: str = self.base_url + "oauth/token"

response = requests.post(token_url, data=data)

try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
raise err
else:
token_data = response.json()
self.set_token_data(token_data)



def request_new_access_token(self):
# Implement logic to request a new access token using client credentials
# Set the new access token and update the expiration time
# Example:
# new_access_token, expires_in = your_token_request_logic()
# self.access_token = new_access_token
# self.token_expiration_time = time.time() + expires_in

# # Prepare the request data
data: dict[str, str] = {
"username": self.username,
"password": self.password,
"client_id": self.client_id,
"grant_type": "password",
}

token_url: str = self.base_url + "oauth/token"

response : requests.models.Response
response = requests.post(token_url, data=data)

try:
response = requests.post(token_url, data=data)
response.raise_for_status() # Raise an HTTPError for bad responses
except requests.exceptions.RequestException as e:
self.access_token = None
raise requests.exceptions.RequestException(
f"Failed to get access token. Status code:\
{response.status_code}"
) from e
response.raise_for_status()
except requests.exceptions.HTTPError as err:
raise SystemExit(err)
else:
# Code to execute if there is no exception
token_data = response.json()
self.access_token = token_data.get("access_token")
self.set_token_data(token_data)

self.api_headers = {"Authorization": f"Bearer {self.access_token}"}

self.expiry_time = time.ctime(time.time() + token_data.get("expires_in"))
self.refresh_token = token_data.get("refresh_token")
self.refresh_expiry = token_data.get("refresh_expires_in")
self.scope = token_data.get("scope")

return self.api_headers
def set_token_data(self, token_data: dict):
self._access_token = token_data["access_token"]
self.expiry_time = time.time() + int(token_data.get("expires_in") or 0)
self.refresh_token: str = token_data["refresh_token"]
self.refresh_expiry = token_data.get("refresh_expires_in")
self.scope = token_data.get("scope")

@property
def token(self):
"""
Returns the access token.
def api_headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",}

Returns
-------
access_token : str or None
The access token retrieved from the API.
"""
return self.access_token

@property
def headers(self):
"""
Returns the API headers.
def token_expiration_time(self):
return self.expiry_time

@property
def refresh_expiration_time(self):
return self.refresh_expiry

@property
def token_scope(self):
return self.scope

def __repr__(self):
return f"OAuth2(username={self.username}, client_id={self.client_id})"

def __str__(self):
return f"OAuth2(username={self.username}, client_id={self.client_id})"

Returns
-------
api_headers : dict or None
The authentication headers containing the access token.
"""
return self.api_headers
6 changes: 4 additions & 2 deletions src/nbiatoolkit/nbia.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def __init__(
self._oauth2_client = OAuth2(username=username, password=password)

try:
self._api_headers = self._oauth2_client.getToken()
self._api_headers = {
"Authorization": f"Bearer {self._oauth2_client.access_token}",
"Content-Type": "application/json",
}
except Exception as e:
self._log.error("Error retrieving access token: %s", e)
self._api_headers = None
raise e

self._base_url : NBIA_ENDPOINTS = NBIA_ENDPOINTS.BASE_URL

@property
Expand Down
103 changes: 57 additions & 46 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,62 @@
import time
import requests

@pytest.fixture(scope="session")
def oauth2():
oauth = OAuth2()
oauth.getToken()
return oauth

@pytest.fixture(scope="session")
def failed_oauth2():
oauth = OAuth2(username="bad_username", password="bad_password")
return oauth

def test_getToken(oauth2):
assert oauth2.access_token is not None
assert oauth2.token is not None

def test_expiry(oauth2):
# expiry should be in the form of :'Tue Jun 29 13:58:57 2077'
# and test for roughly 2 hours from now
print(oauth2.expiry_time)
assert oauth2.expiry_time <= time.ctime(time.time() + 7200)

def test_failed_oauth(failed_oauth2):
# should raise requests.exceptions.RequestException
with pytest.raises(requests.exceptions.RequestException):
failed_oauth2.getToken()
assert failed_oauth2.getToken() is None
assert failed_oauth2.access_token is None
assert failed_oauth2.token is None
assert failed_oauth2.api_headers is None
assert failed_oauth2.expiry_time is None
assert failed_oauth2.refresh_token is None
assert failed_oauth2.refresh_expiry is None
assert failed_oauth2.scope is None


def test_getToken_valid_token(oauth2):
# Test if the access token is valid and not expired
assert oauth2.getToken() == oauth2.access_token
assert oauth2.getToken() is not None
assert oauth2.access_token is not None
assert oauth2.token is not None
assert oauth2.headers is not None
assert oauth2.expiry_time is not None
assert oauth2.refresh_token is not None
assert oauth2.refresh_expiry is not None
assert oauth2.scope is not None
@pytest.fixture
def oauth() -> OAuth2:
return OAuth2()

def test_oauth2(oauth: OAuth2) -> None:
assert oauth.client_id == "NBIA"
assert oauth.username == "nbia_guest"
assert oauth.password == ""
assert oauth.access_token is not None
assert oauth.api_headers is not None
assert oauth.expiry_time is not None
assert oauth.refresh_token is not None
assert oauth.refresh_expiry is not None
assert oauth.scope is not None

def test_is_token_expired(oauth: OAuth2) -> None:
assert oauth.is_token_expired() == False
oauth.expiry_time = time.time() - 100
assert oauth.is_token_expired() == True

def test_refresh_token_or_request_new(oauth: OAuth2) -> None:
oauth.refresh_token_or_request_new()
assert oauth.access_token is not None
assert oauth.refresh_token is not None
assert oauth.refresh_expiry is not None
assert oauth.expiry_time is not None

def test_refresh_after_expiry(oauth: OAuth2) -> None:
access_token_before = oauth.access_token

oauth.expiry_time = time.time() - 100
oauth.refresh_token_or_request_new()
assert oauth.access_token is not None
assert oauth.access_token != access_token_before

assert oauth.refresh_token is not None
assert oauth.refresh_expiry is not None
assert oauth.expiry_time is not None
assert oauth.is_token_expired() == False

def test_failed_refresh(oauth: OAuth2) -> None:
oauth.refresh_token = ""
with pytest.raises(AssertionError):
oauth._refresh_access_token()

oauth.refresh_token = "invalid_refresh_token"

with pytest.raises(requests.exceptions.HTTPError):
oauth._refresh_access_token()

def test_request_new_access_token(oauth: OAuth2) -> None:
oauth.refresh_token = ""
oauth.request_new_access_token()
assert oauth.access_token is not None
assert oauth.refresh_token is not None
assert oauth.refresh_expiry is not None
assert oauth.expiry_time is not None
assert oauth.is_token_expired() == False

0 comments on commit 3d81d2b

Please sign in to comment.