Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions pymongo/auth_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE

if TYPE_CHECKING:
from pymongo.auth import MongoCredential
Expand All @@ -37,7 +38,7 @@
@dataclass
class OIDCIdPInfo:
issuer: str
clientId: str
clientId: Optional[str] = field(default=None)
requestScopes: Optional[list[str]] = field(default=None)


Expand Down Expand Up @@ -189,30 +190,43 @@ def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:

def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
# If there is a cached access token, try to authenticate with it. If
# authentication fails, it's possible the cached access token is expired. In
# that case, invalidate the access token, fetch a new access token, and try
# to authenticate again.
# authentication fails with error code 18, invalidate the access token,
# fetch a new access token, and try to authenticate again. If authentication
# fails for any other reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_machine(conn)
raise
return self._sasl_start_jwt(conn)

def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
# If we have a cached access token, try a JwtStepRequest.
# authentication fails with error code 18, invalidate the access token,
# and try to authenticate again. If authentication fails for any other
# reason, raise the error to the user.
if self.access_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
return self._authenticate_human(conn)
raise

# If we have a cached refresh token, try a JwtStepRequest with that.
# If authentication fails with error code 18, invalidate the access and
# refresh tokens, and try to authenticate again. If authentication fails for
# any other reason, raise the error to the user.
if self.refresh_token:
try:
return self._sasl_start_jwt(conn)
except Exception: # noqa: S110
pass
except OperationFailure as e:
if self._is_auth_error(e):
self.refresh_token = None
return self._authenticate_human(conn)
raise

# Start a new Two-Step SASL conversation.
# Run a PrincipalStepRequest to get the IdpInfo.
Expand Down Expand Up @@ -280,10 +294,16 @@ def _get_access_token(self) -> Optional[str]:
def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
try:
return conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg]
except OperationFailure:
self._invalidate(conn)
except OperationFailure as e:
if self._is_auth_error(e):
self._invalidate(conn)
raise

def _is_auth_error(self, err: Exception) -> bool:
if not isinstance(err, OperationFailure):
return False
return err.code == _AUTHENTICATION_FAILURE_CODE

def _invalidate(self, conn: Connection) -> None:
# Ignore the invalidation if a token gen id is given and is less than our
# current token gen id.
Expand Down
1 change: 0 additions & 1 deletion pymongo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,6 @@ def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]
"AWS_SESSION_TOKEN",
"ENVIRONMENT",
"TOKEN_RESOURCE",
"ALLOWED_HOSTS",
]
)

Expand Down
3 changes: 3 additions & 0 deletions pymongo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@
# Server code raised when re-authentication is required
_REAUTHENTICATION_REQUIRED_CODE: int = 391

# Server code raised when authentication fails.
_AUTHENTICATION_FAILURE_CODE: int = 18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use this constant instead of the magic number 18 in auth_oidc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def _gen_index_name(keys: _IndexList) -> str:
"""Generate an index name from the set of fields it is over."""
Expand Down
8 changes: 7 additions & 1 deletion test/auth/legacy/connection-string.json
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,12 @@
"valid": false,
"credential": null
},
{
"description": "should throw an exception custom callback is chosen but no callback is provided (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:custom",
"valid": false,
"credential": null
},
{
"description": "should throw an exception if neither provider nor callbacks specified (MONGODB-OIDC)",
"uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC",
Expand Down Expand Up @@ -573,4 +579,4 @@
"credential": null
}
]
}
}
Loading