Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions src/sentry/api/endpoints/prompts_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from django.db import IntegrityError, router, transaction
from django.db.models import Q
from django.http import HttpResponse
from django.utils import timezone
from rest_framework import serializers
from rest_framework.request import Request
Expand Down Expand Up @@ -42,7 +41,7 @@ class PromptsActivityEndpoint(OrganizationEndpoint):
"PUT": ApiPublishStatus.UNKNOWN,
}

def get(self, request: Request, **kwargs) -> Response:
def get(self, request: Request, organization: Organization, **kwargs) -> Response:
"""Return feature prompt status if dismissed or in snoozed period"""

if not request.user.is_authenticated:
Expand All @@ -58,14 +57,26 @@ def get(self, request: Request, **kwargs) -> Response:
return Response({"detail": "Invalid feature name " + feature}, status=400)

required_fields = prompt_config.required_fields(feature)
for field in required_fields:
if field not in request.GET:
return Response({"detail": 'Missing required field "%s"' % field}, status=400)
filters = {k: request.GET.get(k) for k in required_fields}
filters: dict[str, Any] = {}

# project_id must be provided and belong to the organization
if "project_id" in required_fields:
project_id = request.GET.get("project_id")
if not project_id:
return Response({"detail": 'Missing required field "project_id"'}, status=400)
if not Project.objects.filter(
id=project_id, organization_id=organization.id
).exists():
return Response({"detail": "Project not found"}, status=404)
filters["project_id"] = project_id

condition = Q(feature=feature, **filters)
conditions = condition if conditions is None else (conditions | condition)

result_qs = PromptsActivity.objects.filter(conditions, user_id=request.user.id)
# Always scope by organization from URL - passed directly to filter() to prevent override
result_qs = PromptsActivity.objects.filter(
conditions, user_id=request.user.id, organization_id=organization.id
)
featuredata = {k.feature: k.data for k in result_qs}
if len(features) == 1:
result = result_qs.first()
Expand All @@ -74,7 +85,7 @@ def get(self, request: Request, **kwargs) -> Response:
else:
return Response({"features": featuredata})

def put(self, request: Request, **kwargs):
def put(self, request: Request, organization: Organization, **kwargs) -> Response:
serializer = PromptsActivitySerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=400)
Expand All @@ -89,26 +100,26 @@ def put(self, request: Request, **kwargs):
if any(elem is None for elem in fields.values()):
return Response({"detail": "Missing required field"}, status=400)

# if project_id or organization_id in required fields make sure they exist
# if NOT in required fields, insert dummy value so dups aren't recorded
# Validate organization_id is present and matches URL organization
if "organization_id" not in required_fields or str(fields["organization_id"]) != str(
organization.id
):
return Response({"detail": "Organization missing or mismatched"}, status=400)
# Override with URL organization to prevent IDOR
fields["organization_id"] = organization.id

# Validate project_id if required, otherwise use dummy value to prevent duplicates
if "project_id" in required_fields:
if not Project.objects.filter(
id=fields["project_id"], organization_id=request.organization.id
).exists():
project_id = fields["project_id"]
if not project_id:
return Response({"detail": "Invalid project_id"}, status=400)
if not Project.objects.filter(id=project_id, organization_id=organization.id).exists():
return Response(
{"detail": "Project does not belong to this organization"}, status=400
)
else:
fields["project_id"] = 0

if "organization_id" in required_fields and str(fields["organization_id"]) == str(
request.organization.id
):
if not Organization.objects.filter(id=fields["organization_id"]).exists():
return Response({"detail": "Organization no longer exists"}, status=400)
else:
return Response({"detail": "Organization missing or mismatched"}, status=400)

data: dict[str, Any] = {}
now = calendar.timegm(timezone.now().utctimetuple())
if status == "snoozed":
Expand All @@ -126,4 +137,4 @@ def put(self, request: Request, **kwargs):
)
except IntegrityError:
pass
return HttpResponse(status=201)
return Response(status=201)
52 changes: 45 additions & 7 deletions tests/sentry/api/endpoints/test_prompts_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_batched_invalid_feature(self) -> None:
def test_invalid_project(self) -> None:
# Invalid project id
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand All @@ -98,7 +97,6 @@ def test_invalid_project(self) -> None:

def test_dismiss(self) -> None:
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand Down Expand Up @@ -135,7 +133,6 @@ def test_dismiss_str_id(self) -> None:
assert resp.status_code == 201, resp.content

data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand All @@ -147,7 +144,6 @@ def test_dismiss_str_id(self) -> None:

def test_snooze(self) -> None:
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand All @@ -173,7 +169,6 @@ def test_snooze(self) -> None:

def test_visible(self) -> None:
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand All @@ -199,7 +194,6 @@ def test_visible(self) -> None:

def test_visible_after_dismiss(self) -> None:
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": "releases",
}
Expand Down Expand Up @@ -235,7 +229,6 @@ def test_visible_after_dismiss(self) -> None:

def test_batched(self) -> None:
data = {
"organization_id": self.org.id,
"project_id": self.project.id,
"feature": ["releases", "alert_stream"],
}
Expand Down Expand Up @@ -290,3 +283,48 @@ def test_project_from_different_organization(self) -> None:

assert resp.status_code == 400
assert resp.data["detail"] == "Project does not belong to this organization"

def test_idor_get_project_from_different_org(self) -> None:
"""Regression test: GET cannot access projects from other organizations (IDOR)."""
other_org = self.create_organization()
other_project = self.create_project(organization=other_org)

resp = self.client.get(
self.path,
{
"project_id": str(other_project.id),
"feature": "releases",
},
)

# Should return 404 to prevent ID enumeration
assert resp.status_code == 404
assert resp.data["detail"] == "Project not found"

def test_get_empty_project_id(self) -> None:
"""Test that empty string project_id returns 400 instead of 500."""
resp = self.client.get(
self.path,
{
"project_id": "",
"feature": "releases",
},
)

assert resp.status_code == 400
assert resp.data["detail"] == 'Missing required field "project_id"'

def test_put_empty_project_id(self) -> None:
"""Test that empty string project_id in PUT returns 400 instead of 500."""
resp = self.client.put(
self.path,
{
"organization_id": self.org.id,
"project_id": "",
"feature": "releases",
"status": "dismissed",
},
)

assert resp.status_code == 400
assert resp.data["detail"] == "Invalid project_id"
Loading