From cc90081b8303843c63c9fa65f2e10a79f336a150 Mon Sep 17 00:00:00 2001 From: Sivaselvan32 Date: Wed, 8 Oct 2025 12:13:05 +0530 Subject: [PATCH] Features providing Policy Check API Specs --- examples/policy_check.py | 164 ++++++++++++++++++++++++++++++ src/tfe/client.py | 2 + src/tfe/errors.py | 8 ++ src/tfe/models/policy_check.py | 117 +++++++++++++++++++-- src/tfe/resources/policy_check.py | 108 ++++++++++++++++++++ 5 files changed, 390 insertions(+), 9 deletions(-) create mode 100644 examples/policy_check.py create mode 100644 src/tfe/resources/policy_check.py diff --git a/examples/policy_check.py b/examples/policy_check.py new file mode 100644 index 0000000..f47a180 --- /dev/null +++ b/examples/policy_check.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import argparse +import os + +from tfe import TFEClient, TFEConfig +from tfe.models.policy_check import PolicyCheckListOptions + + +def _print_header(title: str): + print("\n" + "=" * 80) + print(title) + print("=" * 80) + + +def main(): + parser = argparse.ArgumentParser( + description="Policy Checks demo for python-tfe SDK" + ) + parser.add_argument( + "--address", default=os.getenv("TFE_ADDRESS", "https://app.terraform.io") + ) + parser.add_argument("--token", default=os.getenv("TFE_TOKEN", "")) + parser.add_argument( + "--run-id", required=True, help="Run ID to list policy checks for" + ) + parser.add_argument( + "--policy-check-id", help="Specific policy check ID to read/override" + ) + parser.add_argument( + "--override", action="store_true", help="Override the specified policy check" + ) + parser.add_argument( + "--get-logs", + action="store_true", + help="Get logs for the specified policy check", + ) + parser.add_argument("--page", type=int, default=1) + parser.add_argument("--page-size", type=int, default=20) + args = parser.parse_args() + + if not args.token: + print("Error: TFE_TOKEN environment variable or --token argument is required") + return + + cfg = TFEConfig(address=args.address, token=args.token) + client = TFEClient(cfg) + + # 1) List all policy checks for the given run + _print_header(f"Listing policy checks for run: {args.run_id}") + + options = PolicyCheckListOptions( + page_number=args.page, + page_size=args.page_size, + ) + + try: + pc_list = client.policy_checks.list(args.run_id, options) + + print(f"Total policy checks: {pc_list.total_count}") + print(f"Page {pc_list.current_page} of {pc_list.total_pages}") + print() + + if not pc_list.items: + print("No policy checks found for this run.") + else: + for pc in pc_list.items: + print(f"- ID: {pc.id}") + print(f" Status: {pc.status}") + print(f" Scope: {pc.scope}") + if pc.result: + print( + f" Result: passed={pc.result.passed}, failed={pc.result.total_failed}" + ) + print(f" Duration: {pc.result.duration}ms") + if pc.actions: + print(f" Can Override: {pc.actions.is_overridable}") + if pc.permissions: + print(f" Has Override Permission: {pc.permissions.can_override}") + print() + + except Exception as e: + print(f"Error listing policy checks: {e}") + return + + # 2) Read a specific policy check (if policy-check-id is provided) + if args.policy_check_id: + _print_header(f"Reading policy check: {args.policy_check_id}") + + try: + pc = client.policy_checks.read(args.policy_check_id) + + print(f"ID: {pc.id}") + print(f"Status: {pc.status}") + print(f"Scope: {pc.scope}") + + if pc.result: + print("Result Summary:") + print(f" - Passed: {pc.result.passed}") + print(f" - Hard Failed: {pc.result.hard_failed}") + print(f" - Soft Failed: {pc.result.soft_failed}") + print(f" - Advisory Failed: {pc.result.advisory_failed}") + print(f" - Total Failed: {pc.result.total_failed}") + print(f" - Duration: {pc.result.duration}ms") + print(f" - Overall Result: {pc.result.result}") + + if pc.actions: + print("Actions:") + print(f" - Is Overridable: {pc.actions.is_overridable}") + + if pc.permissions: + print("Permissions:") + print(f" - Can Override: {pc.permissions.can_override}") + + if pc.status_timestamps: + print("Status Timestamps:") + if pc.status_timestamps.queued_at: + print(f" - Queued At: {pc.status_timestamps.queued_at}") + if pc.status_timestamps.passed_at: + print(f" - Passed At: {pc.status_timestamps.passed_at}") + if pc.status_timestamps.soft_failed_at: + print(f" - Soft Failed At: {pc.status_timestamps.soft_failed_at}") + if pc.status_timestamps.hard_failed_at: + print(f" - Hard Failed At: {pc.status_timestamps.hard_failed_at}") + if pc.status_timestamps.errored_at: + print(f" - Errored At: {pc.status_timestamps.errored_at}") + + except Exception as e: + print(f"Error reading policy check: {e}") + return + + # 3) Override the policy check (if requested and possible) + if args.override: + _print_header(f"Overriding policy check: {args.policy_check_id}") + + try: + overridden_pc = client.policy_checks.override(args.policy_check_id) + print(f"Policy check {overridden_pc.id} successfully overridden!") + print(f"New status: {overridden_pc.status}") + + except Exception as e: + print(f"Error overriding policy check: {e}") + + # 4) Get logs for the policy check (if requested) + if args.get_logs: + _print_header(f"Getting logs for policy check: {args.policy_check_id}") + + try: + print( + "Fetching logs (this may take a moment if the policy check is still running)..." + ) + logs = client.policy_checks.logs(args.policy_check_id) + + print("Policy Check Logs:") + print("-" * 60) + print(logs) + print("-" * 60) + + except Exception as e: + print(f"Error getting policy check logs: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/tfe/client.py b/src/tfe/client.py index 90e0f80..5c7be10 100644 --- a/src/tfe/client.py +++ b/src/tfe/client.py @@ -11,6 +11,7 @@ from .resources.organizations import Organizations from .resources.plan import Plans from .resources.policy import Policies +from .resources.policy_check import PolicyChecks from .resources.projects import Projects from .resources.query_run import QueryRuns from .resources.registry_module import RegistryModules @@ -74,6 +75,7 @@ def __init__(self, config: TFEConfig | None = None): self.query_runs = QueryRuns(self._transport) self.run_events = RunEvents(self._transport) self.policies = Policies(self._transport) + self.policy_checks = PolicyChecks(self._transport) # SSH Keys self.ssh_keys = SSHKeys(self._transport) diff --git a/src/tfe/errors.py b/src/tfe/errors.py index 6069a7e..fcab482 100644 --- a/src/tfe/errors.py +++ b/src/tfe/errors.py @@ -415,3 +415,11 @@ class RequiredEnforceError(RequiredFieldMissing): def __init__(self, message: str = "enforce or enforcement-level is required"): super().__init__(message) + + +# Policy Check errors +class InvalidPolicyCheckIDError(InvalidValues): + """Raised when an invalid policy check ID is provided.""" + + def __init__(self, message: str = "invalid value for policy check ID"): + super().__init__(message) diff --git a/src/tfe/models/policy_check.py b/src/tfe/models/policy_check.py index 7e8366c..7b97274 100644 --- a/src/tfe/models/policy_check.py +++ b/src/tfe/models/policy_check.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, Field @@ -8,17 +10,114 @@ from .run import Run -# PolicyCheck represents a Terraform Enterprise policy check.. +class PolicyScope(str, Enum): + """The scope of the policy check.""" + + POLICY_SCOPE_ORGANIZATION = "organization" + POLICY_SCOPE_WORKSPACE = "workspace" + + +class PolicyStatus(str, Enum): + """The status of the policy check.""" + + POLICY_CANCELED = "canceled" + POLICY_ERRORED = "errored" + POLICY_HARD_FAILED = "hard_failed" + POLICY_OVERRIDDEN = "overridden" + POLICY_PASSES = "passed" + POLICY_PENDING = "pending" + POLICY_QUEUED = "queued" + POLICY_SOFT_FAILED = "soft_failed" + POLICY_UNREACHABLE = "unreachable" + + +class PolicyCheckIncludeOpt(str, Enum): + """A list of relations to include""" + + POLICY_CHECK_RUN_WORKSPACE = "run.workspace" + POLICY_CHECK_RUN = "run" + + class PolicyCheck(BaseModel): + """PolicyCheck represents a Terraform Enterprise policy check.""" + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) id: str - # actions: PolicyActions = Field(..., alias="actions") - # permissions: PolicyPermissions = Field(..., alias="permissions") - # result: PolicyResult = Field(..., alias="result") - # scope: PolicyScope = Field(..., alias="scope") - # status: PolicyStatus = Field(..., alias="status") - # status_timestamps: PolicyStatusTimestamps = Field(..., alias="status-timestamps") + actions: PolicyActions | None = Field(None, alias="actions") + permissions: PolicyPermissions | None = Field(None, alias="permissions") + result: PolicyResult | None = Field(None, alias="result") + scope: PolicyScope | None = Field(None, alias="scope") + status: PolicyStatus | None = Field(None, alias="status") + status_timestamps: PolicyStatusTimestamps | None = Field( + None, alias="status-timestamps" + ) # Relations - run: Run = Field(..., alias="run") + run: Run | None = Field(None, alias="run") + + +class PolicyActions(BaseModel): + """PolicyActions represents the policy check actions.""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + is_overridable: bool | None = Field(None, alias="is-overridable") + + +class PolicyPermissions(BaseModel): + """PolicyPermissions represents the policy check permissions.""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + can_override: bool | None = Field(None, alias="can-override") + + +class PolicyResult(BaseModel): + """PolicyResult represents the complete policy check result""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + advisory_failed: int | None = Field(None, alias="advisory-failed") + duration: int | None = Field(None, alias="duration") + hard_failed: int | None = Field(None, alias="hard-failed") + soft_failed: int | None = Field(None, alias="soft-failed") + total_failed: int | None = Field(None, alias="total-failed") + passed: int | None = Field(None, alias="passed") + result: bool | None = Field(None, alias="result") + sentinel: Any | None = Field(None, alias="sentinel") + + +class PolicyStatusTimestamps(BaseModel): + """PolicyStatusTimestamps holds the timestamps for individual policy check statuses.""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + errored_at: datetime | None = Field(None, alias="errored-at") + hard_failed_at: datetime | None = Field(None, alias="hard-failed-at") + passed_at: datetime | None = Field(None, alias="passed-at") + queued_at: datetime | None = Field(None, alias="queued-at") + soft_failed_at: datetime | None = Field(None, alias="soft-failed-at") + + +class PolicyCheckListOptions(BaseModel): + """PolicyCheckListOptions represents the options for listing policy checks.""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + include: list[PolicyCheckIncludeOpt] | None = Field(None, alias="include") + page_number: int | None = Field(None, alias="page[number]") + page_size: int | None = Field(None, alias="page[size]") + + +class PolicyCheckList(BaseModel): + """PolicyCheckList represents a list of policy checks.""" + + model_config = ConfigDict(populate_by_name=True, validate_by_name=True) + + items: list[PolicyCheck] = Field(default_factory=list, alias="items") + current_page: int | None = Field(None, alias="current_page") + total_pages: int | None = Field(None, alias="total_pages") + prev_page: int | None = Field(None, alias="prev_page") + next_page: int | None = Field(None, alias="next_page") + total_count: int | None = Field(None, alias="total_count") diff --git a/src/tfe/resources/policy_check.py b/src/tfe/resources/policy_check.py new file mode 100644 index 0000000..7529f61 --- /dev/null +++ b/src/tfe/resources/policy_check.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import time + +from ..errors import ( + InvalidPolicyCheckIDError, + InvalidRunIDError, +) +from ..models.policy_check import ( + PolicyCheck, + PolicyCheckList, + PolicyCheckListOptions, + PolicyStatus, +) +from ..utils import valid_string_id +from ._base import _Service + + +class PolicyChecks(_Service): + """ + PolicyChecks describes all the policy check related methods that the Terraform Enterprise API supports. + TFE API docs: https://developer.hashicorp.com/terraform/cloud-docs/api-docs/policy-checks + """ + + def list( + self, run_id: str, options: PolicyCheckListOptions | None = None + ) -> PolicyCheckList: + """List all policy checks of the given run.""" + if not valid_string_id(run_id): + raise InvalidRunIDError() + params = ( + options.model_dump(by_alias=True, exclude_none=True) if options else None + ) + r = self.t.request( + "GET", + f"/api/v2/runs/{run_id}/policy-checks", + params=params, + ) + jd = r.json() + items = [] + meta = jd.get("meta", {}) + pagination = meta.get("pagination", {}) + for d in jd.get("data", []): + attrs = d.get("attributes", {}) + attrs["id"] = d.get("id") + attrs["run"] = d.get("relationships", {}).get("run", {}) + items.append(PolicyCheck.model_validate(attrs)) + return PolicyCheckList( + items=items, + current_page=pagination.get("current-page"), + total_pages=pagination.get("total-pages"), + prev_page=pagination.get("prev-page"), + next_page=pagination.get("next-page"), + total_count=pagination.get("total-count"), + ) + + def read(self, policy_check_id: str) -> PolicyCheck: + """Read a policy check by its ID.""" + if not valid_string_id(policy_check_id): + raise InvalidPolicyCheckIDError() + r = self.t.request( + "GET", + f"/api/v2/policy-checks/{policy_check_id}", + ) + jd = r.json() + d = jd.get("data", {}) + attrs = d.get("attributes", {}) + attrs["id"] = d.get("id") + attrs["run"] = d.get("relationships", {}).get("run", {}) + return PolicyCheck.model_validate(attrs) + + def override(self, policy_check_id: str) -> PolicyCheck: + """Override a soft-mandatory or warning policy.""" + if not valid_string_id(policy_check_id): + raise InvalidPolicyCheckIDError() + r = self.t.request( + "POST", + f"/api/v2/policy-checks/{policy_check_id}/actions/override", + ) + jd = r.json() + d = jd.get("data", {}) + attrs = d.get("attributes", {}) + attrs["id"] = d.get("id") + attrs["run"] = d.get("relationships", {}).get("run", {}) + return PolicyCheck.model_validate(attrs) + + def logs(self, policy_check_id: str) -> str: + """Logs retrieves the logs of a policy check.""" + if not valid_string_id(policy_check_id): + raise InvalidPolicyCheckIDError() + + # Loop until the policy check is finished running. + # The policy check logs are not streamed and so only available + # once the check is finished. + while True: + pc = self.read(policy_check_id) + + # Continue polling if the policy check is still pending or queued + if pc.status in (PolicyStatus.POLICY_PENDING, PolicyStatus.POLICY_QUEUED): + time.sleep(0.5) # 500ms wait, equivalent to Go's 500 * time.Millisecond + continue + + # Policy check is finished, get the logs + r = self.t.request( + "GET", + f"/api/v2/policy-checks/{policy_check_id}/output", + ) + return r.text