diff --git a/src/sentry/preprod/api/endpoints/organization_preprod_artifact_assemble.py b/src/sentry/preprod/api/endpoints/organization_preprod_artifact_assemble.py index e8804e9da0178b..36d8bc2de3c268 100644 --- a/src/sentry/preprod/api/endpoints/organization_preprod_artifact_assemble.py +++ b/src/sentry/preprod/api/endpoints/organization_preprod_artifact_assemble.py @@ -35,6 +35,31 @@ ] +def validate_vcs_parameters(data: dict[str, Any]) -> str | None: + head_sha = data.get("head_sha") + base_sha = data.get("base_sha") + + if head_sha and base_sha and head_sha == base_sha: + return f"Head SHA and base SHA cannot be the same ({head_sha}). Please provide a different base SHA." + + if not head_sha and base_sha: + return "Head SHA is required when base SHA is provided. Please provide a head SHA." + + # If any VCS parameters are provided, all required ones must be present + vcs_params = { + "head_sha": head_sha, + "head_repo_name": data.get("head_repo_name"), + "provider": data.get("provider"), + "head_ref": data.get("head_ref"), + } + + if any(vcs_params.values()) and any(not v for v in vcs_params.values()): + missing_params = [k for k, v in vcs_params.items() if not v] + return f"All required VCS parameters must be provided when using VCS features. Missing parameters: {', '.join(missing_params)}" + + return None + + def validate_preprod_artifact_schema(request_body: bytes) -> tuple[dict[str, Any], str | None]: """ Validate the JSON schema for preprod artifact assembly requests. @@ -144,22 +169,10 @@ def post(self, request: Request, project: Project) -> Response: checksum = str(data.get("checksum", "")) chunks = data.get("chunks", []) - # Validate VCS parameters - if any are provided, all required ones must be present - vcs_params = { - "head_sha": data.get("head_sha"), - "head_repo_name": data.get("head_repo_name"), - "provider": data.get("provider"), - "head_ref": data.get("head_ref"), - } - - if any(vcs_params.values()) and any(not v for v in vcs_params.values()): - missing_params = [k for k, v in vcs_params.items() if not v] - return Response( - { - "error": f"All required VCS parameters must be provided when using VCS features. Missing parameters: {', '.join(missing_params)}" - }, - status=400, - ) + # Validate VCS parameters + vcs_error = validate_vcs_parameters(data) + if vcs_error: + return Response({"error": vcs_error}, status=400) # Check if all requested chunks have been uploaded missing_chunks = find_missing_chunks(project.organization_id, set(chunks)) diff --git a/tests/sentry/preprod/api/endpoints/test_organization_preprod_artifact_assemble.py b/tests/sentry/preprod/api/endpoints/test_organization_preprod_artifact_assemble.py index 5242fea0a7667f..2b9a9f44af7663 100644 --- a/tests/sentry/preprod/api/endpoints/test_organization_preprod_artifact_assemble.py +++ b/tests/sentry/preprod/api/endpoints/test_organization_preprod_artifact_assemble.py @@ -12,6 +12,7 @@ from sentry.models.orgauthtoken import OrgAuthToken from sentry.preprod.api.endpoints.organization_preprod_artifact_assemble import ( validate_preprod_artifact_schema, + validate_vcs_parameters, ) from sentry.preprod.tasks import create_preprod_artifact from sentry.silo.base import SiloMode @@ -171,6 +172,116 @@ def test_additional_properties_rejected(self) -> None: assert result == {} +class ValidateVcsParametersTest(TestCase): + """Unit tests for VCS parameter validation function - no database required.""" + + def test_valid_minimal_no_vcs_params(self) -> None: + """Test that validation passes when no VCS params are provided.""" + data = {"checksum": "a" * 40, "chunks": []} + error = validate_vcs_parameters(data) + assert error is None + + def test_valid_complete_vcs_params(self) -> None: + """Test that validation passes when all required VCS params are provided.""" + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": "e" * 40, + "head_repo_name": "owner/repo", + "provider": "github", + "head_ref": "feature/xyz", + } + error = validate_vcs_parameters(data) + assert error is None + + def test_valid_complete_vcs_params_with_base_sha(self) -> None: + """Test that validation passes when all VCS params including base_sha are provided.""" + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": "e" * 40, + "base_sha": "f" * 40, + "head_repo_name": "owner/repo", + "provider": "github", + "head_ref": "feature/xyz", + } + error = validate_vcs_parameters(data) + assert error is None + + def test_same_head_and_base_sha(self) -> None: + """Test that validation fails when head_sha and base_sha are the same.""" + same_sha = "e" * 40 + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": same_sha, + "base_sha": same_sha, + } + error = validate_vcs_parameters(data) + assert error is not None + assert "Head SHA and base SHA cannot be the same" in error + assert same_sha in error + + def test_base_sha_without_head_sha(self) -> None: + """Test that validation fails when base_sha is provided without head_sha.""" + data = {"checksum": "a" * 40, "chunks": [], "base_sha": "f" * 40} + error = validate_vcs_parameters(data) + assert error is not None + assert "Head SHA is required when base SHA is provided" in error + + def test_missing_head_repo_name(self) -> None: + """Test that validation fails when head_repo_name is missing.""" + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": "e" * 40, + "provider": "github", + "head_ref": "feature/xyz", + } + error = validate_vcs_parameters(data) + assert error is not None + assert "Missing parameters" in error + assert "head_repo_name" in error + + def test_missing_provider(self) -> None: + """Test that validation fails when provider is missing.""" + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": "e" * 40, + "head_repo_name": "owner/repo", + "head_ref": "feature/xyz", + } + error = validate_vcs_parameters(data) + assert error is not None + assert "Missing parameters" in error + assert "provider" in error + + def test_missing_head_ref(self) -> None: + """Test that validation fails when head_ref is missing.""" + data = { + "checksum": "a" * 40, + "chunks": [], + "head_sha": "e" * 40, + "head_repo_name": "owner/repo", + "provider": "github", + } + error = validate_vcs_parameters(data) + assert error is not None + assert "Missing parameters" in error + assert "head_ref" in error + + def test_missing_multiple_params(self) -> None: + """Test that validation fails and reports all missing params.""" + data = {"checksum": "a" * 40, "chunks": [], "head_sha": "e" * 40} + error = validate_vcs_parameters(data) + assert error is not None + assert "Missing parameters" in error + assert "head_repo_name" in error + assert "provider" in error + assert "head_ref" in error + + class ProjectPreprodArtifactAssembleTest(APITestCase): """Integration tests for the full endpoint - requires database.""" @@ -804,3 +915,53 @@ def test_assemble_missing_vcs_parameters(self) -> None: assert "head_repo_name" in response.data["error"] assert "provider" in response.data["error"] assert "head_ref" in response.data["error"] + + def test_assemble_same_head_and_base_sha(self) -> None: + """Test that providing the same value for head_sha and base_sha returns a 400 error.""" + content = b"test same sha" + total_checksum = sha1(content).hexdigest() + + blob = FileBlob.from_file(ContentFile(content)) + FileBlobOwner.objects.get_or_create(organization_id=self.organization.id, blob=blob) + + same_sha = "e" * 40 + + response = self.client.post( + self.url, + data={ + "checksum": total_checksum, + "chunks": [blob.checksum], + "head_sha": same_sha, + "base_sha": same_sha, + "provider": "github", + "head_repo_name": "owner/repo", + "head_ref": "feature/xyz", + }, + HTTP_AUTHORIZATION=f"Bearer {self.token.token}", + ) + assert response.status_code == 400, response.content + assert "error" in response.data + assert "Head SHA and base SHA cannot be the same" in response.data["error"] + assert same_sha in response.data["error"] + + def test_assemble_base_sha_without_head_sha(self) -> None: + """Test that providing base_sha without head_sha returns a 400 error.""" + content = b"test base sha without head sha" + total_checksum = sha1(content).hexdigest() + + blob = FileBlob.from_file(ContentFile(content)) + FileBlobOwner.objects.get_or_create(organization_id=self.organization.id, blob=blob) + + response = self.client.post( + self.url, + data={ + "checksum": total_checksum, + "chunks": [blob.checksum], + "base_sha": "f" * 40, + # Missing head_sha + }, + HTTP_AUTHORIZATION=f"Bearer {self.token.token}", + ) + assert response.status_code == 400, response.content + assert "error" in response.data + assert "Head SHA is required when base SHA is provided" in response.data["error"]