Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Allow clients to supply access_tokens as headers #1098

Merged
merged 4 commits into from
Oct 25, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,8 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
return bool(query_params) or bool(auth_headers)


def get_access_token_from_request(request, token_not_found_http_status=401):
Expand All @@ -1176,13 +1177,40 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
Raises:
AuthError: If there isn't an access_token in the request.
"""

auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)

return query_params[0]
return query_params[0]
18 changes: 9 additions & 9 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.types import UserID
from tests.utils import setup_test_homeserver
from tests.utils import setup_test_homeserver, mock_getRawHeaders

import pymacaroons

Expand Down Expand Up @@ -51,7 +51,7 @@ def test_get_user_by_req_user_valid_token(self):

request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)

Expand All @@ -61,7 +61,7 @@ def test_get_user_by_req_user_bad_token(self):

request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

Expand All @@ -74,7 +74,7 @@ def test_get_user_by_req_user_missing_token(self):
self.store.get_user_by_access_token = Mock(return_value=user_info)

request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

Expand All @@ -86,7 +86,7 @@ def test_get_user_by_req_appservice_valid_token(self):

request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)

Expand All @@ -96,7 +96,7 @@ def test_get_user_by_req_appservice_bad_token(self):

request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

Expand All @@ -106,7 +106,7 @@ def test_get_user_by_req_appservice_missing_token(self):
self.store.get_user_by_access_token = Mock(return_value=None)

request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

Expand All @@ -121,7 +121,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id)

Expand All @@ -135,7 +135,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

Expand Down
3 changes: 2 additions & 1 deletion tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def test_started_typing_remote_recv(self):
"user_id": self.u_onion.to_string(),
"typing": True,
}
)
),
federation_auth=True,
)

self.on_new_event.assert_has_calls([
Expand Down
2 changes: 2 additions & 0 deletions tests/rest/client/v1/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from twisted.internet import defer
from mock import Mock
from tests import unittest
from tests.utils import mock_getRawHeaders
import json


Expand All @@ -30,6 +31,7 @@ def setUp(self):
path='/_matrix/client/api/v1/createUser'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()

self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
Expand Down
2 changes: 2 additions & 0 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from twisted.internet import defer
from mock import Mock
from tests import unittest
from tests.utils import mock_getRawHeaders
import json


Expand All @@ -16,6 +17,7 @@ def setUp(self):
path='/_matrix/api/v2_alpha/register'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()

self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
Expand Down
18 changes: 14 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)


def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}

def getRawHeaders(name, default=None):
return headers.get(name, default)

return getRawHeaders


# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):

Expand All @@ -127,7 +136,7 @@ def trigger_get(self, path):

@patch('twisted.web.http.Request')
@defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request):
def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event.

Args:
Expand Down Expand Up @@ -155,9 +164,10 @@ def trigger(self, http_method, path, content, mock_request):

mock_request.getClientIP.return_value = "-"

mock_request.requestHeaders.getRawHeaders.return_value = [
"X-Matrix origin=test,key=,sig="
]
headers = {}
if federation_auth:
headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)

# return the right path if the event requires it
mock_request.path = path
Expand Down