diff --git a/pyproject.toml b/pyproject.toml index b633effaf..8c8ebef25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ dev-dependencies = [ # "scalene>=1.5.45", "filelock<4.0.0,>=3.15.4", "pytest>=8.3.3", + "pytest-asyncio>=0.23.5", "pytest-cov>=6.0.0,<6.0.1", "ruff>=0.6.8", "mypy[mypyc,faster-cache]>=1.13.0", diff --git a/tests/unit/codegen/runner/sandbox/conftest.py b/tests/unit/codegen/runner/sandbox/conftest.py new file mode 100644 index 000000000..e8d012dc7 --- /dev/null +++ b/tests/unit/codegen/runner/sandbox/conftest.py @@ -0,0 +1,58 @@ +from collections.abc import Generator +from unittest.mock import patch + +import pytest + +from codegen.git.repo_operator.local_repo_operator import LocalRepoOperator +from codegen.git.schemas.repo_config import RepoConfig +from codegen.runner.models.configs import RunnerFeatureFlags +from codegen.runner.sandbox.executor import SandboxExecutor +from codegen.runner.sandbox.runner import SandboxRunner +from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags, ProjectConfig +from codegen.sdk.core.codebase import Codebase +from codegen.sdk.enums import ProgrammingLanguage +from codegen.sdk.secrets import Secrets + + +@pytest.fixture +def codebase(tmpdir, request) -> Codebase: + repo_id = getattr(request, "param", 1) + repo_config = RepoConfig(id=repo_id, name="test-repo", full_name="test-org/test-repo", organization_id=1, organization_name="test-org") + op = LocalRepoOperator.create_from_files(repo_path=tmpdir, files={"test.py": "a = 1"}, bot_commit=True, repo_config=repo_config) + projects = [ProjectConfig(repo_operator=op, programming_language=ProgrammingLanguage.PYTHON)] + codebase = Codebase(projects=projects) + return codebase + + +@pytest.fixture +def executor(codebase: Codebase) -> Generator[SandboxExecutor]: + with patch("codegen.runner.sandbox.executor.get_runner_feature_flags") as mock_ff: + mock_ff.return_value = RunnerFeatureFlags(syntax_highlight=False) + + yield SandboxExecutor(codebase) + + +@pytest.fixture +def runner(codebase: Codebase, tmpdir): + with patch("codegen.runner.sandbox.runner.RemoteRepoOperator") as mock_op: + with patch.object(SandboxRunner, "_build_graph") as mock_init_codebase: + mock_init_codebase.return_value = codebase + mock_op.return_value = codebase.op + + yield SandboxRunner(container_id="ta-123", repo_config=codebase.op.repo_config) + + +@pytest.fixture(autouse=True) +def mock_runner_flags(): + with patch("codegen.runner.sandbox.executor.get_runner_feature_flags") as mock_ff: + mock_ff.return_value = RunnerFeatureFlags(syntax_highlight=False) + yield mock_ff + + +@pytest.fixture(autouse=True) +def mock_codebase_config(): + with patch("codegen.runner.sandbox.runner.get_codebase_config") as mock_config: + gs_ffs = GSFeatureFlags(**RunnerFeatureFlags().model_dump()) + secrets = Secrets(openai_key="test-key") + mock_config.return_value = CodebaseConfig(secrets=secrets, feature_flags=gs_ffs) + yield mock_config diff --git a/tests/unit/codegen/runner/sandbox/test_executor.py b/tests/unit/codegen/runner/sandbox/test_executor.py new file mode 100644 index 000000000..e10e8df20 --- /dev/null +++ b/tests/unit/codegen/runner/sandbox/test_executor.py @@ -0,0 +1,227 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from codegen.git.models.codemod_context import CodemodContext +from codegen.runner.models.codemod import GroupingConfig +from codegen.runner.sandbox.executor import SandboxExecutor +from codegen.sdk.codebase.config import SessionOptions +from codegen.sdk.codebase.flagging.code_flag import CodeFlag +from codegen.sdk.codebase.flagging.groupers.enums import GroupBy +from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock + + +@pytest.mark.asyncio +async def test_execute_func_pass_in_codemod_context_takes_priority(executor: SandboxExecutor): + codemod_context = CodemodContext( + CODEMOD_LINK="http://codegen.sh/codemod/5678", + ) + mock_source = """ +print(context.CODEMOD_LINK) +""" + custom_scope = {"context": codemod_context} + code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source, custom_scope=custom_scope) + mock_log = MagicMock() + executor.codebase.log = mock_log + + result = await executor.execute(code_to_exec) + + assert result is not None + + assert mock_log.call_count == 1 + assert mock_log.call_args_list[0][0][0] == "http://codegen.sh/codemod/5678" + + +# @pytest.mark.asyncio +# async def test_init_execute_func_with_pull_request_context(executor: SandboxExecutor): +# mock_source = """ +# print(context.PULL_REQUEST.head.ref) +# print(context.PULL_REQUEST.base.ref) +# """ +# mock_cm_run = MagicMock(epic=MagicMock(id=1234, link="link", user=MagicMock(github_username="user")), codemod_version=MagicMock(source=mock_source)) +# mock_pull = MagicMock(spec=GithubWebhookPR, head=MagicMock(ref="test-head"), base=MagicMock(ref="test-base")) +# codemod_context = get_codemod_context(cm_run=mock_cm_run, pull_request=mock_pull) +# custom_scope = {"context": codemod_context} +# code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source, custom_scope=custom_scope) +# mock_log = MagicMock() +# executor.codebase.log = mock_log +# +# result = await executor.execute(code_to_exec) +# +# assert result is not None +# assert mock_log.call_count == 2 +# assert mock_log.call_args_list[0][0][0] == "test-head" +# assert mock_log.call_args_list[1][0][0] == "test-base" +# +# +# @pytest.mark.asyncio +# async def test_init_execute_func_with_pull_request_context_mock_codebase(executor: SandboxExecutor): +# mock_source = """ +# print(context.PULL_REQUEST.head.ref) +# print(context.PULL_REQUEST.base.ref) +# """ +# mock_cm_run = MagicMock(epic=MagicMock(id=1234, link="link", user=MagicMock(github_username="user")), codemod_version=MagicMock(source=mock_source)) +# mock_pull = MagicMock(spec=GithubWebhookPR, head=MagicMock(ref="test-head"), base=MagicMock(ref="test-base")) +# codemod_context = get_codemod_context(cm_run=mock_cm_run, pull_request=mock_pull) +# custom_scope = {"context": codemod_context} +# code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source, custom_scope=custom_scope) +# +# result = await executor.execute(code_to_exec) +# +# # validate +# assert result is not None +# assert ( +# result.logs +# == """ +# test-head +# test-base +# """.lstrip() +# ) + + +@pytest.mark.asyncio +async def test_run_max_preview_time_exceeded_sets_observation_meta(executor: SandboxExecutor): + mock_source = """ +codebase.files[0].edit("a = 2") +""" + code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source) + result = await executor.execute(code_to_exec, session_options=SessionOptions(max_seconds=0)) + + assert result.is_complete + assert result.observation_meta == {"flags": [], "stop_codemod_exception_type": "MaxPreviewTimeExceeded", "threshold": 0} + + +@pytest.mark.asyncio +async def test_run_max_ai_requests_error_sets_observation_meta(executor: SandboxExecutor): + mock_source = """ +codebase.ai("tell me a joke") +""" + code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source) + result = await executor.execute(code_to_exec, session_options=SessionOptions(max_ai_requests=0)) + + assert result.is_complete + assert result.observation_meta == {"flags": [], "stop_codemod_exception_type": "MaxAIRequestsError", "threshold": 0} + + +@pytest.mark.asyncio +async def test_run_max_transactions_exceeded_sets_observation_meta(executor: SandboxExecutor): + mock_source = """ +codebase.files[0].edit("a = 2") +""" + + code_to_exec = create_execute_function_from_codeblock(codeblock=mock_source) + result = await executor.execute(code_to_exec, session_options=SessionOptions(max_transactions=0)) + + assert result.is_complete + assert result.observation_meta == {"flags": [], "stop_codemod_exception_type": "MaxTransactionsExceeded", "threshold": 0} + + +@pytest.mark.asyncio +async def test_find_flag_groups_with_subdirectories(executor: SandboxExecutor): + groups = await executor.find_flag_groups( + code_flags=[ + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir1/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir2/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir3/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir3/file2.py")), + message="message", + ), + ], + grouping_config=GroupingConfig(subdirectories=["subdir1", "subdir2"]), + ) + assert len(groups) == 1 + assert len(groups[0].flags) == 2 + assert groups[0].flags[0].filepath == "subdir1/file1.py" + assert groups[0].flags[1].filepath == "subdir2/file1.py" + + +@pytest.mark.asyncio +async def test_find_flag_groups_with_group_by(executor: SandboxExecutor): + groups = await executor.find_flag_groups( + code_flags=[ + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir1/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir2/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir3/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir1/file1.py")), + message="message", + ), + ], + grouping_config=GroupingConfig(group_by=GroupBy.FILE), + ) + assert len(groups) == 3 + assert groups[0].segment == "subdir1/file1.py" + assert groups[1].segment == "subdir2/file1.py" + assert groups[2].segment == "subdir3/file1.py" + + assert len(groups[0].flags) == 2 + assert len(groups[1].flags) == 1 + assert len(groups[2].flags) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("codebase", [121], indirect=True) +async def test_find_flag_groups_with_group_by_app(executor: SandboxExecutor): + groups = await executor.find_flag_groups( + code_flags=[ + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="a/b/app1/test1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="a/b/app2/test1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="a/b/app3/test1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="a/b/app2/test2.py")), + message="message", + ), + ], + grouping_config=GroupingConfig(group_by=GroupBy.APP), + ) + count_by_segment = {group.segment: len(group.flags) for group in groups} + assert count_by_segment == {"a/b/app1": 1, "a/b/app2": 2, "a/b/app3": 1} + + +@pytest.mark.skip(reason="TODO: add max_prs as part of find_flag_groups") +@pytest.mark.asyncio +async def test_find_flag_groups_with_max_prs(executor: SandboxExecutor): + groups = await executor.find_flag_groups( + code_flags=[ + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir1/file1.py")), + message="message", + ), + CodeFlag( + symbol=MagicMock(file=MagicMock(filepath="subdir2/file1.py")), + message="message", + ), + ], + grouping_config=GroupingConfig(group_by=GroupBy.FILE, max_prs=0), + ) + assert len(groups) == 0 diff --git a/tests/unit/codegen/runner/sandbox/test_runner.py b/tests/unit/codegen/runner/sandbox/test_runner.py new file mode 100644 index 000000000..0abd5c557 --- /dev/null +++ b/tests/unit/codegen/runner/sandbox/test_runner.py @@ -0,0 +1,127 @@ +from unittest.mock import PropertyMock, patch + +import pytest + +from codegen.runner.sandbox.runner import SandboxRunner + + +@pytest.mark.asyncio +@patch("codegen.runner.sandbox.executor.SandboxExecutor") +async def test_sandbox_runner_warmup_builds_graph(mock_executor, runner: SandboxRunner): + await runner.warmup() + assert runner.codebase.files + assert len(runner.codebase.files) == 1 + + +@pytest.mark.asyncio +@patch("codegen.runner.sandbox.runner.SandboxRunner._build_graph") +async def test_sandbox_runner_warmup_builds_graph_throws(mock_build_graph, runner: SandboxRunner): + mock_build_graph.side_effect = Exception("Test exception") + + with pytest.raises(Exception): + await runner.warmup() + + +@pytest.mark.asyncio +@patch("codegen.runner.sandbox.runner.logger") +@patch("codegen.runner.sandbox.runner.SandboxExecutor") +async def test_sandbox_runner_warmup_logs_repo_id(mock_executor, mock_logger, runner: SandboxRunner): + await runner.warmup() + assert runner.codebase.files + assert len(runner.codebase.files) == 1 + assert mock_logger.info.call_count == 1 + assert "Warming runner for test-org/test-repo (ID=1)" in mock_logger.info.call_args_list[0][0][0] + + +@pytest.mark.asyncio +@patch("codegen.runner.sandbox.runner.SandboxExecutor") +async def test_sandbox_runner_warmup_starts_with_default_branch(mock_executor, runner: SandboxRunner): + await runner.warmup() # assert True is returned + # assert len(runner.codebase._op.git_cli.branches) == 1 TODO: fix GHA creating master and main branch + assert not runner.codebase._op.git_cli.head.is_detached + assert runner.codebase._op.git_cli.active_branch.name == runner.codebase.default_branch + assert runner.codebase._op.git_cli.head.commit == runner.commit + + +@pytest.mark.asyncio +@patch("codegen.runner.sandbox.runner.logger") +@patch("codegen.runner.sandbox.runner.SandboxExecutor") +@patch("codegen.sdk.core.codebase.Codebase.default_branch", new_callable=PropertyMock) +async def test_sandbox_runner_reset_runner_deletes_branches(mock_branch, mock_executor, mock_logger, runner: SandboxRunner): + mock_branch.return_value = "main" + await runner.warmup() + num_branches = len(runner.codebase._op.git_cli.heads) # TODO: fix GHA creating master and main branch and assert the len is 1 at the start + runner.codebase.checkout(branch="test-branch-a", create_if_missing=True) + runner.codebase.checkout(branch="test-branch-b", create_if_missing=True) + assert len(runner.codebase._op.git_cli.heads) == num_branches + 2 + runner.reset_runner() + assert len(runner.codebase._op.git_cli.heads) == 1 # now should be on default branch at self.commit + assert runner.codebase._op.git_cli.active_branch.name == runner.codebase.default_branch + assert runner.codebase._op.git_cli.head.commit == runner.commit + + +# @pytest.mark.asyncio +# @patch("codegen.runner.sandbox.executor.get_runner_feature_flags") +# @patch("codegen.runner.sandbox.executor.SandboxExecutor.execute") +# async def test_sandbox_runner_run_reset_runner_same_branch_state( +# mock_run_execute_flag_groups, +# mock_ffs, +# runner: SandboxRunner, +# db_mock: DBMock, +# db_mock_connection: MockConnectionProvider, +# ): +# """Test that the branch post warm-up state and the post reset_runner state is the same""" +# mock_ffs.return_value = RunnerFeatureFlags() +# mock_run_execute_flag_groups.return_value = CodemodRunResult() +# session = db_mock_connection.get_session() +# mock_source = """ +# codebase.files[0].edit("a = 2") +# """ +# +# # after warmup sandbox is on default branch at self.commit +# await runner.warmup() +# assert not runner.codebase._op.git_cli.head.is_detached +# assert runner.codebase._op.git_cli.active_branch.name == runner.codebase.default_branch +# assert runner.codebase._op.git_cli.head.commit == runner.commit +# +# mock_instances = [*create_mock_codemod_run(create_epic=True, codemod_version_source=mock_source)] +# with db_mock.from_orm(mock_instances) as mocked_data: +# cm_run = mocked_data[CodemodRunModel][0] +# cm_version = mocked_data[CodemodVersionModel][0] +# epic = mocked_data[TaskEpicModel][0] +# codemod = serialize_mock_cm_run(cm_run, cm_version, epic) +# branch_config = BranchConfig(base_branch=runner.codebase.default_branch) +# request = CreateBranchRequest(codemod=codemod, grouping_config=GroupingConfig(), branch_config=branch_config) +# await runner.create_branch(request=request) +# +# # assert a PR branch was created +# assert "codegen-codemod-1-version-1-run-1-group-0" in runner.codebase._op.git_cli.heads +# +# # after running and resetting runner, sandbox is again on default branch at self.commit +# runner.reset_runner() +# assert len(runner.codebase._op.git_cli.heads) == 1 # now should be on default branch at self.commit +# assert not runner.codebase._op.git_cli.head.is_detached +# assert runner.codebase._op.git_cli.active_branch.name == runner.codebase.default_branch +# assert runner.codebase._op.git_cli.head.commit == runner.commit +# +# +# @pytest.mark.asyncio +# @patch("codegen.runner.sandbox.runner.logger") +# async def test_run_user_code_exception_sets_failure_returns_empty_codemod_run_result(mock_logger, runner: SandboxRunner): +# with pytest.raises(InvalidUserCodeException): +# mock_syntax_error_source = """ +# = 1 +# """ +# mock_db = MagicMock() +# mock_db.get().repo.language = ProgrammingLanguage.PYTHON +# mock_cm_run = MagicMock( +# spec=CodemodRunModel, epic=MagicMock(id=1234, link="link", user=MagicMock(github_username="user"), title="test-epic"), codemod_version=MagicMock(id=123, source=mock_syntax_error_source) +# ) +# req = GetDiffRequest(codemod=serialize_cm_run(mock_cm_run)) +# +# await runner.get_diff(request=req) +# +# assert mock_logger.error.call_count == 1 +# assert "InvalidUserCodeException caught compiling codemod version: 123" in mock_logger.error.call_args_list[0][0][0] +# assert "SyntaxError" in mock_cm_run.error +# assert "invalid syntax" in mock_cm_run.error diff --git a/tests/unit/codegen/runner/utils/test_branch_name.py b/tests/unit/codegen/runner/utils/test_branch_name.py new file mode 100644 index 000000000..6b3d807a5 --- /dev/null +++ b/tests/unit/codegen/runner/utils/test_branch_name.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock + +from codegen.runner.utils.branch_name import get_head_branch_name + + +def test_get_head_branch_name_no_group(): + codemod = MagicMock(epic_id=123, version_id=456, run_id=789) + branch_name = get_head_branch_name(codemod=codemod, group=None) + assert branch_name == "codegen-codemod-123-version-456-run-789-group-0" + + +def test_get_head_branch_name_with_group(): + codemod = MagicMock(epic_id=123, version_id=456, run_id=789) + group = MagicMock(id=2) + branch_name = get_head_branch_name(codemod=codemod, group=group) + assert branch_name == "codegen-codemod-123-version-456-run-789-group-2" diff --git a/tests/unit/codegen/sdk/codebase/code_flag/test_code_flag.py b/tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py similarity index 100% rename from tests/unit/codegen/sdk/codebase/code_flag/test_code_flag.py rename to tests/unit/codegen/sdk/codebase/flagging/test_code_flag.py diff --git a/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py b/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py new file mode 100644 index 000000000..179a6ffb2 --- /dev/null +++ b/tests/unit/codegen/sdk/codebase/flagging/test_group_all.py @@ -0,0 +1,31 @@ +from unittest.mock import MagicMock + +from codegen.sdk.codebase.flagging.code_flag import CodeFlag, MessageType +from codegen.sdk.codebase.flagging.groupers.all_grouper import AllGrouper + + +def test_group_all(): + flag1 = CodeFlag( + symbol=MagicMock( + file=MagicMock(filepath="test.py"), + node_id="12345", + ), + message="test message", + message_type=MessageType.GITHUB, + message_recipient="12345", + ) + flag2 = CodeFlag( + symbol=MagicMock( + file=MagicMock(filepath="test.py"), + node_id="12345", + ), + message="test message", + message_type=MessageType.GITHUB, + message_recipient="12345", + ) + flags = [flag1, flag2] + groups = AllGrouper.create_all_groups(flags) + assert len(groups) == 1 + assert len(groups[0].flags) == 2 + assert groups[0].flags[0] == flag1 + assert groups[0].flags[1] == flag2 diff --git a/tests/unit/codegen/shared/compilation/test_codeblock_validation.py b/tests/unit/codegen/shared/compilation/test_codeblock_validation.py new file mode 100644 index 000000000..0f69ad463 --- /dev/null +++ b/tests/unit/codegen/shared/compilation/test_codeblock_validation.py @@ -0,0 +1,22 @@ +import pytest + +from codegen.shared.compilation.codeblock_validation import check_for_dangerous_operations +from codegen.shared.exceptions.compilation import DangerousUserCodeException + + +def test_no_dangerous_operations(): + codeblock = """ +print("not dangerous") +""" + try: + check_for_dangerous_operations(codeblock) + except DangerousUserCodeException: + pytest.fail("Unexpected DangerousPythonCodeError raised") + + +def test_dangerous_operations(): + codeblock = """ +print(os.environ["ENV"]) +""" + with pytest.raises(DangerousUserCodeException): + check_for_dangerous_operations(codeblock) diff --git a/tests/unit/codegen/shared/compilation/test_function_compilation.py b/tests/unit/codegen/shared/compilation/test_function_compilation.py new file mode 100644 index 000000000..da6ccf44a --- /dev/null +++ b/tests/unit/codegen/shared/compilation/test_function_compilation.py @@ -0,0 +1,126 @@ +import pytest + +from codegen.shared.compilation.function_compilation import safe_compile_function_string +from codegen.shared.exceptions.compilation import InvalidUserCodeException + + +def test_valid_func_str_should_not_raise(): + func_str = """ +from codegen.sdk.core.codebase import Codebase + +def execute(codebase: Codebase): + print(len(codebase.files)) +""" + try: + safe_compile_function_string(custom_scope={}, func_name="execute", func_str=func_str) + except InvalidUserCodeException: + pytest.fail("Unexpected InvalidUserCodeException raised") + + +def test_valid_func_str_with_nested_should_not_raise(): + func_str = """ +from codegen.sdk.core.codebase import Codebase + +def execute(codebase: Codebase): + def nested(): + return "I'm nested!" + print("calling nested") + nested() +""" + try: + safe_compile_function_string(custom_scope={}, func_name="execute", func_str=func_str) + except InvalidUserCodeException: + pytest.fail("Unexpected InvalidUserCodeException raised") + + +def test_compile_syntax_error_indent_error_raises(): + func_str = """ +def execute(codebase: Codebase): +a = 1 + print(a) +""" + with pytest.raises(InvalidUserCodeException) as exc_info: + safe_compile_function_string(custom_scope={}, func_name="execute", func_str=func_str) + assert exc_info + error_msg = str(exc_info.value) + assert "IndentationError" in error_msg # an example of a SyntaxError + assert "> 3: a = 1" in error_msg + + +def test_compile_syntax_error_raises(): + func_str = """ +def execute(codebase: Codebase): + print "syntax error" +""" + + with pytest.raises(InvalidUserCodeException) as exc_info: + safe_compile_function_string(custom_scope={}, func_str=func_str, func_name="execute") + assert exc_info + error_msg = str(exc_info.value) + assert "SyntaxError" in error_msg + assert '> 3: print "syntax error"' in error_msg + + +def test_compile_non_syntax_error_unicode_error_raises(): + func_str = """ +def execute(codebase: Codebase): + print("hello")\udcff +""" + + SyntaxError() + with pytest.raises(InvalidUserCodeException) as exc_info: + safe_compile_function_string(custom_scope={}, func_str=func_str, func_name="execute") + assert exc_info + error_msg = str(exc_info.value) + assert "UnicodeEncodeError" in error_msg + # TODO: why is this missing the error context lines? + # TODO: also the error line number is the line in the source code not in the func_str + assert "'utf-8' codec can't encode character '\\udcff'" in error_msg + + +def test_exec_error_non_syntax_error_zero_division_raises(): + """This is to test that we're handling errors (ex: ZeroDivisionError) that are raised during `exec` properly. + + NOTE: this case wouldn't happen with an actual func_str from create_function_str_from_codeblock b/c the func_str would just take in a codebase. + """ + func_str = """ +def execute(codebase: Codebase, exec_error: int = 1/0): + print("zero division error") +""" + with pytest.raises(InvalidUserCodeException) as exc_info: + safe_compile_function_string(custom_scope={}, func_str=func_str, func_name="execute") + assert exc_info + error_msg = str(exc_info.value) + assert "ZeroDivisionError" in error_msg + assert "> 2: def execute(codebase: Codebase, exec_error: int = 1/0):" in error_msg + + +def test_exec_error_non_syntax_error_name_error_raises(): + """This is to test that we're handling errors (ex: NameError) that are raised during `exec` properly. + + NOTE: this case wouldn't happen with an actual func_str from create_function_str_from_codeblock b/c the func_str would not have any patches. + """ + func_str = """ +@patch("foo", return_value="bar") +def execute(codebase: Codebase): + print("zero division error") +""" + with pytest.raises(InvalidUserCodeException) as exc_info: + safe_compile_function_string(custom_scope={}, func_str=func_str, func_name="execute") + assert exc_info + error_msg = str(exc_info.value) + assert "NameError" in error_msg + assert '> 2: @patch("foo", return_value="bar")' in error_msg + + +def test_func_str_uses_custom_scope_var_does_not_raise(): + """This tests if a func_str references a var that is included in custom scope, it will not raise a NameError. + This is to test the case when a group of codemods is run and a later one relies on a local defined in a previous one. + """ + func_str = """ +print(local_a) +""" + try: + safe_compile_function_string(custom_scope={"local_a": "this is local_a"}, func_str=func_str, func_name="execute") + except InvalidUserCodeException: + pytest.fail("Unexpected InvalidPythonCodeException raised") diff --git a/tests/unit/codegen/shared/compilation/test_function_construction.py b/tests/unit/codegen/shared/compilation/test_function_construction.py new file mode 100644 index 000000000..1ba1fe6c9 --- /dev/null +++ b/tests/unit/codegen/shared/compilation/test_function_construction.py @@ -0,0 +1,51 @@ +from unittest.mock import patch + +from codegen.shared.compilation.function_construction import create_function_str_from_codeblock + + +def test_no_execute_func_wraps(): + codeblock = """ +print(len(codebase.files)) +""" + func = create_function_str_from_codeblock(codeblock, func_name="execute") + assert ( + """ +def execute(codebase: Codebase, pr_options: PROptions | None = None, pr = None, **kwargs): + print = codebase.log + print(len(codebase.files)) +""" + in func + ) + + +def test_func_name_already_exists(): + codeblock = """ +def execute(codebase: Codebase): + print(len(codebase.files)) +""" + func = create_function_str_from_codeblock(codeblock, func_name="execute") + assert codeblock in func + + +def test_func_name_not_execute(): + codeblock = """ +print(len(codebase.files)) +""" + func = create_function_str_from_codeblock(codeblock, func_name="not_execute") + assert ( + """ +def not_execute(codebase: Codebase, pr_options: PROptions | None = None, pr = None, **kwargs): + print = codebase.log + print(len(codebase.files)) +""" + in func + ) + + +def test_function_str_includes_imports(): + codeblock = """ +print(len(codebase.files)) +""" + with patch("codegen.shared.compilation.function_construction.get_generated_imports", return_value="from foo import bar"): + func = create_function_str_from_codeblock(codeblock, func_name="execute") + assert "from foo import bar" in func diff --git a/tests/unit/codegen/shared/compilation/test_string_to_code.py b/tests/unit/codegen/shared/compilation/test_string_to_code.py new file mode 100644 index 000000000..7d71ab50e --- /dev/null +++ b/tests/unit/codegen/shared/compilation/test_string_to_code.py @@ -0,0 +1,117 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from codegen.git.models.pr_options import PROptions +from codegen.shared.compilation.string_to_code import create_execute_function_from_codeblock +from codegen.shared.exceptions.compilation import DangerousUserCodeException, InvalidUserCodeException +from codegen.shared.exceptions.control_flow import StopCodemodException + + +def test_syntax_error_raises(): + codeblock = """ +print "syntax error" +""" + with pytest.raises(InvalidUserCodeException) as exc_info: + create_execute_function_from_codeblock(codeblock=codeblock) + assert exc_info + error_msg = str(exc_info.value) + assert "SyntaxError" in error_msg + assert 'print "syntax error"' in error_msg + + +def test_print_os_environ_raises(): + codeblock = """ +print(os.environ["ENV"]) +""" + with pytest.raises(DangerousUserCodeException): + create_execute_function_from_codeblock(codeblock=codeblock) + + +def test_print_calls_codebase_log(): + """Test print is monkey patched to call codebase.log""" + codeblock = """ +print("actually codebase.log") +""" + func = create_execute_function_from_codeblock(codeblock=codeblock) + mock_log = MagicMock() + func(codebase=MagicMock(log=mock_log), pr_options=PROptions()) + assert mock_log.call_count == 1 + assert mock_log.call_args_list[0][0][0] == "actually codebase.log" + + +def test_set_custom_scope_does_not_raise(): + """Test if the custom scope is set and the codeblock uses a var defined in the scope, it does not raise a NameError.""" + codeblock = """ +print(local_a) +""" + func = create_execute_function_from_codeblock(codeblock=codeblock, custom_scope={"local_a": "this is local_a"}) + mock_log = MagicMock() + func(codebase=MagicMock(log=mock_log), pr_options=PROptions()) + assert mock_log.call_count == 1 + assert mock_log.call_args_list[0][0][0] == "this is local_a" + + +@patch("codegen.shared.compilation.string_to_code.logger") +def test_stop_codemod_execution_logs_and_raises(mock_logger): + codeblock = """ +local_a = "this is local_a" +raise StopCodemodException("test exception") +""" + func = create_execute_function_from_codeblock(codeblock=codeblock) + with pytest.raises(StopCodemodException): + func(codebase=MagicMock(), pr_options=PROptions()) + mock_logger.info.call_count == 2 + mock_logger.info.call_args_list[1][0][0] == "Stopping codemod due to StopCodemodException: test exception" + + +def test_references_import_from_generated_imports_does_not_raise(): + codeblock = """ +print(os.getcwd()) # test external import +print(MessageType.GITHUB) # test gs private import +print(Export.__name__) # test gs public import +""" + func = create_execute_function_from_codeblock(codeblock=codeblock) + mock_log = MagicMock() + func(codebase=MagicMock(log=mock_log), pr_options=PROptions()) + assert mock_log.call_count == 3 + + +def test_references_import_not_in_generated_imports_raises_runtime_error(): + codeblock = """ +print(Chainable.__name__) +""" + with pytest.raises(RuntimeError) as exc_info: + func = create_execute_function_from_codeblock(codeblock=codeblock) + func(codebase=MagicMock(), pr_options=PROptions()) + assert exc_info + error_msg = str(exc_info.value) + assert "NameError: name 'Chainable' is not defined." in error_msg + assert "> 1: print(Chainable.__name__)" in error_msg + + +def test_error_during_execution_raises_runtime_error(): + codeblock = """ +print(var_that_does_not_exist) +""" + func = create_execute_function_from_codeblock(codeblock=codeblock) + with pytest.raises(RuntimeError) as exc_info: + func(codebase=MagicMock(), pr_options=PROptions()) + assert exc_info + assert exc_info.typename == "RuntimeError" + error_msg = str(exc_info.value) + assert "NameError: name 'var_that_does_not_exist' is not defined" in error_msg + assert "> 1: print(var_that_does_not_exist)" in error_msg + + +@pytest.mark.xfail(reason="TODO(CG-9581): fix codeblocks with return statements") +def test_return_statement_still_returns_locals(): + """Test if there is a return statement in a customer code block, the function should still return the locals""" + codeblock = """ +local_a = "this is local_a" +return "this is a return statement" +""" + func = create_execute_function_from_codeblock(codeblock=codeblock) + res = func(codebase=MagicMock(), pr_options=PROptions()) + assert isinstance(res, dict) + assert res == {"local_a": "this is local_a"} diff --git a/uv.lock b/uv.lock index c74fd17fa..e3c8eb3a7 100644 --- a/uv.lock +++ b/uv.lock @@ -452,6 +452,7 @@ dev = [ { name = "pre-commit" }, { name = "pre-commit-uv" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-benchmark", extra = ["histogram"] }, { name = "pytest-cov" }, { name = "pytest-mock" }, @@ -545,6 +546,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pre-commit-uv", specifier = ">=4.1.4" }, { name = "pytest", specifier = ">=8.3.3" }, + { name = "pytest-asyncio", specifier = ">=0.23.5" }, { name = "pytest-benchmark", extras = ["histogram"], specifier = ">=5.1.0" }, { name = "pytest-cov", specifier = ">=6.0.0,<6.0.1" }, { name = "pytest-mock", specifier = ">=3.14.0,<4.0.0" }, @@ -2010,6 +2012,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/df/adcc0d60f1053d74717d21d58c0048479e9cab51464ce0d2965b086bd0e2/pytest_asyncio-0.25.2.tar.gz", hash = "sha256:3f8ef9a98f45948ea91a0ed3dc4268b5326c0e7bce73892acc654df4262ad45f", size = 53950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/d8/defa05ae50dcd6019a95527200d3b3980043df5aa445d40cb0ef9f7f98ab/pytest_asyncio-0.25.2-py3-none-any.whl", hash = "sha256:0d0bb693f7b99da304a0634afc0a4b19e49d5e0de2d670f38dc4bfa5727c5075", size = 19400 }, +] + [[package]] name = "pytest-benchmark" version = "5.1.0"