Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tito97sp committed Mar 17, 2024
1 parent c05dd83 commit 46a5fd0
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1160,13 +1160,13 @@ def update_params_for_auth(
security_scheme_instance = self.configuration.security_scheme_info[security_scheme_component_name]
oauth_server_client_info = self.configuration.oauth_server_client_info
security_scheme_instance.apply_auth(
headers,
resource_path,
method,
body,
query_params_suffix,
scope_names,
oauth_server_client_info
headers=headers,
resource_path=resource_path,
method=method,
body=body,
query_params_suffix=query_params_suffix,
scope_names=scope_names,
client_info=oauth_server_client_info
)
except KeyError as ex:
raise ex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
pass

Expand All @@ -96,7 +96,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
if self.in_location is ApiKeyInLocation.COOKIE:
headers.add('Cookie', self.api_key)
Expand Down Expand Up @@ -134,7 +134,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
user_pass = f"{self.user_id}:{self.password}"
b64_user_pass = base64.b64encode(user_pass.encode(encoding=self.encoding))
Expand All @@ -156,7 +156,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
headers.add('Authorization', f"Bearer {self.access_token}")

Expand All @@ -174,7 +174,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
raise NotImplementedError("HTTPDigestSecurityScheme not yet implemented")

Expand All @@ -191,13 +191,13 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
raise NotImplementedError("MutualTLSSecurityScheme not yet implemented")


class OAuthFlowBase:
def apply_auth(
def set_token(
self,
headers: _collections.HTTPHeaderDict,
token: OauthToken
Expand All @@ -209,7 +209,7 @@ def apply_auth(


@dataclasses.dataclass
class ImplicitOAuthFlow(OAuthFlowBase):
class ImplicitOAuthFlow(__SecuritySchemeBase, OAuthFlowBase):
authorization_url: parse.ParseResult
scopes: typing.Dict[str, str]
refresh_url: typing.Optional[str] = None
Expand All @@ -224,8 +224,9 @@ def apply_auth(
resource_path: str,
method: str,
body: typing.Optional[typing.Union[str, bytes]],
scope_names: typing.Tuple[str] = (),
client_info: OauthClientInfo = {}
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
client_info: OauthServerClientInfo = {}
) -> None:
"""
Not implemented because this flow requires the user to visit a webpage and grant access
Expand All @@ -235,7 +236,7 @@ def apply_auth(


@dataclasses.dataclass
class PasswordOauthFlow(OAuthFlowBase):
class PasswordOauthFlow(__SecuritySchemeBase, OAuthFlowBase):
token_url: parse.ParseResult
scopes: typing.Dict[str, str]
refresh_url: typing.Optional[str] = None
Expand All @@ -254,12 +255,13 @@ def apply_auth(
resource_path: str,
method: str,
body: typing.Optional[typing.Union[str, bytes]],
client_info: OauthClientInfo,
scope_names: typing.Tuple[str] = (),
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
client_info: OauthServerClientInfo = {}
) -> None:
token = self._scope_names_to_token.get(scope_names)
if token:
super().apply_auth(headers, token)
super().set_token(headers, token)
return
client = self._scope_names_to_client.get(scope_names)
if client is None:
Expand All @@ -269,18 +271,18 @@ def apply_auth(
scope=scope_names
)
self._scope_names_to_token[scope_names] = client
token: OauthToken = client.fetch_token(
token = client.fetch_token(
url=parse.urlunparse(self.token_url),
username=self.username,
password=self.password
)
self._scope_names_to_token[scope_names] = token
print(token)
super().apply_auth(headers, token)
super().set_token(headers, token)


@dataclasses.dataclass
class ClientCredentialsOauthFlow(OAuthFlowBase):
class ClientCredentialsOauthFlow(__SecuritySchemeBase, OAuthFlowBase):
token_url: parse.ParseResult
scopes: typing.Dict[str, str]
refresh_url: typing.Optional[str] = None
Expand All @@ -297,12 +299,13 @@ def apply_auth(
resource_path: str,
method: str,
body: typing.Optional[typing.Union[str, bytes]],
scope_names: typing.Tuple[str] = (),
client_info: OauthClientInfo = {}
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
client_info: OauthServerClientInfo = {}
) -> None:
token = self._scope_names_to_token.get(scope_names)
if token:
super().apply_auth(headers, token)
super().set_token(headers, token)
return
client = self._scope_names_to_client.get(scope_names)
if client is None:
Expand All @@ -312,17 +315,17 @@ def apply_auth(
scope=scope_names
)
self._scope_names_to_token[scope_names] = client
token: OauthToken = client.fetch_token(
token = client.fetch_token(
url=parse.urlunparse(self.token_url),
grant_type='client_credentials'
)
self._scope_names_to_token[scope_names] = token
print(token)
super().apply_auth(headers, token)
super().set_token(headers, token)


@dataclasses.dataclass
class AuthorizationCodeOauthFlow(OAuthFlowBase):
class AuthorizationCodeOauthFlow(__SecuritySchemeBase, OAuthFlowBase):
authorization_url: parse.ParseResult
token_url: parse.ParseResult
scopes: typing.Dict[str, str]
Expand All @@ -338,8 +341,9 @@ def apply_auth(
resource_path: str,
method: str,
body: typing.Optional[typing.Union[str, bytes]],
scope_names: typing.Tuple[str] = (),
client_info: OauthClientInfo = {}
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
client_info: OauthServerClientInfo = {}
) -> None:
"""
Not implemented because this flow requires the user to visit a webpage and grant access
Expand Down Expand Up @@ -368,14 +372,14 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
if not self.flows:
raise exceptions.ApiValueError('flows are not defined and are required, define them')
if not scope_names:
raise exceptions.ApiValueError('scope_names are not defined and are required, define them')
if not oauth_server_client_info:
raise exceptions.ApiValueError('oauth_server_client_info is not defined and is required, define it')
if not client_info:
raise exceptions.ApiValueError('client_info is not defined and is required, define it')
chosen_flows = []
for flow in [self.flows.implicit, self.flows.password, self.flows.client_credentials, self.flows.authorization_code]:
if flow is None:
Expand All @@ -396,20 +400,21 @@ def apply_auth(
"flow may contain the scopes"
)
chosen_flow = chosen_flows[0]
if chosen_flow.auth_or_token_url.netloc not in oauth_server_client_info:
if chosen_flow.auth_or_token_url.netloc not in client_info:
raise exceptions.ApiValueError(
f"oauth_server_client_info is missing info for oauth server "
f"client_info is missing info for oauth server "
"hostname={chosen_flow.auth_or_token_url.netloc}. Add it to you api_configuration"
)
client_info = oauth_server_client_info[chosen_flow.auth_or_token_url.netloc]
client_info = client_info[chosen_flow.auth_or_token_url.netloc]
# note: scope input must be sorted tuple
chosen_flow.apply_auth(
headers=headers,
resource_path=resource_path,
method=method,
body=body,
query_params_suffix=query_params_suffix,
scope_names=scope_names,
client_info=client_info
client_info=client_info,
)


Expand All @@ -425,7 +430,7 @@ def apply_auth(
body: typing.Optional[typing.Union[str, bytes]],
query_params_suffix: typing.Optional[str],
scope_names: typing.Tuple[str, ...] = (),
oauth_server_client_info: OauthServerClientInfo = {}
client_info: OauthServerClientInfo = {}
) -> None:
raise NotImplementedError("OpenIdConnectSecurityScheme not yet implemented")

Expand Down
14 changes: 7 additions & 7 deletions samples/client/petstore/python/src/petstore_api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,13 +1162,13 @@ def update_params_for_auth(
security_scheme_instance = self.configuration.security_scheme_info[security_scheme_component_name]
oauth_server_client_info = self.configuration.oauth_server_client_info
security_scheme_instance.apply_auth(
headers,
resource_path,
method,
body,
query_params_suffix,
scope_names,
oauth_server_client_info
headers=headers,
resource_path=resource_path,
method=method,
body=body,
query_params_suffix=query_params_suffix,
scope_names=scope_names,
client_info=oauth_server_client_info
)
except KeyError as ex:
raise ex
Expand Down
Loading

0 comments on commit 46a5fd0

Please sign in to comment.