diff --git a/README.md b/README.md index 29cbb28..20d0c2b 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ For more information see the [self-hosted runner security docs](https://docs.git | aws_tags | The AWS tags to use for your runner, formatted as a JSON list. See `README` for more details. | false | | | extra_gh_labels | Any extra GitHub labels to tag your runners with. Passed as a comma-separated list with no spaces. | false | | | instance_count | The number of instances to create, defaults to 1 | false | 1 | +| gh_timeout | The timeout in seconds to wait for the runner to come online as seen by the GitHub API. Defaults to 1200 seconds. | false | 1200 | ### AWS `stop` Inputs | Input | Description | Required for stop| Default | Note | diff --git a/action.yml b/action.yml index 9e34c6f..af6a5e5 100644 --- a/action.yml +++ b/action.yml @@ -1,8 +1,8 @@ name: gha-runner description: A simple GitHub Action for creating self-hosted runners. runs: - using: 'docker' - image: 'Dockerfile' + using: "docker" + image: "Dockerfile" inputs: action: description: 'Whether to start or stop. Options: "start", "stop"' @@ -37,9 +37,9 @@ inputs: instance_count: description: "The number of instances to create, defaults to 1" required: false - default: '1' + default: "1" instance_mapping: - description: 'A JSON object mapping instance ids to unique GitHub runner labels. Required to stop created instances.' + description: "A JSON object mapping instance ids to unique GitHub runner labels. Required to stop created instances." required: false provider: description: 'The cloud provider to use to provision a runner. Will not start if not set. Example: "aws"' @@ -47,6 +47,9 @@ inputs: repo: description: "The repo to run against. Will use the the current repo if not specified." required: false + gh_timeout: + description: "The timeout in seconds to wait for the runner to come online as seen by the GitHub API. Defaults to 1200 seconds." + required: false outputs: mapping: description: "A JSON object mapping instance IDs to unique GitHub runner labels. This is used in conjection with the the `instance_mapping` input when stopping." diff --git a/src/gha_runner/__main__.py b/src/gha_runner/__main__.py index 0924e6f..aa0e6b6 100644 --- a/src/gha_runner/__main__.py +++ b/src/gha_runner/__main__.py @@ -50,7 +50,11 @@ def get_instance_mapping() -> dict[str, str]: def start_runner_instances( - provider: str, cloud_params: dict, gh: GitHubInstance, count: int + provider: str, + cloud_params: dict, + gh: GitHubInstance, + count: int, + timeout: int, ): release = gh.get_latest_runner_release(platform="linux", architecture="x64") cloud_params["runner_release"] = release @@ -75,7 +79,7 @@ def start_runner_instances( print("Instance is ready!") for label in github_labels: print(f"Waiting for {label}...") - gh.wait_for_runner(label) + gh.wait_for_runner(label, timeout) def stop_runner_instances( @@ -142,6 +146,8 @@ def main(): # pragma: no cover provider = os.environ.get("INPUT_PROVIDER") if provider is None: raise Exception("Missing required input variable INPUT_PROVIDER") + # Set the default timeout to 20 minutes + gh_timeout = int(os.environ.get("INPUT_GH_TIMEOUT", 1200)) gha_params = { "token": os.environ["GH_PAT"], @@ -165,6 +171,7 @@ def main(): # pragma: no cover cloud_params, gh, instance_count, + gh_timeout, ) elif action == "stop": stop_runner_instances(provider, cloud_params, gh) diff --git a/src/gha_runner/gh.py b/src/gha_runner/gh.py index 50d48c1..9924224 100644 --- a/src/gha_runner/gh.py +++ b/src/gha_runner/gh.py @@ -12,6 +12,7 @@ class TokenRetrievalError(Exception): """Exception raised when there is an error retrieving a token from GitHub.""" + class MissingRunnerLabel(Exception): """Exception raised when a runner does not exist in the repository.""" @@ -189,7 +190,7 @@ def get_runners(self, label: str) -> list[SelfHostedActionsRunner] | None: ] return matched_runners if matched_runners else None - def wait_for_runner(self, label: str, wait: int = 15): + def wait_for_runner(self, label: str, timeout: int, wait: int = 15): """Wait for the runner with the given label to be online. Parameters @@ -198,10 +199,14 @@ def wait_for_runner(self, label: str, wait: int = 15): The label of the runner to wait for. wait : int The time in seconds to wait between checks. Defaults to 15 seconds. - + timeout : int + The maximum time in seconds to wait for the runner to be online. """ + max = time.time() + timeout runner = self.get_runner(label) while runner is None: + if time.time() > max: + raise RuntimeError(f"Timeout reached: Runner {label} not found") print(f"Runner {label} not found. Waiting...") runner = self.get_runner(label) time.sleep(wait) diff --git a/tests/test_gh.py b/tests/test_gh.py index e7f6b37..c0869a2 100644 --- a/tests/test_gh.py +++ b/tests/test_gh.py @@ -2,7 +2,11 @@ import pytest from unittest.mock import patch, MagicMock, Mock -from gha_runner.gh import TokenRetrievalError, GitHubInstance, MissingRunnerLabel +from gha_runner.gh import ( + TokenRetrievalError, + GitHubInstance, + MissingRunnerLabel, +) from github.SelfHostedActionsRunner import SelfHostedActionsRunner @@ -383,7 +387,7 @@ def mock_get_runner(monkeypatch): def test_wait_for_runner(github_release_mock, mock_get_runner, capsys): instance, _, _ = github_release_mock get_runner_mock, label, expected_calls = mock_get_runner - instance.wait_for_runner(label, wait=1) + instance.wait_for_runner(label, 10, wait=1) captured = capsys.readouterr() # Combine all expected calls into a single string combined = "".join(expected_calls) @@ -392,3 +396,32 @@ def test_wait_for_runner(github_release_mock, mock_get_runner, capsys): assert captured.out == combined # Validate that the get_runner method was called the correct number of times assert get_runner_mock.call_count == len(expected_calls) + +@pytest.fixture +def mock_get_runner_timeout(monkeypatch): + label = "runner-linux-x64" + side_effect = [None, None] + # Dynamically build out the expected calls based on the side_effect + expected_calls = [ + f"Runner {label} not found. Waiting...\n ", + ] + + get_runner_mock = MagicMock() + # Setup the side_effect for the get_runner_mock + get_runner_mock.side_effect = side_effect + monkeypatch.setattr(GitHubInstance, "get_runner", get_runner_mock) + return get_runner_mock, label, expected_calls + +def test_wait_for_runner_timeout(github_release_mock, mock_get_runner_timeout, capsys): + instance, _, _ = github_release_mock + get_runner_mock, label, expected_calls = mock_get_runner_timeout + with pytest.raises(RuntimeError, match=f"Timeout reached: Runner {label} not found"): + instance.wait_for_runner(label, timeout=1, wait=1) + captured = capsys.readouterr() + # Combine all expected calls into a single string + combined = "".join(expected_calls) + + # Validate that the expected output matches the captured output + assert captured.out == combined + # Validate that the get_runner method was called the correct number of times + assert get_runner_mock.call_count == len(expected_calls) diff --git a/tests/test_main.py b/tests/test_main.py index b9d75e2..7f41b3d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -181,6 +181,7 @@ def test_start_runner_instances_smoke( gh=mock_gh, count=1, cloud_params={}, + timeout=0, ) except Exception as e: pytest.fail(f"Exception raised: {e}")