diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 6e58f630d..36b8f0cb7 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -31,6 +31,7 @@ .. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1 """ +from datetime import datetime import io import json @@ -66,6 +67,7 @@ def __init__( client_secret=None, scopes=None, quota_project_id=None, + expiry=None, ): """ Args: @@ -95,6 +97,7 @@ def __init__( """ super(Credentials, self).__init__() self.token = token + self.expiry = expiry self._refresh_token = refresh_token self._id_token = id_token self._scopes = scopes @@ -128,6 +131,11 @@ def refresh_token(self): """Optional[str]: The OAuth 2.0 refresh token.""" return self._refresh_token + @property + def scopes(self): + """Optional[str]: The OAuth 2.0 permission scopes.""" + return self._scopes + @property def token_uri(self): """Optional[str]: The OAuth 2.0 authorization server's token endpoint @@ -241,16 +249,30 @@ def from_authorized_user_info(cls, info, scopes=None): "fields {}.".format(", ".join(missing)) ) + # access token expiry (datetime obj); auto-expire if not saved + expiry = info.get("expiry") + if expiry: + expiry = datetime.strptime( + expiry.rstrip("Z").split(".")[0], "%Y-%m-%dT%H:%M:%S" + ) + else: + expiry = _helpers.utcnow() - _helpers.CLOCK_SKEW + + # process scopes, which needs to be a seq + if scopes is None and "scopes" in info: + scopes = info.get("scopes") + if isinstance(scopes, str): + scopes = scopes.split(" ") + return cls( - None, # No access token, must be refreshed. - refresh_token=info["refresh_token"], - token_uri=_GOOGLE_OAUTH2_TOKEN_ENDPOINT, + token=info.get("token"), + refresh_token=info.get("refresh_token"), + token_uri=_GOOGLE_OAUTH2_TOKEN_ENDPOINT, # always overrides scopes=scopes, - client_id=info["client_id"], - client_secret=info["client_secret"], - quota_project_id=info.get( - "quota_project_id" - ), # quota project may not exist + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + quota_project_id=info.get("quota_project_id"), # may not exist + expiry=expiry, ) @classmethod @@ -294,8 +316,10 @@ def to_json(self, strip=None): "client_secret": self.client_secret, "scopes": self.scopes, } + if self.expiry: # flatten expiry timestamp + prep["expiry"] = self.expiry.isoformat() + "Z" - # Remove empty entries + # Remove empty entries (those which are None) prep = {k: v for k, v in prep.items() if v is not None} # Remove entries that explicitely need to be removed @@ -316,7 +340,6 @@ class UserAccessTokenCredentials(credentials.CredentialsWithQuotaProject): specified, the current active account will be used. quota_project_id (Optional[str]): The project ID used for quota and billing. - """ def __init__(self, account=None, quota_project_id=None): diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index ceb8cdfd5..ee8b8a211 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -359,6 +359,20 @@ def test_from_authorized_user_info(self): assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT assert creds.scopes == scopes + info["scopes"] = "email" # single non-array scope from file + creds = credentials.Credentials.from_authorized_user_info(info) + assert creds.scopes == [info["scopes"]] + + info["scopes"] = ["email", "profile"] # array scope from file + creds = credentials.Credentials.from_authorized_user_info(info) + assert creds.scopes == info["scopes"] + + expiry = datetime.datetime(2020, 8, 14, 15, 54, 1) + info["expiry"] = expiry.isoformat() + "Z" + creds = credentials.Credentials.from_authorized_user_info(info) + assert creds.expiry == expiry + assert creds.expired + def test_from_authorized_user_file(self): info = AUTH_USER_INFO.copy() @@ -381,7 +395,10 @@ def test_from_authorized_user_file(self): def test_to_json(self): info = AUTH_USER_INFO.copy() + expiry = datetime.datetime(2020, 8, 14, 15, 54, 1) + info["expiry"] = expiry.isoformat() + "Z" creds = credentials.Credentials.from_authorized_user_info(info) + assert creds.expiry == expiry # Test with no `strip` arg json_output = creds.to_json() @@ -392,6 +409,7 @@ def test_to_json(self): assert json_asdict.get("client_id") == creds.client_id assert json_asdict.get("scopes") == creds.scopes assert json_asdict.get("client_secret") == creds.client_secret + assert json_asdict.get("expiry") == info["expiry"] # Test with a `strip` arg json_output = creds.to_json(strip=["client_secret"]) @@ -403,6 +421,12 @@ def test_to_json(self): assert json_asdict.get("scopes") == creds.scopes assert json_asdict.get("client_secret") is None + # Test with no expiry + creds.expiry = None + json_output = creds.to_json() + json_asdict = json.loads(json_output) + assert json_asdict.get("expiry") is None + def test_pickle_and_unpickle(self): creds = self.make_credentials() unpickled = pickle.loads(pickle.dumps(creds))