diff --git a/.github/workflows/ado/sources-upload.yml b/.github/workflows/ado/sources-upload.yml index 17c46d1b3f5..23e3a1e47ca 100644 --- a/.github/workflows/ado/sources-upload.yml +++ b/.github/workflows/ado/sources-upload.yml @@ -23,7 +23,7 @@ # Name: "ControlTower-PRCheck" # Required variables: # - ApiAudience : Entra ID audience URI for the Control Tower app -# - ApiBaseUrl : Base URL of the Control Tower service +# - ApiBaseDirectUrl : Direct base URL of the Control Tower APIM endpoint (bypasses Azure Front Door) # - AzldevCommit : Commit hash for azldev (go install ...@) # Trigger controlled by ADO branch policy — not YAML triggers. diff --git a/.github/workflows/ado/templates/sources-upload-stages.yml b/.github/workflows/ado/templates/sources-upload-stages.yml index dcc0a36edfc..3ac9432be20 100644 --- a/.github/workflows/ado/templates/sources-upload-stages.yml +++ b/.github/workflows/ado/templates/sources-upload-stages.yml @@ -111,7 +111,7 @@ stages: echo "##[endgroup]" echo "##[group]Python dependencies" - pip install -r .github/workflows/scripts/control-tower-prcheck/requirements.txt + pip install -r .github/workflows/scripts/control-tower/requirements.txt echo "##[endgroup]" displayName: "Install dependencies" env: @@ -200,7 +200,7 @@ stages: inlineScript: | set -euo pipefail - python3 .github/workflows/scripts/control-tower-prcheck/run_control_tower_prcheck.py \ + python3 .github/workflows/scripts/control-tower/run_prcheck.py \ --api-audience "$API_AUDIENCE" \ --api-base-url "$API_BASE_URL" \ --build-reason "$BUILD_REASON" \ @@ -209,7 +209,7 @@ stages: --repo-uri "$UPSTREAM_REPO_URL" env: API_AUDIENCE: $(ApiAudience) - API_BASE_URL: $(ApiBaseUrl) + API_BASE_URL: $(ApiBaseDirectUrl) BUILD_REASON: $(Build.Reason) COMPONENTS: $(components) SOURCE_COMMIT: $(sourceCommit) diff --git a/.github/workflows/scripts/control-tower-prcheck/run_control_tower_prcheck.py b/.github/workflows/scripts/control-tower/client.py similarity index 54% rename from .github/workflows/scripts/control-tower-prcheck/run_control_tower_prcheck.py rename to .github/workflows/scripts/control-tower/client.py index 825e3ad9f50..db9181ab8ed 100644 --- a/.github/workflows/scripts/control-tower-prcheck/run_control_tower_prcheck.py +++ b/.github/workflows/scripts/control-tower/client.py @@ -1,23 +1,20 @@ -"""Call the Control Tower prcheck API and wait for the resulting job to finish. +"""Shared HTTP client for Azure Linux Control Tower scenario calls. -Flow: - 1. POST ``/api/Scenario/prcheck`` with the PR context. The service responds - with a ``WorkflowJobStatusDto`` describing the job it just queued. - 2. Poll ``/api/Workflow/jobs/status/{jobId}`` until the job reaches a - terminal state (Completed / Failed / Cancelled / CancelledByAdmin / - TimedOut / Unknown) or the local poll timeout elapses. - 3. Exit 0 only if the terminal status is ``Completed``; otherwise surface - the error details from the job status payload and exit 1. +Provides: + * A retry-aware ``requests.Session``. + * Bearer-token acquisition + transparent single-shot refresh on 401. + * POST helpers for ``/api/Scenario/*`` endpoints. + * Job-status polling against ``/api/Workflow/jobs/status/{jobId}``. + * Diagnostic formatting that tolerates the three error shapes Control + Tower returns (middleware, controller, ASP.NET validation). Authentication: - Requires an active Azure CLI session (e.g. via an AzureCLI@2 pipeline + Requires an active Azure CLI session (e.g. via an ``AzureCLI@2`` pipeline task with a Workload Identity Federation service connection). ``DefaultAzureCredential`` discovers the session automatically. """ -import argparse import json -import sys import time from dataclasses import dataclass from typing import Any, Optional @@ -29,77 +26,21 @@ # JobStatus values from the Control Tower service # (azl-ControlTower/ControlTower/Shared/Models/Jobs/JobStatus.cs). -_NON_TERMINAL_STATUSES = frozenset({"Queued", "Pending", "Running"}) -_SUCCESS_STATUS = "Completed" -_TERMINAL_FAILURE_STATUSES = frozenset( +NON_TERMINAL_STATUSES = frozenset({"Queued", "Pending", "Running"}) +SUCCESS_STATUS = "Completed" +TERMINAL_FAILURE_STATUSES = frozenset( {"Failed", "Cancelled", "CancelledByAdmin", "Unknown", "TimedOut"} ) @dataclass -class _TokenHolder: +class TokenHolder: """Mutable bearer-token holder so helpers can observe in-place refreshes.""" token: str -def _parse_components(value: str) -> list[str]: - """Parse a comma-separated string into a list of stripped, non-empty names.""" - return [c.strip() for c in value.split(",") if c.strip()] - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Call the Control Tower prcheck API and wait for the job to finish.", - ) - parser.add_argument( - "--api-audience", - required=True, - help="Entra ID audience URI (e.g. api://)", - ) - parser.add_argument( - "--api-base-url", required=True, help="Base URL of the Control Tower service" - ) - parser.add_argument( - "--build-reason", - required=True, - help="ADO build reason (PullRequest, IndividualCI, …)", - ) - parser.add_argument( - "--components", - required=True, - type=_parse_components, - help="Comma-separated list of affected component names", - ) - parser.add_argument("--source-commit", default=None, help="Source commit SHA") - parser.add_argument( - "--source-branch", - default=None, - help="Source branch name (alternative to --source-commit)", - ) - parser.add_argument("--target-commit", default=None, help="Target commit SHA") - parser.add_argument( - "--target-branch", - default=None, - help="Target branch name (alternative to --target-commit)", - ) - parser.add_argument("--repo-uri", required=True, help="Upstream repository URI") - parser.add_argument( - "--poll-interval-seconds", - type=int, - default=10, - help="How often to poll the job status endpoint (default: 10).", - ) - parser.add_argument( - "--poll-timeout-seconds", - type=int, - default=7200, - help="Maximum time to wait for the job to reach a terminal state (default: 7200 = 2h).", - ) - return parser.parse_args() - - -def _make_session() -> requests.Session: +def make_session() -> requests.Session: """Create a ``requests.Session`` with retries for idempotent GETs only. Retry budget is tuned to complete quickly relative to the 10s default poll @@ -123,7 +64,7 @@ def _make_session() -> requests.Session: return session -def _get_token(credential: DefaultAzureCredential, audience: str) -> str: +def get_token(credential: DefaultAzureCredential, audience: str) -> str: """Acquire a bearer token for the given audience.""" return credential.get_token(f"{audience}/.default").token @@ -136,7 +77,7 @@ def _auth_headers(token: str) -> dict[str, str]: } -def _format_error(response: requests.Response) -> str: +def format_error(response: requests.Response) -> str: """Render a detailed diagnostic string for a failed CT response. Tolerates the three error shapes used by Control Tower: @@ -203,7 +144,7 @@ def _request_with_refresh( url: str, credential: DefaultAzureCredential, audience: str, - token_holder: _TokenHolder, + token_holder: TokenHolder, *, json_payload: Optional[dict] = None, ) -> requests.Response: @@ -217,9 +158,10 @@ def _request_with_refresh( ) if response.status_code == 401: print( - "Bearer token rejected (401) — refreshing and retrying once...", flush=True + "Bearer token rejected (401) — refreshing and retrying once...", + flush=True, ) - token_holder.token = _get_token(credential, audience) + token_holder.token = get_token(credential, audience) response = session.request( method, url, @@ -247,32 +189,48 @@ def _parse_json_object(response: requests.Response, context: str) -> dict: return body -def _post_prcheck( +def post_scenario( session: requests.Session, base_url: str, + path: str, credential: DefaultAzureCredential, audience: str, - token_holder: _TokenHolder, + token_holder: TokenHolder, payload: dict, + *, + context: str, ) -> dict: - """POST the prcheck request and return the parsed response dict.""" - url = f"{base_url}/api/Scenario/prcheck" + """POST a scenario request and return the parsed response dict. + + ``path`` is the API path (e.g. ``/api/Scenario/prcheck``) appended to + ``base_url``. A leading ``/`` is added if missing. ``context`` is used in + error messages to identify the call. + """ + if not path.startswith("/"): + path = "/" + path + url = f"{base_url}{path}" response = _request_with_refresh( - session, "POST", url, credential, audience, token_holder, json_payload=payload + session, + "POST", + url, + credential, + audience, + token_holder, + json_payload=payload, ) if not response.ok: raise RuntimeError( - "Control Tower 'prcheck' request failed.\n" + _format_error(response) + f"Control Tower '{context}' request failed.\n" + format_error(response) ) - return _parse_json_object(response, "Control Tower 'prcheck'") + return _parse_json_object(response, f"Control Tower '{context}'") -def _get_job_status( +def get_job_status( session: requests.Session, base_url: str, credential: DefaultAzureCredential, audience: str, - token_holder: _TokenHolder, + token_holder: TokenHolder, job_id: str, ) -> dict: """GET the job status. Refreshes the bearer token on 401 and retries once.""" @@ -282,7 +240,7 @@ def _get_job_status( ) if not response.ok: raise RuntimeError( - "Control Tower job status request failed.\n" + _format_error(response) + "Control Tower job status request failed.\n" + format_error(response) ) return _parse_json_object(response, "Control Tower job status") @@ -301,12 +259,12 @@ def _summarize_tasks(tasks: Any) -> str: return f"{total} tasks ({parts})" -def _poll_job_until_terminal( +def poll_until_terminal( session: requests.Session, base_url: str, credential: DefaultAzureCredential, audience: str, - token_holder: _TokenHolder, + token_holder: TokenHolder, job_id: str, poll_interval_seconds: int, poll_timeout_seconds: int, @@ -322,7 +280,7 @@ def _poll_job_until_terminal( job_status_object: Optional[dict] = None while True: - job_status_object = _get_job_status( + job_status_object = get_job_status( session, base_url, credential, audience, token_holder, job_id ) current_status = job_status_object.get("status", "Unknown") @@ -331,7 +289,9 @@ def _poll_job_until_terminal( if current_status != previous_status: task_summary = _summarize_tasks(job_status_object.get("tasks")) transition = ( - f"{previous_status} -> {current_status}" if previous_status is not None else current_status + f"{previous_status} -> {current_status}" + if previous_status is not None + else current_status ) suffix = f" | {task_summary}" if task_summary else "" print( @@ -346,7 +306,7 @@ def _poll_job_until_terminal( flush=True, ) - if current_status not in _NON_TERMINAL_STATUSES: + if current_status not in NON_TERMINAL_STATUSES: return job_status_object remaining = deadline - time.monotonic() @@ -360,13 +320,13 @@ def _poll_job_until_terminal( time.sleep(min(poll_interval_seconds, max(1, int(remaining)))) -def _print_final_status(final: dict) -> None: +def print_final_status(final: dict) -> None: """Pretty-print the final job status payload.""" print("Final job status payload:") print(json.dumps(final, indent=2, default=str)) -def _report_failure(final: dict) -> None: +def report_failure(final: dict) -> None: """Emit ADO-style error lines with the most actionable fields from ``final``.""" status = final.get("status", "Unknown") error_message = final.get("errorMessage") @@ -381,7 +341,7 @@ def _report_failure(final: dict) -> None: failed = [ t for t in tasks - if isinstance(t, dict) and t.get("status") in _TERMINAL_FAILURE_STATUSES + if isinstance(t, dict) and t.get("status") in TERMINAL_FAILURE_STATUSES ] for task in failed: name = task.get("taskName") or task.get("taskId") @@ -389,118 +349,3 @@ def _report_failure(final: dict) -> None: f"##[error]task '{name}' status={task.get('status')} " f"attempt={task.get('attemptNumber')}" ) - - -def main() -> None: - args = _parse_args() - - if args.poll_interval_seconds <= 0: - print("##[error]--poll-interval-seconds must be a positive integer.") - sys.exit(2) - if args.poll_timeout_seconds <= 0: - print("##[error]--poll-timeout-seconds must be a positive integer.") - sys.exit(2) - - # Normalize base URL to avoid accidental double slashes if the operator - # configured `ApiBaseUrl` with a trailing '/'. - base_url = args.api_base_url.rstrip("/") - - # ── Build payload ──────────────────────────────────────────────── - payload: dict = { - "components": args.components, - "buildReason": args.build_reason, - "repoUri": args.repo_uri, - } - if args.source_commit is not None: - payload["sourceCommitSha"] = args.source_commit - if args.source_branch is not None: - payload["sourceBranch"] = args.source_branch - if args.target_commit is not None: - payload["targetCommitSha"] = args.target_commit - if args.target_branch is not None: - payload["targetBranch"] = args.target_branch - - print("Calling Control Tower 'prcheck' endpoint...") - print("Payload:") - print(json.dumps(payload, indent=2)) - - if args.build_reason == "PullRequest": - print( - "Skipping Control Tower call - pull request triggers are not supported, yet." - ) - return - - if not args.components: - print( - "No affected components detected between source and target commits; " - "skipping Control Tower call." - ) - return - - # ── Acquire bearer token ───────────────────────────────────────── - credential = DefaultAzureCredential() - token_holder = _TokenHolder(token=_get_token(credential, args.api_audience)) - - session = _make_session() - - # ── Call prcheck API ───────────────────────────────────────────── - try: - prcheck_response = _post_prcheck( - session, base_url, credential, args.api_audience, token_holder, payload - ) - except RuntimeError as exc: - print(f"##[error]{exc}") - sys.exit(1) - - print("prcheck response:") - print(json.dumps(prcheck_response, indent=2, default=str)) - - job_id = prcheck_response.get("jobId") - if not job_id: - print( - "##[error]Control Tower 'prcheck' response did not include a 'jobId'. " - "Cannot poll for job status." - ) - sys.exit(1) - - # ── Poll for job completion ────────────────────────────────────── - print( - f"Polling job {job_id} every {args.poll_interval_seconds}s " - f"(timeout {args.poll_timeout_seconds}s)..." - ) - try: - final = _poll_job_until_terminal( - session, - base_url, - credential, - args.api_audience, - token_holder, - job_id, - args.poll_interval_seconds, - args.poll_timeout_seconds, - ) - except RuntimeError as exc: - print(f"##[error]{exc}") - sys.exit(1) - - if final is None: - # Local timeout — job may still be running on the service side. - print( - f"##[error]Timed out locally after {args.poll_timeout_seconds}s " - f"waiting for job {job_id} to finish. Inspect the job in Control Tower." - ) - sys.exit(1) - - _print_final_status(final) - - status = final.get("status") - if status == _SUCCESS_STATUS: - print(f"Control Tower job {job_id} completed successfully.") - return - - _report_failure(final) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/scripts/control-tower-prcheck/requirements.txt b/.github/workflows/scripts/control-tower/requirements.txt similarity index 100% rename from .github/workflows/scripts/control-tower-prcheck/requirements.txt rename to .github/workflows/scripts/control-tower/requirements.txt diff --git a/.github/workflows/scripts/control-tower/run_prcheck.py b/.github/workflows/scripts/control-tower/run_prcheck.py new file mode 100644 index 00000000000..3bc8fb75f05 --- /dev/null +++ b/.github/workflows/scripts/control-tower/run_prcheck.py @@ -0,0 +1,201 @@ +"""Call the Control Tower 'prcheck' API and wait for the resulting job to finish. + +Flow: + 1. POST ``/api/Scenario/prcheck`` with the PR context. The service responds + with a ``WorkflowJobStatusDto`` describing the job it just queued. + 2. Poll ``/api/Workflow/jobs/status/{jobId}`` until the job reaches a + terminal state (Completed / Failed / Cancelled / CancelledByAdmin / + TimedOut / Unknown) or the local poll timeout elapses. + 3. Exit 0 only if the terminal status is ``Completed``; otherwise surface + the error details from the job status payload and exit 1. +""" + +import argparse +import json +import sys + +from azure.identity import DefaultAzureCredential + +import client as ct + + +def _parse_components(value: str) -> list[str]: + """Parse a comma-separated string into a list of stripped, non-empty names.""" + return [c.strip() for c in value.split(",") if c.strip()] + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Call the Control Tower prcheck API and wait for the job to finish.", + ) + parser.add_argument( + "--api-audience", + required=True, + help="Entra ID audience URI (e.g. api://)", + ) + parser.add_argument( + "--api-base-url", required=True, help="Base URL of the Control Tower service" + ) + parser.add_argument( + "--build-reason", + required=True, + help="ADO build reason (PullRequest, IndividualCI, …)", + ) + + parser.add_argument( + "--components", + required=True, + type=_parse_components, + help="Comma-separated list of affected component names", + ) + + parser.add_argument("--source-commit", default=None, help="Source commit SHA") + parser.add_argument( + "--source-branch", + default=None, + help="Source branch name (alternative to --source-commit)", + ) + parser.add_argument("--target-commit", default=None, help="Target commit SHA") + parser.add_argument( + "--target-branch", + default=None, + help="Target branch name (alternative to --target-commit)", + ) + parser.add_argument("--repo-uri", required=True, help="Upstream repository URI") + parser.add_argument( + "--poll-interval-seconds", + type=int, + default=10, + help="How often to poll the job status endpoint (default: 10).", + ) + parser.add_argument( + "--poll-timeout-seconds", + type=int, + default=7200, + help="Maximum time to wait for the job to reach a terminal state (default: 7200 = 2h).", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + if args.poll_interval_seconds <= 0: + print("##[error]--poll-interval-seconds must be a positive integer.") + sys.exit(2) + if args.poll_timeout_seconds <= 0: + print("##[error]--poll-timeout-seconds must be a positive integer.") + sys.exit(2) + + components: list[str] = args.components + + # Normalize the base URL to avoid accidental double slashes if it was + # configured with a trailing '/'. + base_url = args.api_base_url.rstrip("/") + + # ── Build payload ──────────────────────────────────────────────── + payload: dict = { + "components": components, + "buildReason": args.build_reason, + "repoUri": args.repo_uri, + } + if args.source_commit is not None: + payload["sourceCommitSha"] = args.source_commit + if args.source_branch is not None: + payload["sourceBranch"] = args.source_branch + if args.target_commit is not None: + payload["targetCommitSha"] = args.target_commit + if args.target_branch is not None: + payload["targetBranch"] = args.target_branch + + print("Calling Control Tower 'prcheck' endpoint...") + print("Payload:") + print(json.dumps(payload, indent=2)) + + if args.build_reason == "PullRequest": + print( + "Skipping Control Tower call - pull request triggers are not supported, yet." + ) + return + + if not components: + print( + "No affected components detected between source and target commits; " + "skipping Control Tower call." + ) + return + + # ── Acquire bearer token ───────────────────────────────────────── + credential = DefaultAzureCredential() + token_holder = ct.TokenHolder(token=ct.get_token(credential, args.api_audience)) + + session = ct.make_session() + + # ── Call prcheck API ───────────────────────────────────────────── + try: + prcheck_response = ct.post_scenario( + session, + base_url, + "/api/Scenario/prcheck", + credential, + args.api_audience, + token_holder, + payload, + context="prcheck", + ) + except RuntimeError as exc: + print(f"##[error]{exc}") + sys.exit(1) + + print("prcheck response:") + print(json.dumps(prcheck_response, indent=2, default=str)) + + job_id = prcheck_response.get("jobId") + if not job_id: + print( + "##[error]Control Tower 'prcheck' response did not include a 'jobId'. " + "Cannot poll for job status." + ) + sys.exit(1) + + # ── Poll for job completion ────────────────────────────────────── + print( + f"Polling job {job_id} every {args.poll_interval_seconds}s " + f"(timeout {args.poll_timeout_seconds}s)..." + ) + try: + final = ct.poll_until_terminal( + session, + base_url, + credential, + args.api_audience, + token_holder, + job_id, + args.poll_interval_seconds, + args.poll_timeout_seconds, + ) + except RuntimeError as exc: + print(f"##[error]{exc}") + sys.exit(1) + + if final is None: + # Local timeout — job may still be running on the service side. + print( + f"##[error]Timed out locally after {args.poll_timeout_seconds}s " + f"waiting for job {job_id} to finish. Inspect the job in Control Tower." + ) + sys.exit(1) + + ct.print_final_status(final) + + status = final.get("status") + if status == ct.SUCCESS_STATUS: + print(f"Control Tower job {job_id} completed successfully.") + return + + ct.report_failure(final) + sys.exit(1) + + +if __name__ == "__main__": + main()