Skip to content

Commit

Permalink
OAuth client did not check the 'state' value for /authorization using…
Browse files Browse the repository at this point in the history
… code-challenge - fixes #77
  • Loading branch information
Michael Kane Juncker committed Apr 26, 2020
1 parent 1060958 commit 168751b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
15 changes: 14 additions & 1 deletion fxa/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import base64
import hashlib
from binascii import hexlify
from six import string_types
from six.moves.urllib.parse import urlparse, urlunparse, urlencode, parse_qs

Expand Down Expand Up @@ -148,10 +149,11 @@ def authorize_code(self, sessionOrAssertion, scope=None, client_id=None,
client_id = self.client_id
assertion = self._get_identity_assertion(sessionOrAssertion, client_id)
url = "/authorization"
state = base64.urlsafe_b64encode(os.urandom(23)).decode('utf-8').rstrip("=")
body = {
"client_id": client_id,
"assertion": assertion,
"state": "x", # state is required, but we don't use it
"state": state
}
if scope is not None:
body["scope"] = scope
Expand All @@ -167,6 +169,17 @@ def authorize_code(self, sessionOrAssertion, scope=None, client_id=None,
# This flow is designed for web-based redirects.
# In order to get the code we must parse it from the redirect url.
query_params = parse_qs(urlparse(resp["redirect"]).query)

# Make sure the redirect URL is authentic
if "state" not in query_params:
error_msg = "state missing in OAuth response"
raise OutOfProtocolError(error_msg)

if state != query_params["state"][0]:
error_msg = "state mismatch in OAuth response (wanted: '{}', got: '{}')".format(
state, query_params["state"][0])
raise OutOfProtocolError(error_msg)

try:
return query_params["code"][0]
except (KeyError, IndexError, ValueError):
Expand Down
45 changes: 31 additions & 14 deletions fxa/tests/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,21 @@ class TestAuthClientAuthorizeCode(unittest.TestCase):
server_url = TEST_SERVER_URL

def setUp(self):
def authorization_callback(request):
data = json.loads(_decoded(request.body))
headers = {
'Content-Type': 'application/json'
}
body = {
'redirect': 'https://relier/page?code=qed&state={}'.format(data["state"])
}
return (200, headers, json.dumps(body))

self.client = Client("abc", "xyz", server_url=self.server_url)
body = '{"redirect": "https://relier/page?code=qed&state=blah"}'
responses.add(responses.POST,
'https://server/v1/authorization',
body=body,
content_type='application/json')
responses.add_callback(responses.POST,
'https://server/v1/authorization',
callback=authorization_callback,
content_type='application/json')

@responses.activate
def test_authorize_code_with_default_arguments(self):
Expand All @@ -288,7 +297,7 @@ def test_authorize_code_with_default_arguments(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
})

@responses.activate
Expand All @@ -300,7 +309,7 @@ def test_authorize_code_with_explicit_scope(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
"scope": "profile:email",
})

Expand All @@ -313,7 +322,7 @@ def test_authorize_code_with_explicit_client_id(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": "cba",
"state": "x",
"state": AnyStringValue(),
})

@responses.activate
Expand All @@ -330,7 +339,7 @@ def test_authorize_code_with_pkce_challenge(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
"code_challenge": challenge["code_challenge"],
"code_challenge_method": challenge["code_challenge_method"],
})
Expand All @@ -349,7 +358,7 @@ def test_authorize_code_with_session_object(self):
self.assertEquals(req_body, {
"assertion": "IDENTITY",
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
})


Expand All @@ -373,7 +382,7 @@ def test_authorize_token_with_default_arguments(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
"response_type": "token",
})

Expand All @@ -386,7 +395,7 @@ def test_authorize_token_with_explicit_scope(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
"response_type": "token",
"scope": "storage",
})
Expand All @@ -400,7 +409,7 @@ def test_authorize_token_with_explicit_client_id(self):
self.assertEquals(req_body, {
"assertion": assertion,
"client_id": "cba",
"state": "x",
"state": AnyStringValue(),
"response_type": "token",
})

Expand All @@ -418,7 +427,7 @@ def test_authorize_token_with_session_object(self):
self.assertEquals(req_body, {
"assertion": "IDENTITY",
"client_id": self.client.client_id,
"state": "x",
"state": AnyStringValue(),
"response_type": "token",
})

Expand Down Expand Up @@ -591,3 +600,11 @@ def test_monkey_patch_for_gevent(self):
self.assertEqual(fxa._utils.requests, grequests)

fxa._utils.requests = old_requests

class AnyStringValue:

def __eq__(self, other):
return isinstance(other, str)

def __repr__(self):
return 'any string'

0 comments on commit 168751b

Please sign in to comment.