Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sentry/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import urlparse

from django.utils import timezone
from rest_framework.request import Request

from sentry import options
from sentry.auth.access import get_cached_organization_member
Expand Down Expand Up @@ -182,3 +183,10 @@ def generate_region_url() -> str:
if not region_url_template or not region:
return options.get("system.url-prefix")
return region_url_template.replace("{region}", region)


def generate_url_prefix(request: Request) -> str:
url_prefix = options.get("system.url-prefix")
if request.subdomain:
url_prefix = generate_organization_url(request.subdomain)
return url_prefix
4 changes: 2 additions & 2 deletions src/sentry/auth/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,8 @@ def get_initial_state(self):
state.update({"flow": self.flow})
return state

def get_redirect_url(self):
return absolute_uri(reverse("sentry-auth-sso"))
def get_redirect_url(self, url_prefix=None):
return absolute_uri(reverse("sentry-auth-sso"), url_prefix=url_prefix)

def dispatch_to(self, step: View):
return step.dispatch(request=self.request, helper=self)
Expand Down
9 changes: 7 additions & 2 deletions src/sentry/auth/providers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rest_framework.request import Request
from rest_framework.response import Response

from sentry.api.utils import generate_url_prefix
from sentry.auth.exceptions import IdentityNotValid
from sentry.auth.provider import Provider
from sentry.auth.view import AuthView
Expand Down Expand Up @@ -52,7 +53,10 @@ def dispatch(self, request: Request, helper) -> Response:

state = uuid4().hex

params = self.get_authorize_params(state=state, redirect_uri=helper.get_redirect_url())
url_prefix = generate_url_prefix(request)
params = self.get_authorize_params(
state=state, redirect_uri=helper.get_redirect_url(url_prefix)
)
redirect_uri = f"{self.get_authorize_url()}?{urlencode(params)}"

helper.bind_state("state", state)
Expand Down Expand Up @@ -84,8 +88,9 @@ def get_token_params(self, code, redirect_uri):
}

def exchange_token(self, request: Request, helper, code):
url_prefix = generate_url_prefix(request)
# TODO: this needs the auth yet
data = self.get_token_params(code=code, redirect_uri=helper.get_redirect_url())
data = self.get_token_params(code=code, redirect_uri=helper.get_redirect_url(url_prefix))
req = safe_urlopen(self.access_token_url, data=data)
body = safe_urlread(req)
if req.headers["Content-Type"].startswith("application/x-www-form-urlencoded"):
Expand Down
10 changes: 8 additions & 2 deletions src/sentry/identity/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.views.decorators.csrf import csrf_exempt
from requests.exceptions import SSLError

from sentry.api.utils import generate_url_prefix
from sentry.auth.exceptions import IdentityNotValid
from sentry.http import safe_urlopen, safe_urlread
from sentry.pipeline import PipelineView
Expand Down Expand Up @@ -241,8 +242,10 @@ def dispatch(self, request: Request, pipeline) -> Response:

state = uuid4().hex

url_prefix = generate_url_prefix(request)
params = self.get_authorize_params(
state=state, redirect_uri=absolute_uri(pipeline.redirect_url())
state=state,
redirect_uri=absolute_uri(pipeline.redirect_url(), url_prefix=url_prefix),
)
redirect_uri = f"{self.get_authorize_url()}?{urlencode(params)}"

Expand Down Expand Up @@ -275,8 +278,11 @@ def get_token_params(self, code, redirect_uri):
}

def exchange_token(self, request: Request, pipeline, code):
url_prefix = generate_url_prefix(request)
# TODO: this needs the auth yet
data = self.get_token_params(code=code, redirect_uri=absolute_uri(pipeline.redirect_url()))
data = self.get_token_params(
code=code, redirect_uri=absolute_uri(pipeline.redirect_url(), url_prefix=url_prefix)
)
verify_ssl = pipeline.config.get("verify_ssl", True)
try:
req = safe_urlopen(self.access_token_url, data=data, verify_ssl=verify_ssl)
Expand Down
9 changes: 5 additions & 4 deletions src/sentry/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
ParsedUriMatch = namedtuple("ParsedUriMatch", ["scheme", "domain", "path"])


def absolute_uri(url: Optional[str] = None) -> str:
prefix = options.get("system.url-prefix")
def absolute_uri(url: Optional[str] = None, url_prefix=None) -> str:
if url_prefix is None:
url_prefix = options.get("system.url-prefix")
if not url:
return prefix
return urljoin(prefix.rstrip("/") + "/", url.lstrip("/"))
return url_prefix
return urljoin(url_prefix.rstrip("/") + "/", url.lstrip("/"))


def origin_from_url(url):
Expand Down
122 changes: 109 additions & 13 deletions tests/sentry/identity/test_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from unittest.mock import Mock
from collections import namedtuple
from unittest import mock
from urllib.parse import parse_qs, urlparse

import responses
from django.test import Client, RequestFactory
from exam import fixture
from requests.exceptions import SSLError

import sentry.identity
from sentry.identity.oauth2 import OAuth2CallbackView
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
from sentry.identity.pipeline import IdentityProviderPipeline
from sentry.identity.providers.dummy import DummyProvider
from sentry.testutils import TestCase
from sentry.testutils.silo import control_silo_test
from sentry.utils import json

MockResponse = namedtuple("MockResponse", ["headers", "content"])


@control_silo_test
class OAuth2CallbackViewTest(TestCase):
def setUp(self):
sentry.identity.register(DummyProvider)
super().setUp()
self.request = RequestFactory().get("/")
self.request.subdomain = None

def tearDown(self):
super().tearDown()
Expand All @@ -30,17 +38,49 @@ def view(self):
client_secret="secret-value",
)

@responses.activate
def test_exchange_token_success(self):
responses.add(
responses.POST, "https://example.org/oauth/token", json={"token": "a-fake-token"}
)
pipeline = IdentityProviderPipeline(request=Mock(), provider_key="dummy")
@mock.patch("sentry.identity.oauth2.safe_urlopen")
def test_exchange_token_success(self, safe_urlopen):
headers = {"Content-Type": "application/json"}
safe_urlopen.return_value = MockResponse(headers, json.dumps({"token": "a-fake-token"}))

pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
code = "auth-code"
result = self.view.exchange_token(None, pipeline, code)
result = self.view.exchange_token(self.request, pipeline, code)
assert "token" in result
assert "a-fake-token" == result["token"]

assert safe_urlopen.called
data = safe_urlopen.call_args[1]["data"]
assert data == {
"client_id": 123456,
"client_secret": "secret-value",
"code": "auth-code",
"grant_type": "authorization_code",
"redirect_uri": "http://testserver/extensions/default/setup/",
}

@mock.patch("sentry.identity.oauth2.safe_urlopen")
def test_exchange_token_success_customer_domains(self, safe_urlopen):
headers = {"Content-Type": "application/json"}
safe_urlopen.return_value = MockResponse(headers, json.dumps({"token": "a-fake-token"}))

self.request.subdomain = "albertos-apples"
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
code = "auth-code"
result = self.view.exchange_token(self.request, pipeline, code)
assert "token" in result
assert "a-fake-token" == result["token"]

assert safe_urlopen.called
data = safe_urlopen.call_args[1]["data"]
assert data == {
"client_id": 123456,
"client_secret": "secret-value",
"code": "auth-code",
"grant_type": "authorization_code",
"redirect_uri": "http://albertos-apples.testserver/extensions/default/setup/",
}

@responses.activate
def test_exchange_token_ssl_error(self):
def ssl_error(request):
Expand All @@ -49,9 +89,9 @@ def ssl_error(request):
responses.add_callback(
responses.POST, "https://example.org/oauth/token", callback=ssl_error
)
pipeline = IdentityProviderPipeline(request=Mock(), provider_key="dummy")
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
code = "auth-code"
result = self.view.exchange_token(None, pipeline, code)
result = self.view.exchange_token(self.request, pipeline, code)
assert "token" not in result
assert "error" in result
assert "error_description" in result
Expand All @@ -60,10 +100,66 @@ def ssl_error(request):
@responses.activate
def test_exchange_token_no_json(self):
responses.add(responses.POST, "https://example.org/oauth/token", body="")
pipeline = IdentityProviderPipeline(request=Mock(), provider_key="dummy")
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
code = "auth-code"
result = self.view.exchange_token(None, pipeline, code)
result = self.view.exchange_token(self.request, pipeline, code)
assert "token" not in result
assert "error" in result
assert "error_description" in result
assert "JSON" in result["error_description"]


@control_silo_test
class OAuth2LoginViewTest(TestCase):
def setUp(self):
sentry.identity.register(DummyProvider)
super().setUp()
self.request = RequestFactory().get("/")
self.request.session = Client().session
self.request.subdomain = None

def tearDown(self):
super().tearDown()
sentry.identity.unregister(DummyProvider)

@fixture
def view(self):
return OAuth2LoginView(
authorize_url="https://example.org/oauth2/authorize",
client_id=123456,
scope="all-the-things",
)

def test_simple(self):
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
response = self.view.dispatch(self.request, pipeline)

assert response.status_code == 302
assert response["Location"].startswith("https://example.org/oauth2/authorize")
redirect_url = urlparse(response["Location"])
query = parse_qs(redirect_url.query)

assert query["client_id"][0] == "123456"
assert query["redirect_uri"][0] == "http://testserver/extensions/default/setup/"
assert query["response_type"][0] == "code"
assert query["scope"][0] == "all-the-things"
assert "state" in query

def test_customer_domains(self):
self.request.subdomain = "albertos-apples"
pipeline = IdentityProviderPipeline(request=self.request, provider_key="dummy")
response = self.view.dispatch(self.request, pipeline)

assert response.status_code == 302
assert response["Location"].startswith("https://example.org/oauth2/authorize")
redirect_url = urlparse(response["Location"])
query = parse_qs(redirect_url.query)

assert query["client_id"][0] == "123456"
assert (
query["redirect_uri"][0]
== "http://albertos-apples.testserver/extensions/default/setup/"
)
assert query["response_type"][0] == "code"
assert query["scope"][0] == "all-the-things"
assert "state" in query
36 changes: 31 additions & 5 deletions tests/sentry/web/frontend/test_auth_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,21 @@ def login_path(self):
def sso_path(self):
return reverse("sentry-auth-sso")

def initiate_oauth_flow(self):
resp = self.client.post(self.login_path, {"init": True})
def initiate_oauth_flow(self, http_host=None):
kwargs = {}
if http_host is not None:
kwargs["HTTP_HOST"] = http_host
else:
http_host = "testserver"

resp = self.client.post(self.login_path, {"init": True}, **kwargs)

assert resp.status_code == 302
redirect = urlparse(resp.get("Location", ""))
query = parse_qs(redirect.query)

assert redirect.path == "/authorize_url"
assert query["redirect_uri"][0] == "http://testserver/auth/sso/"
assert query["redirect_uri"][0] == f"http://{http_host}/auth/sso/"
assert query["client_id"][0] == "my_client_id"
assert "state" in query

Expand All @@ -87,8 +93,19 @@ def initiate_callback(self, state, auth_data, urlopen, expect_success=True, **kw
if expect_success:
assert resp.status_code == 200
assert urlopen.called
assert urlopen.call_args[1]["data"]["code"] == "1234"
assert urlopen.call_args[1]["data"]["client_secret"] == "my_client_secret"
data = urlopen.call_args[1]["data"]

http_host = "testserver"
if "HTTP_HOST" in kwargs:
http_host = kwargs["HTTP_HOST"]

assert data == {
"grant_type": "authorization_code",
"code": "1234",
"redirect_uri": f"http://{http_host}/auth/sso/",
"client_id": "my_client_id",
"client_secret": "my_client_secret",
}

return resp

Expand All @@ -100,6 +117,15 @@ def test_oauth2_flow(self):

assert auth_resp.context["existing_user"] == self.user

def test_oauth2_flow_customer_domain(self):
HTTP_HOST = "albertos-apples.testserver"
auth_data = {"id": "oauth_external_id_1234", "email": self.user.email}

state = self.initiate_oauth_flow(http_host=HTTP_HOST)
auth_resp = self.initiate_callback(state, auth_data, HTTP_HOST=HTTP_HOST)

assert auth_resp.context["existing_user"] == self.user

def test_state_mismatch(self):
auth_data = {"id": "oauth_external_id_1234", "email": self.user.email}

Expand Down