diff --git a/.talismanrc b/.talismanrc index 7e6f757..b443cb3 100644 --- a/.talismanrc +++ b/.talismanrc @@ -393,4 +393,14 @@ fileignoreconfig: checksum: 344aa9e4b3ec399c581a507eceff63ff6faf56ad938475e5f4865f6cb590df68 - filename: tests/unit/contentstack/test_contentstack.py checksum: 98503cbd96cb546a19aed037a6ca28ef54fcea312efcd9bac1171e43760f6e86 +- filename: contentstack_management/contentstack.py + checksum: 591978d70ecbe5fc3e6587544e9c112a6cd85fd8da2051b48ff87ab6a2e9eb57 +- filename: tests/unit/test_oauth_handler.py + checksum: 8b6853ba64c3de4f9097ca506719c5e33c7468ae5985b8adcda3eb6461d76be5 +- filename: contentstack_management/oauth/oauth_handler.py + checksum: e33cfd32d90c0553c4959c0d266fef1247cd0e0fe7bbe85cae98bb205e62c70e +- filename: tests/unit/user_session/test_user_session_totp.py + checksum: 0db30c5a306783b10d345d73cff3c61490d7cbc47273623df47e6849c3e97002 +- filename: tests/unit/contentstack/test_totp_login.py + checksum: cefad0ddc1a2db1bf59d6e04501c4381acc8b44fad1e5e2e24c06e33d827c859 version: "1.0" diff --git a/CHANGELOG.md b/CHANGELOG.md index c8c7614..c7fcd18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ # CHANGELOG ## Content Management SDK For Python + +--- +## v1.7.0 + +#### Date: 15 September 2025 + +- OAuth 2.0 support. --- ## v1.6.0 diff --git a/contentstack_management/__init__.py b/contentstack_management/__init__.py index fdf9348..ebc2aaf 100644 --- a/contentstack_management/__init__.py +++ b/contentstack_management/__init__.py @@ -34,6 +34,8 @@ from .extensions.extension import Extension from .variant_group.variant_group import VariantGroup from .variants.variants import Variants +from .oauth.oauth_handler import OAuthHandler +from .oauth.oauth_interceptor import OAuthInterceptor __all__ = ( @@ -71,14 +73,16 @@ "PublishQueue", "Extension", "VariantGroup", -"Variants" +"Variants", +"OAuthHandler", +"OAuthInterceptor" ) __title__ = 'contentstack-management-python' __author__ = 'dev-ex' __status__ = 'debug' __region__ = 'na' -__version__ = '1.6.0' +__version__ = '1.7.0' __host__ = 'api.contentstack.io' __protocol__ = 'https://' __api_version__ = 'v3' diff --git a/contentstack_management/_api_client.py b/contentstack_management/_api_client.py index a6f7275..5929689 100644 --- a/contentstack_management/_api_client.py +++ b/contentstack_management/_api_client.py @@ -2,7 +2,7 @@ class _APIClient: - def __init__(self, endpoint, headers, timeout=30, max_retries: int = 5): + def __init__(self, endpoint, headers, timeout=30, max_retries: int = 5, oauth_interceptor=None): """ The function is a constructor that initializes the endpoint, headers, timeout, and max_retries attributes of an object. @@ -25,6 +25,8 @@ def __init__(self, endpoint, headers, timeout=30, max_retries: int = 5): self.headers = headers self.timeout = timeout self.max_retries = max_retries + self.oauth_interceptor = oauth_interceptor + self.oauth = {} # OAuth token storage pass def _call_request(self, method, url, headers: dict = None, params=None, data=None, json_data=None, files=None): @@ -52,9 +54,17 @@ def _call_request(self, method, url, headers: dict = None, params=None, data=Non :return: the JSON response from the HTTP request. """ - # headers.update(self.headers) + if self.oauth_interceptor and self.oauth_interceptor.is_oauth_configured(): + return self.oauth_interceptor.execute_request( + method, url, headers=headers, params=params, data=data, + json=json_data, files=files, timeout=self.timeout + ) + + if headers is None: + headers = {} + headers.update(self.headers) # Merge client headers (including authtoken) with request headers response = requests.request( - method, url, headers=headers, params=params, data=data, json=json_data, files=files) + method, url, headers=headers, params=params, data=data, json=json_data, files=files, timeout=self.timeout) # response.raise_for_status() return response diff --git a/contentstack_management/contentstack.py b/contentstack_management/contentstack.py index 74526a6..c9e9cb8 100644 --- a/contentstack_management/contentstack.py +++ b/contentstack_management/contentstack.py @@ -1,9 +1,12 @@ from enum import Enum +import os +import pyotp from ._api_client import _APIClient from contentstack_management.organizations import organization from contentstack_management.stack import stack from contentstack_management.user_session import user_session from contentstack_management.users import user +from contentstack_management.oauth.oauth_handler import OAuthHandler version = '0.0.1' @@ -33,14 +36,16 @@ class Client: def __init__(self, host: str = 'api.contentstack.io', scheme: str = 'https://', authtoken: str = None , management_token=None, headers: dict = None, region: Region = Region.US.value, version='v3', timeout=2, max_retries: int = 18, early_access: list = None, - **kwargs): + oauth_config: dict = None, **kwargs): self.endpoint = 'https://api.contentstack.io/v3/' - if region is not None and host is not None and region is not Region.US.value: - self.endpoint = f'{scheme}{region}-{host}/{version}/' - if region is not None and host is None and region is not Region.US.value: - host = 'api.contentstack.com' - self.endpoint = f'{scheme}{region}-{host}/{version}/' - if host is not None and region is None: + + if region is not None and region is not Region.US.value: + if host is not None and host != 'api.contentstack.io': + self.endpoint = f'{scheme}{region}-api.{host}/{version}/' + else: + host = 'api.contentstack.com' + self.endpoint = f'{scheme}{region}-{host}/{version}/' + elif host is not None and host != 'api.contentstack.io': self.endpoint = f'{scheme}{host}/{version}/' if headers is None: headers = {} @@ -55,6 +60,19 @@ def __init__(self, host: str = 'api.contentstack.io', scheme: str = 'https://', headers['authorization'] = management_token headers = user_agents(headers) self.client = _APIClient(endpoint=self.endpoint, headers=headers, timeout=timeout, max_retries=max_retries) + + # Initialize OAuth if configuration is provided + self.oauth_handler = None + if oauth_config: + self.oauth_handler = OAuthHandler( + app_id=oauth_config.get('app_id'), + client_id=oauth_config.get('client_id'), + redirect_uri=oauth_config.get('redirect_uri'), + response_type=oauth_config.get('response_type', 'code'), + client_secret=oauth_config.get('client_secret'), + scope=oauth_config.get('scope'), + api_client=self.client + ) """ :param host: Optional hostname for the API endpoint. @@ -77,9 +95,36 @@ def __init__(self, host: str = 'api.contentstack.io', scheme: str = 'https://', ------------------------------- """ - def login(self, email: str, password: str, tfa_token: str = None): - return user_session.UserSession(self.client).login(email, password, tfa_token) - pass + def login(self, email: str, password: str, tfa_token: str = None, mfa_secret: str = None): + """ + Login to Contentstack with optional TOTP support. + + :param email: User's email address + :param password: User's password + :param tfa_token: Optional two-factor authentication token + :param mfa_secret: Optional MFA secret for automatic TOTP generation. + If not provided, will check MFA_SECRET environment variable + :return: Response object from the login request + """ + final_tfa_token = tfa_token + + if not mfa_secret: + mfa_secret = os.getenv('MFA_SECRET') + + if mfa_secret and not tfa_token: + final_tfa_token = self._generate_totp(mfa_secret) + + return user_session.UserSession(self.client).login(email, password, final_tfa_token) + + def _generate_totp(self, secret: str) -> str: + """ + Generate a Time-Based One-Time Password (TOTP) from the provided secret. + + :param secret: The MFA secret key for TOTP generation + :return: The current TOTP code as a string + """ + totp = pyotp.TOTP(secret) + return totp.now() def logout(self): return user_session.UserSession(client=self.client).logout() @@ -96,3 +141,41 @@ def organizations(self, organization_uid: str = None): def stack(self, api_key: str = None): return stack.Stack(self.client, api_key) + + def oauth(self, app_id: str, client_id: str, redirect_uri: str, + response_type: str = "code", client_secret: str = None, + scope: list = None): + """ + Create an OAuth handler for OAuth 2.0 authentication. + + Args: + app_id: Your registered App ID + client_id: Your OAuth Client ID + redirect_uri: The URL where the user is redirected after login and consent + response_type: OAuth response type (default: "code") + client_secret: Client secret for standard OAuth flows (optional for PKCE) + scope: Permissions requested (optional) + + Returns: + OAuthHandler instance + + Example: + >>> import contentstack_management + >>> client = contentstack_management.Client() + >>> oauth_handler = client.oauth( + ... app_id='your-app-id', + ... client_id='your-client-id', + ... redirect_uri='http://localhost:3000/callback' + ... ) + >>> auth_url = oauth_handler.authorize() + >>> print(f"Visit this URL to authorize: {auth_url}") + """ + return OAuthHandler( + app_id=app_id, + client_id=client_id, + redirect_uri=redirect_uri, + response_type=response_type, + client_secret=client_secret, + scope=scope, + api_client=self.client + ) diff --git a/contentstack_management/oauth/__init__.py b/contentstack_management/oauth/__init__.py new file mode 100644 index 0000000..c10b906 --- /dev/null +++ b/contentstack_management/oauth/__init__.py @@ -0,0 +1,5 @@ +"""OAuth 2.0 authentication module for Contentstack Management SDK.""" + +from .oauth_handler import OAuthHandler + +__all__ = ["OAuthHandler"] diff --git a/contentstack_management/oauth/oauth_handler.py b/contentstack_management/oauth/oauth_handler.py new file mode 100644 index 0000000..4eaebfa --- /dev/null +++ b/contentstack_management/oauth/oauth_handler.py @@ -0,0 +1,409 @@ +""" +OAuth 2.0 Handler for Contentstack Management SDK. + +This module provides OAuth 2.0 authentication support including: +- Authorization code flow +- PKCE (Proof Key for Code Exchange) +- Token management and refresh +- Secure token storage +""" + +import hashlib +import secrets +import time +import urllib.parse +from typing import Dict, List, Optional, Union +from urllib.parse import urlparse, parse_qs + +import requests + + +class OAuthHandler: + """ + OAuth 2.0 Handler for Contentstack Management SDK. + + This class manages OAuth 2.0 authentication flow including authorization, + token exchange, refresh, and secure storage. + """ + + # Error messages + ACCESS_TOKEN_EXPIRED_MSG = "🔄 Access token expired, refreshing..." + NO_ACCESS_TOKEN_MSG = "No access token available. Please authenticate first." + NO_OAUTH_TOKENS_MSG = "No OAuth tokens available" + NO_REFRESH_TOKEN_MSG = "No refresh token available" + OAUTH_NOT_CONFIGURED_MSG = "OAuth is not configured. Please set up OAuth first." + INVALID_AUTHORIZATION_CODE_MSG = "Authorization code not found in redirect URL" + TOKEN_EXCHANGE_FAILED_MSG = "Token exchange failed" + TOKEN_REFRESH_FAILED_MSG = "Token refresh failed" + OAUTH_BASE_URL_NOT_SET_MSG = "OAuthBaseURL is not set" + OAUTH_BASE_URL = 'https://app.contentstack.com' + DEVELOPER_HUB_BASE_URL = 'https://developerhub-api.contentstack.com' + + def __init__( + self, + app_id: str, + client_id: str, + redirect_uri: str, + response_type: str = "code", + client_secret: Optional[str] = None, + scope: Optional[List[str]] = None, + api_client=None + ): + self.app_id = app_id + self.client_id = client_id + self.redirect_uri = redirect_uri + self.response_type = response_type + self.scope = ' '.join(scope) if scope else '' + self.client_secret = client_secret # Optional, if provided, PKCE will be skipped + self.api_client = api_client + + self._oauth_base_url = self._construct_oauth_base_url() + self._developer_hub_base_url = self._construct_developer_hub_base_url() + + if self.api_client: + if not hasattr(self.api_client, 'oauth'): + self.api_client.oauth = {} + self.api_client.oauth.update({ + 'redirect_uri': redirect_uri, + 'client_id': client_id, + 'app_id': app_id + }) + + # PKCE setup + self.use_pkce = client_secret is None + if self.use_pkce: + self.code_verifier = self._generate_code_verifier() + self.code_challenge = self._generate_code_challenge(self.code_verifier) + else: + self.code_verifier = None + self.code_challenge = None + + def _construct_oauth_base_url(self) -> str: + """ + Construct OAuth base URL based on api_client endpoint using dynamic text replacement. + Returns: + OAuth base URL string + """ + if not self.api_client or not hasattr(self.api_client, 'endpoint'): + return self.OAUTH_BASE_URL + + endpoint = self.api_client.endpoint + + from urllib.parse import urlparse + parsed = urlparse(endpoint) + domain = parsed.netloc + if 'api.contentstack.io' in domain: + oauth_domain = domain.replace('api.contentstack.io', 'app.contentstack.com') + else: + oauth_domain = domain.replace('-api.', '-app.').replace('api.', 'app.') + oauth_url = f'https://{oauth_domain}' + return oauth_url + + def _construct_developer_hub_base_url(self) -> str: + """ + Construct Developer Hub base URL based on api_client endpoint using dynamic text replacement. + Returns: + Developer Hub base URL string + """ + if not self.api_client or not hasattr(self.api_client, 'endpoint'): + return self.DEVELOPER_HUB_BASE_URL + + endpoint = self.api_client.endpoint + from urllib.parse import urlparse + parsed = urlparse(endpoint) + domain = parsed.netloc + + if 'api.contentstack.io' in domain: + dev_hub_domain = domain.replace('api.contentstack.io', 'developerhub-api.contentstack.com') + else: + dev_hub_domain = domain.replace('-api.', '-developerhub-api.') + dev_hub_url = f'https://{dev_hub_domain}' + return dev_hub_url + + def _generate_code_verifier(self, length: int = 128) -> str: + """ + Generate a random code verifier for PKCE. + Returns: + Base64 URL-encoded code verifier + """ + code_verifier = secrets.token_urlsafe(length) + return code_verifier + + def _generate_code_challenge(self, code_verifier: str) -> str: + """ + Generate code challenge from code verifier using SHA256. + Returns: + Base64 URL-encoded SHA256 hash of the code verifier + """ + sha256_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest() + import base64 + code_challenge = base64.urlsafe_b64encode(sha256_hash).decode('utf-8').rstrip('=') + return code_challenge + + def _get_headers(self) -> Dict[str, str]: + """Get common headers for OAuth requests.""" + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + if (self.api_client and hasattr(self.api_client, 'oauth') and + self.api_client.oauth.get('accessToken')): + headers["authorization"] = f"Bearer {self.api_client.oauth['accessToken']}" + return headers + + def authorize(self) -> str: + """ + Start the OAuth authorization flow. + Returns: + Authorization URL for user to visit + """ + try: + print(f"OAuth Handler - authorize() called with app_id: {self.app_id}, client_id: {self.client_id}") # Debug + if not self._oauth_base_url: + raise ValueError(self.OAUTH_BASE_URL_NOT_SET_MSG) + oauth_base = self._oauth_base_url.rstrip('/') + base_url = f"{oauth_base}/#!/apps/{self.app_id}/authorize" + params = { + 'response_type': 'code', # Always use 'code' + 'client_id': self.client_id + } + + if self.redirect_uri and self.redirect_uri.strip() and self.redirect_uri != 'None': + params['redirect_uri'] = self.redirect_uri + + # Add PKCE parameters if using PKCE + if self.use_pkce: + if not self.code_challenge: + self.code_challenge = self._generate_code_challenge(self.code_verifier) + params['code_challenge'] = self.code_challenge + params['code_challenge_method'] = 'S256' + + query_string = urllib.parse.urlencode(params) + final_url = f"{base_url}?{query_string}" + print(f"OAuth Handler - final authorization URL: {final_url}") # Debug + return final_url + + except Exception as e: + raise ValueError(f"Error generating authorization URL: {e}") + + def handle_redirect(self, redirect_url: str) -> Dict: + """ + Handle the redirect from OAuth authorization server. + Returns: + Dictionary containing token information + """ + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + if "code" not in query_params: + raise ValueError(self.INVALID_AUTHORIZATION_CODE_MSG) + authorization_code = query_params["code"][0] + return self.exchange_code_for_token(authorization_code) + + def exchange_code_for_token(self, authorization_code: str) -> Dict: + """ + Exchange authorization code for access token. + Returns: + Dictionary containing token information + """ + if not authorization_code or not authorization_code.strip(): + raise ValueError("Authorization code cannot be empty") + data = { + "grant_type": "authorization_code", + "code": authorization_code.strip(), + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + "app_id": self.app_id + } + + if self.client_secret: + data["client_secret"] = self.client_secret + else: + data["code_verifier"] = self.code_verifier + + headers = self._get_headers() + + try: + token_endpoint = f"{self._developer_hub_base_url}/token" + response = requests.post( + token_endpoint, + data=data, + headers=headers, + timeout=30 + ) + response.raise_for_status() + token_data = response.json() + self._save_tokens(token_data) + return token_data + + except requests.RequestException as e: + raise requests.RequestException(f"{self.TOKEN_EXCHANGE_FAILED_MSG}: {str(e)}") + + def _save_tokens(self, token_data: Dict): + """ + Save tokens and related information. + """ + # Store tokens in api_client.oauth + if self.api_client: + self.api_client.oauth["accessToken"] = token_data.get("access_token") + self.api_client.oauth["refreshToken"] = token_data.get("refresh_token") or self.api_client.oauth.get("refreshToken") + self.api_client.oauth["organizationUID"] = token_data.get("organization_uid") + self.api_client.oauth["userUID"] = token_data.get("user_uid") + expires_in = token_data.get("expires_in", 3600) + self.api_client.oauth["tokenExpiryTime"] = int(time.time() * 1000) + (expires_in - 60) * 1000 # Store expiry time in milliseconds + + self._access_token = token_data.get("access_token") + self._refresh_token = token_data.get("refresh_token") + expires_in = token_data.get("expires_in", 3600) # Default 1 hour + self._token_expiry_time = time.time() + expires_in + self._organization_uid = token_data.get("organization_uid") + self._user_uid = token_data.get("user_uid") + if self.api_client and self._access_token: + self.api_client.headers["Authorization"] = f"Bearer {self._access_token}" + + def get_valid_access_token(self) -> str: + """ + Get valid access token, refreshing if necessary. + Returns: + Valid access token + """ + if self.is_token_expired(): + print(self.ACCESS_TOKEN_EXPIRED_MSG) + self.refresh_access_token() + access_token = self.get_access_token() + if not access_token: + raise ValueError(self.NO_ACCESS_TOKEN_MSG) + return access_token + + def is_token_expired(self) -> bool: + """ + Check if access token is expired. + Returns: + True if token is expired, False otherwise + """ + if not self.api_client or not hasattr(self.api_client, 'oauth'): + return True + + token_expiry_time = self.api_client.oauth.get('tokenExpiryTime') + if not token_expiry_time: + return True + if token_expiry_time > 1e10: + expiry_time = token_expiry_time / 1000 + else: + expiry_time = token_expiry_time + return time.time() >= expiry_time + + def refresh_access_token(self) -> str: + """ + Refresh the access token using refresh token. + Returns: + New access token + """ + if not self.api_client or not hasattr(self.api_client, 'oauth') or not self.api_client.oauth: + raise ValueError(self.NO_OAUTH_TOKENS_MSG) + refresh_token = self.api_client.oauth.get('refreshToken') + if not refresh_token: + raise ValueError(self.NO_REFRESH_TOKEN_MSG) + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": self.client_id, + "app_id": self.app_id + } + if self.client_secret: + data["client_secret"] = self.client_secret + headers = self._get_headers() + try: + response = requests.post( + f"{self._developer_hub_base_url}/token", + data=data, + headers=headers, + timeout=30 + ) + response.raise_for_status() + + token_data = response.json() + self._save_tokens(token_data) + + return self._access_token + except requests.RequestException as e: + raise requests.RequestException(f"{self.TOKEN_REFRESH_FAILED_MSG}: {str(e)}") + + def logout(self, revoke_authorization: bool = True) -> bool: + """ + Logout and clear OAuth tokens. + Returns: + True if logout successful, False otherwise + """ + try: + self._clear_tokens() + return True + + except Exception: + self._clear_tokens() + return False + + def _clear_tokens(self): + """Clear all stored tokens and related information.""" + self._access_token = None + self._refresh_token = None + self._token_expiry_time = None + self._organization_uid = None + self._user_uid = None + + if self.api_client and "Authorization" in self.api_client.headers: + del self.api_client.headers["Authorization"] + + def get_access_token(self) -> Optional[str]: + """Get the current access token.""" + return self._access_token + + def set_access_token(self, token: str): + """Set the access token.""" + self._access_token = token + if self.api_client: + self.api_client.headers["Authorization"] = f"Bearer {token}" + if not hasattr(self.api_client, 'oauth'): + self.api_client.oauth = {} + self.api_client.oauth['accessToken'] = token + + def get_refresh_token(self) -> Optional[str]: + """Get the current refresh token.""" + return self._refresh_token + + def set_refresh_token(self, token: str): + """Set the refresh token.""" + self._refresh_token = token + if self.api_client: + if not hasattr(self.api_client, 'oauth'): + self.api_client.oauth = {} + self.api_client.oauth['refreshToken'] = token + + def get_organization_uid(self) -> Optional[str]: + """Get the organization UID.""" + return self._organization_uid + + def set_organization_uid(self, uid: str): + """Set the organization UID.""" + self._organization_uid = uid + + def get_user_uid(self) -> Optional[str]: + """Get the user UID.""" + return self._user_uid + + def set_user_uid(self, uid: str): + """Set the user UID.""" + self._user_uid = uid + + def get_token_expiry_time(self) -> Optional[float]: + """Get the token expiry time.""" + return self._token_expiry_time + + def set_token_expiry_time(self, expiry_time: float): + """Set the token expiry time.""" + self._token_expiry_time = expiry_time + if self.api_client: + if not hasattr(self.api_client, 'oauth'): + self.api_client.oauth = {} + self.api_client.oauth['tokenExpiryTime'] = expiry_time + diff --git a/contentstack_management/oauth/oauth_interceptor.py b/contentstack_management/oauth/oauth_interceptor.py new file mode 100644 index 0000000..8888666 --- /dev/null +++ b/contentstack_management/oauth/oauth_interceptor.py @@ -0,0 +1,187 @@ +""" +OAuth Interceptor for automatic token management and request handling. +""" + +import time +import threading +from typing import Dict, Any, Optional +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +class OAuthInterceptor: + """ + OAuth interceptor that automatically handles token refresh and request retries. + """ + MAX_RETRIES = 3 + REFRESH_TIMEOUT = 30 + TOKEN_ENDPOINT_PATH = "/token" + TOKEN_REFRESH_FAILED_MSG = "Token refresh failed" + NO_VALID_TOKENS_MSG = "OAuth: No valid tokens available" + TOKEN_REFRESH_FAILED_AFTER_401_MSG = "OAuth: Token refresh failed after 401" + # User agent strings + USER_AGENT = "contentstack-python-management-sdk" + X_USER_AGENT = "contentstack-python-management-sdk" + + def __init__(self, oauth_handler): + """ + Initialize the OAuth interceptor. + Args: + oauth_handler: The OAuthHandler instance + """ + self.oauth_handler = oauth_handler + self.early_access = None + self.refresh_lock = threading.Lock() + + def set_early_access(self, early_access: list): + """Set early access headers.""" + self.early_access = early_access + + def is_oauth_configured(self) -> bool: + """Check if OAuth is properly configured.""" + return (self.oauth_handler is not None and + hasattr(self.oauth_handler, 'app_id') and + self.oauth_handler.app_id is not None) + + def has_valid_tokens(self) -> bool: + """Check if we have valid (non-expired) tokens.""" + if not self.oauth_handler or not hasattr(self.oauth_handler, 'api_client'): + return False + + api_client = self.oauth_handler.api_client + if not hasattr(api_client, 'oauth') or not api_client.oauth: + return False + return not self.oauth_handler.is_token_expired() + + def _get_default_headers(self, request_url: str) -> Dict[str, str]: + """ + Get default headers for requests. + Returns: + Dictionary of headers + """ + headers = { + "X-User-Agent": self.X_USER_AGENT, + "User-Agent": self.USER_AGENT, + "x-header-ea": ",".join(self.early_access) if self.early_access else "true" + } + + if self.TOKEN_ENDPOINT_PATH in request_url: + headers["Content-Type"] = "application/x-www-form-urlencoded" + else: + headers["Content-Type"] = "application/json" + + return headers + + def _add_auth_header(self, headers: Dict[str, str], request_url: str) -> Dict[str, str]: + """ + Add authorization header if appropriate. + Returns: + Updated headers with authorization if needed + """ + if self.TOKEN_ENDPOINT_PATH in request_url: + return headers + if (self.oauth_handler and + hasattr(self.oauth_handler, 'api_client') and + self.oauth_handler.api_client and + hasattr(self.oauth_handler.api_client, 'oauth') and + self.oauth_handler.api_client.oauth and + self.oauth_handler.api_client.oauth.get('accessToken')): + + headers["Authorization"] = f"Bearer {self.oauth_handler.api_client.oauth['accessToken']}" + return headers + + def _ensure_valid_token(self) -> bool: + """ + Ensure we have a valid token, refreshing if necessary. + Returns: + True if we have a valid token, False otherwise + """ + if not self.oauth_handler or not hasattr(self.oauth_handler, 'api_client'): + return False + + api_client = self.oauth_handler.api_client + if not hasattr(api_client, 'oauth') or not api_client.oauth: + return False + + # Check if token is expired and refresh if needed + if self.oauth_handler.is_token_expired(): + with self.refresh_lock: + try: + if self.oauth_handler.is_token_expired(): + self.oauth_handler.refresh_access_token() + return True + except Exception as e: + print(f"{self.TOKEN_REFRESH_FAILED_MSG}: {e}") + return False + + return True + + def execute_request(self, method: str, url: str, **kwargs) -> requests.Response: + """ + Execute a request with OAuth handling and retry logic. + Returns: + Response object + """ + if self.TOKEN_ENDPOINT_PATH in url: + return self._make_request(method, url, **kwargs) + if not self._ensure_valid_token(): + raise requests.RequestException(self.NO_VALID_TOKENS_MSG) + return self._execute_with_retry(method, url, 0, **kwargs) + + def _execute_with_retry(self, method: str, url: str, retry_count: int, **kwargs) -> requests.Response: + """ + Execute request with retry logic. + Returns: + Response object + """ + headers = self._get_default_headers(url) + headers = self._add_auth_header(headers, url) + if 'headers' in kwargs: + headers.update(kwargs['headers']) + kwargs['headers'] = headers + + response = self._make_request(method, url, **kwargs) + if not response.ok and retry_count < self.MAX_RETRIES: + status_code = response.status_code + + if (status_code == 401 and + self.oauth_handler and + hasattr(self.oauth_handler, 'api_client') and + self.oauth_handler.api_client and + hasattr(self.oauth_handler.api_client, 'oauth') and + self.oauth_handler.api_client.oauth and + self.oauth_handler.api_client.oauth.get('refreshToken')): + + with self.refresh_lock: + try: + self.oauth_handler.refresh_access_token() + headers["Authorization"] = f"Bearer {self.oauth_handler.api_client.oauth['accessToken']}" + kwargs['headers'] = headers + return self._execute_with_retry(method, url, retry_count + 1, **kwargs) + except Exception as e: + raise requests.RequestException(f"{self.TOKEN_REFRESH_FAILED_AFTER_401_MSG}: {e}") + + if status_code == 429 or (status_code >= 500 and status_code != 501): + # Calculate delay with exponential backoff + delay = min(1000 * (2 ** retry_count), 30000) / 1000 # Convert to seconds + time.sleep(delay) + return self._execute_with_retry(method, url, retry_count + 1, **kwargs) + return response + + def _make_request(self, method: str, url: str, **kwargs) -> requests.Response: + """ + Make the actual HTTP request. + Returns: + Response object + """ + return requests.request(method, url, **kwargs) + + def get_valid_access_token(self) -> Optional[str]: + """ + Get a valid access token, refreshing if necessary. + Returns: + Valid access token or None if unavailable + """ + if self._ensure_valid_token(): + return self.oauth_handler.api_client.oauth.get('accessToken') + return None diff --git a/contentstack_management/user_session/user_session.py b/contentstack_management/user_session/user_session.py index 8c05ca7..91588cd 100644 --- a/contentstack_management/user_session/user_session.py +++ b/contentstack_management/user_session/user_session.py @@ -63,7 +63,7 @@ def login(self, email=None, password=None, tfa_token=None): } if tfa_token is not None: - data["user"]["tf_token"] = tfa_token + data["user"]["tfa_token"] = tfa_token data = json.dumps(data) response = self.client.post(_path, headers=self.client.headers, data=data, json_data=None) diff --git a/requirements.txt b/requirements.txt index 72b0ca9..cc66228 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ requests>=2.32.0,<3.0.0 pylint>=2.0.0 bson>=0.5.9,<1.0.0 requests-toolbelt>=1.0.0,<2.0.0 +pyotp==2.9.0 diff --git a/tests/unit/contentstack/test_contentstack.py b/tests/unit/contentstack/test_contentstack.py index 3ebef2d..e12cb26 100644 --- a/tests/unit/contentstack/test_contentstack.py +++ b/tests/unit/contentstack/test_contentstack.py @@ -8,25 +8,25 @@ class ContentstackRegionUnitTests(unittest.TestCase): def test_au_region(self): """Test that au region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='au') - expected_endpoint = 'https://au-api.contentstack.io/v3/' + expected_endpoint = 'https://au-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_eu_region(self): """Test that gcp-eu region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='gcp-eu') - expected_endpoint = 'https://gcp-eu-api.contentstack.io/v3/' + expected_endpoint = 'https://gcp-eu-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_azure_eu_region(self): """Test that azure-eu region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='azure-eu') - expected_endpoint = 'https://azure-eu-api.contentstack.io/v3/' + expected_endpoint = 'https://azure-eu-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_azure_na_region(self): """Test that azure-na region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='azure-na') - expected_endpoint = 'https://azure-na-api.contentstack.io/v3/' + expected_endpoint = 'https://azure-na-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_au_region_with_custom_host(self): @@ -34,9 +34,9 @@ def test_au_region_with_custom_host(self): client = contentstack_management.Client( authtoken='your_authtoken', region='au', - host='custom.contentstack.io' + host='example.com' ) - expected_endpoint = 'https://au-custom.contentstack.io/v3/' + expected_endpoint = 'https://au-api.example.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_eu_region_with_custom_host(self): @@ -46,19 +46,19 @@ def test_gcp_eu_region_with_custom_host(self): region='gcp-eu', host='custom.contentstack.io' ) - expected_endpoint = 'https://gcp-eu-custom.contentstack.io/v3/' + expected_endpoint = 'https://gcp-eu-api.custom.contentstack.io/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_au_region_enum_value(self): """Test that au region using enum value creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region=Region.AU.value) - expected_endpoint = 'https://au-api.contentstack.io/v3/' + expected_endpoint = 'https://au-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_eu_region_enum_value(self): """Test that gcp-eu region using enum value creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region=Region.GCP_EU.value) - expected_endpoint = 'https://gcp-eu-api.contentstack.io/v3/' + expected_endpoint = 'https://gcp-eu-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_au_region_with_custom_scheme(self): @@ -68,7 +68,7 @@ def test_au_region_with_custom_scheme(self): region='au', scheme='http://' ) - expected_endpoint = 'http://au-api.contentstack.io/v3/' + expected_endpoint = 'http://au-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_eu_region_with_custom_scheme(self): @@ -78,7 +78,7 @@ def test_gcp_eu_region_with_custom_scheme(self): region='gcp-eu', scheme='http://' ) - expected_endpoint = 'http://gcp-eu-api.contentstack.io/v3/' + expected_endpoint = 'http://gcp-eu-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_au_region_with_custom_version(self): @@ -88,7 +88,7 @@ def test_au_region_with_custom_version(self): region='au', version='v2' ) - expected_endpoint = 'https://au-api.contentstack.io/v2/' + expected_endpoint = 'https://au-api.contentstack.com/v2/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_eu_region_with_custom_version(self): @@ -98,7 +98,7 @@ def test_gcp_eu_region_with_custom_version(self): region='gcp-eu', version='v2' ) - expected_endpoint = 'https://gcp-eu-api.contentstack.io/v2/' + expected_endpoint = 'https://gcp-eu-api.contentstack.com/v2/' self.assertEqual(client.endpoint, expected_endpoint) def test_au_region_headers(self): @@ -222,13 +222,13 @@ def test_us_region_default_behavior(self): def test_eu_region(self): """Test that eu region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='eu') - expected_endpoint = 'https://eu-api.contentstack.io/v3/' + expected_endpoint = 'https://eu-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_gcp_na_region(self): """Test that gcp-na region creates the correct endpoint URL""" client = contentstack_management.Client(authtoken='your_authtoken', region='gcp-na') - expected_endpoint = 'https://gcp-na-api.contentstack.io/v3/' + expected_endpoint = 'https://gcp-na-api.contentstack.com/v3/' self.assertEqual(client.endpoint, expected_endpoint) def test_region_with_none_host(self): diff --git a/tests/unit/contentstack/test_contentstack_integration.py b/tests/unit/contentstack/test_contentstack_integration.py index e1370ef..8adbcef 100644 --- a/tests/unit/contentstack/test_contentstack_integration.py +++ b/tests/unit/contentstack/test_contentstack_integration.py @@ -115,7 +115,7 @@ def test_region_endpoint_construction_logic(self): # Test non-US region with default host client = contentstack_management.Client(region='eu') - self.assertEqual(client.endpoint, 'https://eu-api.contentstack.io/v3/') + self.assertEqual(client.endpoint, 'https://eu-api.contentstack.com/v3/') # Skip custom host tests due to implementation issues # Test custom host without region diff --git a/tests/unit/contentstack/test_totp_login.py b/tests/unit/contentstack/test_totp_login.py new file mode 100644 index 0000000..bf6bd7e --- /dev/null +++ b/tests/unit/contentstack/test_totp_login.py @@ -0,0 +1,146 @@ +import unittest +import os +import sys +from unittest.mock import patch, MagicMock + +# Add the contentstack_management module to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..')) + +import contentstack_management +from contentstack_management.contentstack import Client + + +class TOTPLoginTests(unittest.TestCase): + """Unit tests for TOTP login functionality in Contentstack Management Python SDK""" + + def setUp(self): + """Set up test fixtures before each test method""" + self.client = Client() + self.test_email = "test@example.com" + self.test_password = "test_password" + self.test_secret = "JBSWY3DPEHPK3PXP" # Standard test secret for TOTP + self.test_tfa_token = "123456" + + def tearDown(self): + """Clean up after each test method""" + # Clean up environment variables + if 'MFA_SECRET' in os.environ: + del os.environ['MFA_SECRET'] + + def test_login_method_signature_with_totp(self): + """Test that login method accepts TOTP parameters""" + client = contentstack_management.Client() + # Test that the method exists and can be called with the expected parameters + self.assertTrue(hasattr(client, 'login')) + self.assertTrue(callable(client.login)) + + # Test that the method accepts TOTP parameters without error + try: + client.login(self.test_email, self.test_password, tfa_token=self.test_tfa_token) + client.login(self.test_email, self.test_password, mfa_secret=self.test_secret) + client.login(self.test_email, self.test_password, tfa_token=self.test_tfa_token, mfa_secret=self.test_secret) + except Exception as e: + self.fail(f"Login method should accept TOTP parameters without error: {e}") + + def test_generate_totp_method(self): + """Test the _generate_totp method generates correct TOTP codes""" + # Test with a known secret and verify the TOTP generation + totp_code = self.client._generate_totp(self.test_secret) + + # Verify the TOTP code is a 6-digit string + self.assertIsInstance(totp_code, str) + self.assertEqual(len(totp_code), 6) + self.assertTrue(totp_code.isdigit()) + + def test_login_with_mfa_secret_generates_totp(self): + """Test that login with mfa_secret generates TOTP automatically""" + with patch.object(self.client, 'client') as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + mock_client.post.return_value = mock_response + + # Mock the UserSession class + with patch('contentstack_management.user_session.user_session.UserSession') as mock_user_session: + mock_session_instance = MagicMock() + mock_session_instance.login.return_value = mock_response + mock_user_session.return_value = mock_session_instance + + # Mock the _generate_totp method to return a predictable value + with patch.object(self.client, '_generate_totp', return_value='654321') as mock_generate_totp: + result = self.client.login( + self.test_email, + self.test_password, + mfa_secret=self.test_secret + ) + + # Verify _generate_totp was called with the secret + mock_generate_totp.assert_called_once_with(self.test_secret) + + # Verify UserSession was called with generated TOTP + mock_session_instance.login.assert_called_once_with( + self.test_email, + self.test_password, + '654321' + ) + self.assertEqual(result, mock_response) + + def test_login_with_environment_variable(self): + """Test that login uses MFA_SECRET environment variable when mfa_secret is not provided""" + # Set environment variable + os.environ['MFA_SECRET'] = self.test_secret + + with patch.object(self.client, 'client') as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + mock_client.post.return_value = mock_response + + # Mock the UserSession class + with patch('contentstack_management.user_session.user_session.UserSession') as mock_user_session: + mock_session_instance = MagicMock() + mock_session_instance.login.return_value = mock_response + mock_user_session.return_value = mock_session_instance + + # Mock the _generate_totp method + with patch.object(self.client, '_generate_totp', return_value='789012') as mock_generate_totp: + result = self.client.login(self.test_email, self.test_password) + + # Verify _generate_totp was called with the environment secret + mock_generate_totp.assert_called_once_with(self.test_secret) + + # Verify UserSession was called with generated TOTP + mock_session_instance.login.assert_called_once_with( + self.test_email, + self.test_password, + '789012' + ) + self.assertEqual(result, mock_response) + + def test_backward_compatibility(self): + """Test that existing login patterns continue to work (backward compatibility)""" + with patch.object(self.client, 'client') as mock_client: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + mock_client.post.return_value = mock_response + + # Mock the UserSession class + with patch('contentstack_management.user_session.user_session.UserSession') as mock_user_session: + mock_session_instance = MagicMock() + mock_session_instance.login.return_value = mock_response + mock_user_session.return_value = mock_session_instance + + # Test old pattern: client.login(email, password) + result1 = self.client.login(self.test_email, self.test_password) + + # Test old pattern: client.login(email, password, tfa_token) + result2 = self.client.login(self.test_email, self.test_password, self.test_tfa_token) + + # Both should work without errors + self.assertEqual(result1, mock_response) + self.assertEqual(result2, mock_response) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/test_oauth_handler.py b/tests/unit/test_oauth_handler.py new file mode 100644 index 0000000..5b3564f --- /dev/null +++ b/tests/unit/test_oauth_handler.py @@ -0,0 +1,367 @@ +""" +Unit tests for OAuth Handler functionality. +""" + +import json +import time +import unittest +from unittest.mock import Mock, patch, MagicMock +import requests + +from contentstack_management.oauth.oauth_handler import OAuthHandler + + +class TestOAuthHandler(unittest.TestCase): + """Test cases for OAuthHandler class.""" + + def setUp(self): + """Set up test fixtures.""" + self.app_id = "test-app-id" + self.client_id = "test-client-id" + self.redirect_uri = "http://localhost:3000/callback" + self.client_secret = "test-client-secret" + self.scope = ["read", "write"] + self.api_client = Mock() + self.api_client.headers = {} + self.api_client.endpoint = "https://api.contentstack.io/v3/" + self.api_client.oauth = {} + + # Create OAuth handler with client secret (standard flow) + self.oauth_handler_with_secret = OAuthHandler( + app_id=self.app_id, + client_id=self.client_id, + redirect_uri=self.redirect_uri, + client_secret=self.client_secret, + scope=self.scope, + api_client=self.api_client + ) + + # Create OAuth handler without client secret (PKCE flow) + self.oauth_handler_pkce = OAuthHandler( + app_id=self.app_id, + client_id=self.client_id, + redirect_uri=self.redirect_uri, + scope=self.scope, + api_client=self.api_client + ) + + def test_initialization_with_client_secret(self): + """Test OAuth handler initialization with client secret.""" + self.assertEqual(self.oauth_handler_with_secret.app_id, self.app_id) + self.assertEqual(self.oauth_handler_with_secret.client_id, self.client_id) + self.assertEqual(self.oauth_handler_with_secret.redirect_uri, self.redirect_uri) + self.assertEqual(self.oauth_handler_with_secret.client_secret, self.client_secret) + self.assertEqual(self.oauth_handler_with_secret.scope, 'read write') # Converted to string + self.assertFalse(self.oauth_handler_with_secret.use_pkce) + self.assertIsNone(self.oauth_handler_with_secret.code_verifier) + self.assertIsNone(self.oauth_handler_with_secret.code_challenge) + + def test_initialization_without_client_secret(self): + """Test OAuth handler initialization without client secret (PKCE).""" + self.assertEqual(self.oauth_handler_pkce.app_id, self.app_id) + self.assertEqual(self.oauth_handler_pkce.client_id, self.client_id) + self.assertEqual(self.oauth_handler_pkce.redirect_uri, self.redirect_uri) + self.assertIsNone(self.oauth_handler_pkce.client_secret) + self.assertEqual(self.oauth_handler_pkce.scope, 'read write') # Converted to string + self.assertTrue(self.oauth_handler_pkce.use_pkce) + self.assertIsNotNone(self.oauth_handler_pkce.code_verifier) + self.assertIsNotNone(self.oauth_handler_pkce.code_challenge) + + def test_generate_code_verifier(self): + """Test code verifier generation.""" + verifier = self.oauth_handler_pkce._generate_code_verifier() + self.assertIsInstance(verifier, str) + self.assertGreater(len(verifier), 0) + + # Test with custom length (secrets.token_urlsafe returns base64 encoded string) + verifier_short = self.oauth_handler_pkce._generate_code_verifier(32) + # Base64 encoding of 32 bytes results in ~43 characters + self.assertGreaterEqual(len(verifier_short), 32) + + def test_generate_code_challenge(self): + """Test code challenge generation.""" + verifier = "test-code-verifier" + challenge = self.oauth_handler_pkce._generate_code_challenge(verifier) + self.assertIsInstance(challenge, str) + self.assertGreater(len(challenge), 0) + + def test_get_headers(self): + """Test header generation.""" + headers = self.oauth_handler_with_secret._get_headers() + expected_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + self.assertEqual(headers, expected_headers) + + def test_authorize_with_client_secret(self): + """Test authorization URL generation with client secret.""" + auth_url = self.oauth_handler_with_secret.authorize() + + self.assertIn(self.app_id, auth_url) + self.assertIn(self.client_id, auth_url) + self.assertIn("http%3A%2F%2Flocalhost%3A3000%2Fcallback", auth_url) # URL encoded + self.assertIn("response_type=code", auth_url) + # Note: scope is not included in the authorization URL as per Contentstack OAuth implementation + self.assertNotIn("code_challenge", auth_url) + + def test_authorize_with_pkce(self): + """Test authorization URL generation with PKCE.""" + auth_url = self.oauth_handler_pkce.authorize() + + self.assertIn(self.app_id, auth_url) + self.assertIn(self.client_id, auth_url) + self.assertIn("http%3A%2F%2Flocalhost%3A3000%2Fcallback", auth_url) # URL encoded + self.assertIn("response_type=code", auth_url) + # Note: scope is not included in the authorization URL as per Contentstack OAuth implementation + self.assertIn("code_challenge", auth_url) + self.assertIn("code_challenge_method=S256", auth_url) + + def test_authorize_without_scope(self): + """Test authorization URL generation without scope.""" + oauth_handler = OAuthHandler( + app_id=self.app_id, + client_id=self.client_id, + redirect_uri=self.redirect_uri + ) + auth_url = oauth_handler.authorize() + + self.assertNotIn("scope=", auth_url) + + def test_handle_redirect_success(self): + """Test successful redirect handling.""" + redirect_url = "http://localhost:3000/callback?code=test-auth-code&state=test-state" + + # Mock the token exchange response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "organization_uid": "test-org-uid", + "user_uid": "test-user-uid" + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + token_data = self.oauth_handler_with_secret.handle_redirect(redirect_url) + + self.assertEqual(token_data["access_token"], "test-access-token") + self.assertEqual(token_data["refresh_token"], "test-refresh-token") + self.assertEqual(self.oauth_handler_with_secret.get_access_token(), "test-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_refresh_token(), "test-refresh-token") + self.assertEqual(self.oauth_handler_with_secret.get_organization_uid(), "test-org-uid") + self.assertEqual(self.oauth_handler_with_secret.get_user_uid(), "test-user-uid") + + def test_handle_redirect_no_code(self): + """Test redirect handling with no authorization code.""" + redirect_url = "http://localhost:3000/callback?error=access_denied" + + with self.assertRaises(ValueError) as context: + self.oauth_handler_with_secret.handle_redirect(redirect_url) + + self.assertIn("Authorization code not found", str(context.exception)) + + def test_exchange_code_for_token_success(self): + """Test successful code exchange for token.""" + auth_code = "test-auth-code" + + # Mock the token exchange response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "organization_uid": "test-org-uid", + "user_uid": "test-user-uid" + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + token_data = self.oauth_handler_with_secret.exchange_code_for_token(auth_code) + + self.assertEqual(token_data["access_token"], "test-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_access_token(), "test-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_refresh_token(), "test-refresh-token") + + def test_exchange_code_for_token_with_pkce(self): + """Test code exchange with PKCE flow.""" + auth_code = "test-auth-code" + + # Mock the token exchange response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response) as mock_post: + self.oauth_handler_pkce.exchange_code_for_token(auth_code) + + # Verify that code_verifier was included in the request + call_args = mock_post.call_args + self.assertIn("code_verifier", call_args[1]["data"]) + + def test_exchange_code_for_token_failure(self): + """Test code exchange failure.""" + auth_code = "invalid-auth-code" + + # Mock the failed response + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.RequestException("Token exchange failed") + + with patch('requests.post', return_value=mock_response): + with self.assertRaises(requests.RequestException): + self.oauth_handler_with_secret.exchange_code_for_token(auth_code) + + def test_refresh_access_token_success(self): + """Test successful access token refresh.""" + # Set up initial tokens + self.oauth_handler_with_secret.set_refresh_token("test-refresh-token") + + # Mock the refresh response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + new_token = self.oauth_handler_with_secret.refresh_access_token() + + self.assertEqual(new_token, "new-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_access_token(), "new-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_refresh_token(), "new-refresh-token") + + def test_refresh_access_token_no_refresh_token(self): + """Test refresh access token without refresh token.""" + with self.assertRaises(ValueError) as context: + self.oauth_handler_with_secret.refresh_access_token() + + self.assertIn("No refresh token available", str(context.exception)) + + def test_refresh_access_token_failure(self): + """Test refresh access token failure.""" + self.oauth_handler_with_secret.set_refresh_token("invalid-refresh-token") + + # Mock the failed response + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.RequestException("Refresh failed") + + with patch('requests.post', return_value=mock_response): + with self.assertRaises(requests.RequestException): + self.oauth_handler_with_secret.refresh_access_token() + + def test_is_token_expired(self): + """Test token expiry checking.""" + # No expiry time set + self.assertTrue(self.oauth_handler_with_secret.is_token_expired()) + + # Set future expiry time + future_time = time.time() + 3600 + self.oauth_handler_with_secret.set_token_expiry_time(future_time) + self.assertFalse(self.oauth_handler_with_secret.is_token_expired()) + + # Set past expiry time + past_time = time.time() - 3600 + self.oauth_handler_with_secret.set_token_expiry_time(past_time) + self.assertTrue(self.oauth_handler_with_secret.is_token_expired()) + + def test_get_valid_access_token_with_valid_token(self): + """Test getting valid access token when token is valid.""" + self.oauth_handler_with_secret.set_access_token("valid-token") + self.oauth_handler_with_secret.set_token_expiry_time(time.time() + 3600) + + valid_token = self.oauth_handler_with_secret.get_valid_access_token() + self.assertEqual(valid_token, "valid-token") + + def test_get_valid_access_token_with_refresh(self): + """Test getting valid access token with refresh.""" + self.oauth_handler_with_secret.set_access_token("expired-token") + self.oauth_handler_with_secret.set_refresh_token("valid-refresh-token") + self.oauth_handler_with_secret.set_token_expiry_time(time.time() - 3600) + + # Mock the refresh response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "new-valid-token", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + valid_token = self.oauth_handler_with_secret.get_valid_access_token() + self.assertEqual(valid_token, "new-valid-token") + + def test_get_valid_access_token_no_refresh_token(self): + """Test getting valid access token without refresh token.""" + self.oauth_handler_with_secret.set_access_token("expired-token") + self.oauth_handler_with_secret.set_token_expiry_time(time.time() - 3600) + + with self.assertRaises(ValueError) as context: + self.oauth_handler_with_secret.get_valid_access_token() + + self.assertIn("No refresh token available", str(context.exception)) + + def test_logout_success(self): + """Test successful logout.""" + # Set up tokens + self.oauth_handler_with_secret.set_access_token("test-token") + self.oauth_handler_with_secret.set_refresh_token("test-refresh-token") + + # Mock the revoke response + mock_response = Mock() + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + result = self.oauth_handler_with_secret.logout() + + self.assertTrue(result) + self.assertIsNone(self.oauth_handler_with_secret.get_access_token()) + self.assertIsNone(self.oauth_handler_with_secret.get_refresh_token()) + + def test_logout_without_revoke(self): + """Test logout without revoking authorization.""" + # Set up tokens + self.oauth_handler_with_secret.set_access_token("test-token") + self.oauth_handler_with_secret.set_refresh_token("test-refresh-token") + + result = self.oauth_handler_with_secret.logout(revoke_authorization=False) + + self.assertTrue(result) + self.assertIsNone(self.oauth_handler_with_secret.get_access_token()) + self.assertIsNone(self.oauth_handler_with_secret.get_refresh_token()) + + # Note: Tests for get_oauth_app_authorization and revoke_oauth_app_authorization + # methods removed as these methods were removed from OAuthHandler for simplicity + + def test_getter_setter_methods(self): + """Test all getter and setter methods.""" + # Test access token + self.oauth_handler_with_secret.set_access_token("test-access-token") + self.assertEqual(self.oauth_handler_with_secret.get_access_token(), "test-access-token") + + # Test refresh token + self.oauth_handler_with_secret.set_refresh_token("test-refresh-token") + self.assertEqual(self.oauth_handler_with_secret.get_refresh_token(), "test-refresh-token") + + # Test organization UID + self.oauth_handler_with_secret.set_organization_uid("test-org-uid") + self.assertEqual(self.oauth_handler_with_secret.get_organization_uid(), "test-org-uid") + + # Test user UID + self.oauth_handler_with_secret.set_user_uid("test-user-uid") + self.assertEqual(self.oauth_handler_with_secret.get_user_uid(), "test-user-uid") + + # Test token expiry time + expiry_time = time.time() + 3600 + self.oauth_handler_with_secret.set_token_expiry_time(expiry_time) + self.assertEqual(self.oauth_handler_with_secret.get_token_expiry_time(), expiry_time) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/user_session/test_user_session_totp.py b/tests/unit/user_session/test_user_session_totp.py new file mode 100644 index 0000000..9107fee --- /dev/null +++ b/tests/unit/user_session/test_user_session_totp.py @@ -0,0 +1,144 @@ +import unittest +import os +import sys +import json +from unittest.mock import patch, MagicMock + +# Add the contentstack_management module to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', '..')) + +from contentstack_management.user_session.user_session import UserSession + + +class UserSessionTOTPTests(unittest.TestCase): + """Unit tests for TOTP-related functionality in UserSession class""" + + def setUp(self): + """Set up test fixtures before each test method""" + self.mock_client = MagicMock() + self.user_session = UserSession(self.mock_client) + self.test_email = "test@example.com" + self.test_password = "test_password" + self.test_tfa_token = "123456" + + def test_login_with_tfa_token_uses_correct_field_name(self): + """Test that login with TFA token uses 'tfa_token' field name (not 'tf_token')""" + # Mock the client post method + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + self.mock_client.post.return_value = mock_response + + # Call login with TFA token + result = self.user_session.login(self.test_email, self.test_password, self.test_tfa_token) + + # Verify the request was made correctly + self.mock_client.post.assert_called_once() + call_args = self.mock_client.post.call_args + + # Check the data - this is the critical test for the field name fix + expected_data = { + "user": { + "email": self.test_email, + "password": self.test_password, + "tfa_token": self.test_tfa_token # Should be "tfa_token", not "tf_token" + } + } + actual_data = json.loads(call_args[1]['data']) + self.assertEqual(actual_data, expected_data) + + # Verify the correct field name is used + self.assertIn("tfa_token", actual_data["user"]) + self.assertNotIn("tf_token", actual_data["user"]) + self.assertEqual(actual_data["user"]["tfa_token"], self.test_tfa_token) + + # Check the response + self.assertEqual(result, mock_response) + + def test_login_without_tfa_token(self): + """Test login without TFA token (original behavior)""" + # Mock the client post method + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + self.mock_client.post.return_value = mock_response + + # Call login without TFA token + result = self.user_session.login(self.test_email, self.test_password) + + # Verify the request was made correctly + self.mock_client.post.assert_called_once() + call_args = self.mock_client.post.call_args + + # Check the data + expected_data = { + "user": { + "email": self.test_email, + "password": self.test_password + } + } + actual_data = json.loads(call_args[1]['data']) + self.assertEqual(actual_data, expected_data) + + # Check the response + self.assertEqual(result, mock_response) + + def test_login_with_none_tfa_token(self): + """Test login with None TFA token""" + # Mock the client post method + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'user': {'authtoken': 'test_token'}} + self.mock_client.post.return_value = mock_response + + # Call login with None TFA token + result = self.user_session.login(self.test_email, self.test_password, None) + + # Verify the request was made correctly + self.mock_client.post.assert_called_once() + call_args = self.mock_client.post.call_args + + # Check the data - None should not be included + expected_data = { + "user": { + "email": self.test_email, + "password": self.test_password + } + } + actual_data = json.loads(call_args[1]['data']) + self.assertEqual(actual_data, expected_data) + + # Check the response + self.assertEqual(result, mock_response) + + def test_login_parameter_validation(self): + """Test login parameter validation""" + # Test with empty email + with self.assertRaises(PermissionError) as context: + self.user_session.login("", self.test_password, self.test_tfa_token) + self.assertIn("Email Id is required", str(context.exception)) + + # Test with empty password + with self.assertRaises(PermissionError) as context: + self.user_session.login(self.test_email, "", self.test_tfa_token) + self.assertIn("Password is required", str(context.exception)) + + def test_login_method_signature(self): + """Test that login method has the correct signature""" + import inspect + + # Get the signature of the login method + sig = inspect.signature(self.user_session.login) + params = list(sig.parameters.keys()) + + # Verify the method has the expected parameters + expected_params = ['email', 'password', 'tfa_token'] + for param in expected_params: + self.assertIn(param, params) + + # Verify parameter defaults + self.assertEqual(sig.parameters['tfa_token'].default, None) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/users/test_users.py b/tests/unit/users/test_users.py index 61eb6a3..a789d87 100644 --- a/tests/unit/users/test_users.py +++ b/tests/unit/users/test_users.py @@ -14,7 +14,8 @@ class UserUnitTests(unittest.TestCase): def setUp(self): self.client = contentstack_management.Client(host=host) - self.client.login(username, password) + # Note: Login call removed to avoid network requests in unit tests + # The actual login is not needed for testing request structure def test_get_user(self): response = self.client.user().fetch()