From 8d79694bc2c497537ea779f353e5789a29f11c9a Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 11:25:36 -0700 Subject: [PATCH 01/10] Add allow_external flag --- src/codegen/configs/models/codebase.py | 1 + src/codegen/sdk/codebase/codebase_context.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/codegen/configs/models/codebase.py b/src/codegen/configs/models/codebase.py index b8484e63b..8846d245a 100644 --- a/src/codegen/configs/models/codebase.py +++ b/src/codegen/configs/models/codebase.py @@ -21,6 +21,7 @@ def __init__(self, prefix: str = "CODEBASE", *args, **kwargs) -> None: import_resolution_paths: list[str] = Field(default_factory=lambda: []) import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {}) py_resolve_syspath: bool = False + allow_external: bool = False ts_dependency_manager: bool = False ts_language_engine: bool = False v8_ts_engine: bool = False diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index 8840adce5..92d441e2f 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -381,7 +381,7 @@ def get_directory(self, directory_path: PathLike, create_on_missing: bool = Fals """ # If not part of repo path, return None absolute_path = self.to_absolute(directory_path) - if not self.is_subdir(absolute_path): + if not self.is_subdir(absolute_path) and not self.config.allow_external: assert False, f"Directory {absolute_path} is not part of repo path {self.repo_path}" return None @@ -611,7 +611,7 @@ def get_edges(self) -> list[tuple[NodeId, NodeId, EdgeType, Usage | None]]: def get_file(self, file_path: os.PathLike, ignore_case: bool = False) -> SourceFile | None: # If not part of repo path, return None absolute_path = self.to_absolute(file_path) - if not self.is_subdir(absolute_path): + if not self.is_subdir(absolute_path) and not self.config.allow_external: assert False, f"File {file_path} is not part of the repository path" # Check if file exists in graph From 246624f7aecb72374e15d211f2c9f0694f01d747 Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 11:25:44 -0700 Subject: [PATCH 02/10] Add docs for allow_external --- docs/introduction/advanced-settings.mdx | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/introduction/advanced-settings.mdx b/docs/introduction/advanced-settings.mdx index e7297bd38..629ec1a0a 100644 --- a/docs/introduction/advanced-settings.mdx +++ b/docs/introduction/advanced-settings.mdx @@ -309,6 +309,19 @@ Controls import path overrides during import resolution. Enables and disables resolution of imports from `sys.path`. + +For this to properly work, you may also need to enable `allow_external`. + + +## Flag: `allow_external` +> **Default: `False`** + +Enables resolving imports, files, modules, and directories from outside of the repo path. + + +Turning this flag off may allow for bad actors to access files outside of the repo path! Use with caution! + + ## Flag: `ts_dependency_manager` > **Default: `False`** From 779a7ad9b7396fa108b298f642f4063b1de0e428 Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 11:25:58 -0700 Subject: [PATCH 03/10] Fix tests (cwd fix) --- src/codegen/sdk/utils.py | 27 ++++ .../test_import_resolution.py | 147 ++++++++++-------- 2 files changed, 107 insertions(+), 67 deletions(-) diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py index 7476e6e8a..7a989aac6 100644 --- a/src/codegen/sdk/utils.py +++ b/src/codegen/sdk/utils.py @@ -339,3 +339,30 @@ def is_minified_js(content): except Exception as e: print(f"Error analyzing content: {e}") return False + + +@contextmanager +def use_cwd(path): + """Context manager that temporarily changes the current working directory. + + Args: + path (str): The directory path to change to. + + Yields: + str: The new current working directory. + + Example: + ```python + with use_cwd('/path/to/directory'): + # Code here runs with the working directory set to '/path/to/directory' + ... + # Working directory is restored to the original + ``` + """ + old_cwd = os.getcwd() + try: + os.chdir(path) + yield path + finally: + os.chdir(old_cwd) + diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index ebbe5d724..65d6e94fa 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -2,6 +2,7 @@ from codegen.sdk.codebase.config import TestFlags from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.sdk.utils import use_cwd if TYPE_CHECKING: from codegen.sdk.core.file import SourceFile @@ -272,27 +273,31 @@ def func(): """, }, ) as codebase: - src_file: SourceFile = codebase.get_file("a/b/c/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") + # Wrap CWD since we are using py_resolve_syspath + with use_cwd(tmpdir): + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") - # Enable resolution via sys.path - codebase.ctx.config.py_resolve_syspath = True + # Enable resolution via sys.path + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True - # =====[ Imports cannot be found without sys.path being set ]===== - assert len(consumer_file.imports) == 1 - src_import: Import = consumer_file.imports[0] - src_import_resolution: ImportResolution = src_import.resolve_import() - assert src_import_resolution is None + # =====[ Imports cannot be found without sys.path being set ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution: ImportResolution = src_import.resolve_import() + assert src_import_resolution is None - # Modify sys.path for this test only - monkeypatch.syspath_prepend("a") + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") - # =====[ Imports can be found with sys.path set and active ]===== - codebase.ctx.config.py_resolve_syspath = True - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is True + # =====[ Imports can be found with sys.path set and active ]===== + codebase.ctx.config.py_resolve_syspath = True + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True def test_import_resolution_file_custom_resolve_path(tmpdir: str) -> None: @@ -366,28 +371,32 @@ def func(): """, }, ) as codebase: - src_file: SourceFile = codebase.get_file("a/b/c/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") - - # Ensure we don't have overrites and enable syspath resolution - codebase.ctx.config.import_resolution_paths = [] - codebase.ctx.config.py_resolve_syspath = True - - # =====[ Import with sys.path set can be found ]===== - assert len(consumer_file.imports) == 1 - # Modify sys.path for this test only - monkeypatch.syspath_prepend("a") - src_import: Import = consumer_file.imports[0] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file.file_path == "a/c/src.py" - - # =====[ Imports can be found with custom resolve over sys.path ]===== - codebase.ctx.config.import_resolution_paths = ["a/b"] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is True + # Wrap CWD since we are using py_resolve_syspath + with use_cwd(tmpdir): + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.import_resolution_paths = [] + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True + + # =====[ Import with sys.path set can be found ]===== + assert len(consumer_file.imports) == 1 + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file.file_path == "a/c/src.py" + + # =====[ Imports can be found with custom resolve over sys.path ]===== + codebase.ctx.config.import_resolution_paths = ["a/b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True def test_import_resolution_default_conflicts_overrite(tmpdir: str, monkeypatch) -> None: @@ -412,34 +421,38 @@ def func(): """, }, ) as codebase: - src_file: SourceFile = codebase.get_file("a/src.py") - src_file_overrite: SourceFile = codebase.get_file("b/a/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") - - # Ensure we don't have overrites and enable syspath resolution - codebase.ctx.config.import_resolution_paths = [] - codebase.ctx.config.py_resolve_syspath = True - - # =====[ Default import works ]===== - assert len(consumer_file.imports) == 1 - src_import: Import = consumer_file.imports[0] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - - # =====[ Sys.path overrite has precedence ]===== - monkeypatch.syspath_prepend("b") - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is not src_file - assert src_import_resolution.from_file is src_file_overrite - - # =====[ Custom overrite has precedence ]===== - codebase.ctx.config.import_resolution_paths = ["b"] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is not src_file - assert src_import_resolution.from_file is src_file_overrite + # Wrap CWD since we are using py_resolve_syspath + with use_cwd(tmpdir): + src_file: SourceFile = codebase.get_file("a/src.py") + src_file_overrite: SourceFile = codebase.get_file("b/a/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.import_resolution_paths = [] + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True + + # =====[ Default import works ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + + # =====[ Sys.path overrite has precedence ]===== + monkeypatch.syspath_prepend("b") + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite + + # =====[ Custom overrite has precedence ]===== + codebase.ctx.config.import_resolution_paths = ["b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite def test_import_resolution_init_wildcard(tmpdir: str) -> None: From 350c4f2948fc32ea36f52150e9fc7b982565976d Mon Sep 17 00:00:00 2001 From: EdwardJXLi <20020059+EdwardJXLi@users.noreply.github.com> Date: Wed, 12 Mar 2025 18:27:25 +0000 Subject: [PATCH 04/10] Automated pre-commit update --- src/codegen/sdk/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py index 7a989aac6..fd9b8ea32 100644 --- a/src/codegen/sdk/utils.py +++ b/src/codegen/sdk/utils.py @@ -353,7 +353,7 @@ def use_cwd(path): Example: ```python - with use_cwd('/path/to/directory'): + with use_cwd("/path/to/directory"): # Code here runs with the working directory set to '/path/to/directory' ... # Working directory is restored to the original @@ -365,4 +365,3 @@ def use_cwd(path): yield path finally: os.chdir(old_cwd) - From 1a94c723afe528e0415855e6bbdd49a77ebdd80c Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 11:30:36 -0700 Subject: [PATCH 05/10] Upgrade allow_external warning to error --- docs/introduction/advanced-settings.mdx | 6 +++--- src/codegen/sdk/core/codebase.py | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/introduction/advanced-settings.mdx b/docs/introduction/advanced-settings.mdx index 629ec1a0a..08f0e5472 100644 --- a/docs/introduction/advanced-settings.mdx +++ b/docs/introduction/advanced-settings.mdx @@ -309,9 +309,9 @@ Controls import path overrides during import resolution. Enables and disables resolution of imports from `sys.path`. - -For this to properly work, you may also need to enable `allow_external`. - + +For this to properly work, you must also set `allow_external` to `True`. + ## Flag: `allow_external` > **Default: `False`** diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 34d899441..f5b853a6e 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -213,6 +213,13 @@ def __init__( self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress) self.console = Console(record=True, soft_wrap=True) + # Assert config assertions + # External import resolution must be enabled if syspath is enabled + if self.ctx.config.py_resolve_syspath: + if not self.ctx.config.allow_external: + msg = "allow_external must be set to True when py_resolve_syspath is enabled" + raise ValueError(msg) + @noapidoc def __str__(self) -> str: return f"" From edd4fc6a66f2d59cd104014c745783aa377a595a Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 13:20:15 -0700 Subject: [PATCH 06/10] Remove use_cwd --- src/codegen/sdk/utils.py | 26 --- tests/unit/codegen/conftest.py | 2 + .../test_import_resolution.py | 153 +++++++++--------- 3 files changed, 75 insertions(+), 106 deletions(-) diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py index fd9b8ea32..7476e6e8a 100644 --- a/src/codegen/sdk/utils.py +++ b/src/codegen/sdk/utils.py @@ -339,29 +339,3 @@ def is_minified_js(content): except Exception as e: print(f"Error analyzing content: {e}") return False - - -@contextmanager -def use_cwd(path): - """Context manager that temporarily changes the current working directory. - - Args: - path (str): The directory path to change to. - - Yields: - str: The new current working directory. - - Example: - ```python - with use_cwd("/path/to/directory"): - # Code here runs with the working directory set to '/path/to/directory' - ... - # Working directory is restored to the original - ``` - """ - old_cwd = os.getcwd() - try: - os.chdir(path) - yield path - finally: - os.chdir(old_cwd) diff --git a/tests/unit/codegen/conftest.py b/tests/unit/codegen/conftest.py index f27c3a12a..8b5bf53eb 100644 --- a/tests/unit/codegen/conftest.py +++ b/tests/unit/codegen/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest from codegen.sdk.codebase.factory.get_session import get_codebase_session diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 65d6e94fa..085b285a9 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -2,7 +2,6 @@ from codegen.sdk.codebase.config import TestFlags from codegen.sdk.codebase.factory.get_session import get_codebase_session -from codegen.sdk.utils import use_cwd if TYPE_CHECKING: from codegen.sdk.core.file import SourceFile @@ -273,31 +272,29 @@ def func(): """, }, ) as codebase: - # Wrap CWD since we are using py_resolve_syspath - with use_cwd(tmpdir): - src_file: SourceFile = codebase.get_file("a/b/c/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") - # Enable resolution via sys.path - codebase.ctx.config.py_resolve_syspath = True - # Allow resolving files and modules outside of the repo path - codebase.ctx.config.allow_external = True + # Enable resolution via sys.path + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True - # =====[ Imports cannot be found without sys.path being set ]===== - assert len(consumer_file.imports) == 1 - src_import: Import = consumer_file.imports[0] - src_import_resolution: ImportResolution = src_import.resolve_import() - assert src_import_resolution is None + # =====[ Imports cannot be found without sys.path being set ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution: ImportResolution = src_import.resolve_import() + assert src_import_resolution is None - # Modify sys.path for this test only - monkeypatch.syspath_prepend("a") + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") - # =====[ Imports can be found with sys.path set and active ]===== - codebase.ctx.config.py_resolve_syspath = True - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is True + # =====[ Imports can be found with sys.path set and active ]===== + codebase.ctx.config.py_resolve_syspath = True + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True def test_import_resolution_file_custom_resolve_path(tmpdir: str) -> None: @@ -371,32 +368,30 @@ def func(): """, }, ) as codebase: - # Wrap CWD since we are using py_resolve_syspath - with use_cwd(tmpdir): - src_file: SourceFile = codebase.get_file("a/b/c/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") - - # Ensure we don't have overrites and enable syspath resolution - codebase.ctx.config.import_resolution_paths = [] - codebase.ctx.config.py_resolve_syspath = True - # Allow resolving files and modules outside of the repo path - codebase.ctx.config.allow_external = True - - # =====[ Import with sys.path set can be found ]===== - assert len(consumer_file.imports) == 1 - # Modify sys.path for this test only - monkeypatch.syspath_prepend("a") - src_import: Import = consumer_file.imports[0] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file.file_path == "a/c/src.py" - - # =====[ Imports can be found with custom resolve over sys.path ]===== - codebase.ctx.config.import_resolution_paths = ["a/b"] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is True + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.import_resolution_paths = [] + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True + + # =====[ Import with sys.path set can be found ]===== + assert len(consumer_file.imports) == 1 + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file.file_path == "a/c/src.py" + + # =====[ Imports can be found with custom resolve over sys.path ]===== + codebase.ctx.config.import_resolution_paths = ["a/b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True def test_import_resolution_default_conflicts_overrite(tmpdir: str, monkeypatch) -> None: @@ -421,38 +416,36 @@ def func(): """, }, ) as codebase: - # Wrap CWD since we are using py_resolve_syspath - with use_cwd(tmpdir): - src_file: SourceFile = codebase.get_file("a/src.py") - src_file_overrite: SourceFile = codebase.get_file("b/a/src.py") - consumer_file: SourceFile = codebase.get_file("consumer.py") - - # Ensure we don't have overrites and enable syspath resolution - codebase.ctx.config.import_resolution_paths = [] - codebase.ctx.config.py_resolve_syspath = True - # Allow resolving files and modules outside of the repo path - codebase.ctx.config.allow_external = True - - # =====[ Default import works ]===== - assert len(consumer_file.imports) == 1 - src_import: Import = consumer_file.imports[0] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - - # =====[ Sys.path overrite has precedence ]===== - monkeypatch.syspath_prepend("b") - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is not src_file - assert src_import_resolution.from_file is src_file_overrite - - # =====[ Custom overrite has precedence ]===== - codebase.ctx.config.import_resolution_paths = ["b"] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is not src_file - assert src_import_resolution.from_file is src_file_overrite + src_file: SourceFile = codebase.get_file("a/src.py") + src_file_overrite: SourceFile = codebase.get_file("b/a/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.import_resolution_paths = [] + codebase.ctx.config.py_resolve_syspath = True + # Allow resolving files and modules outside of the repo path + codebase.ctx.config.allow_external = True + + # =====[ Default import works ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + + # =====[ Sys.path overrite has precedence ]===== + monkeypatch.syspath_prepend("b") + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite + + # =====[ Custom overrite has precedence ]===== + codebase.ctx.config.import_resolution_paths = ["b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite def test_import_resolution_init_wildcard(tmpdir: str) -> None: From c4c83a3c283f52e07a13d2cc619aed5993db9a5c Mon Sep 17 00:00:00 2001 From: EdwardJXLi <20020059+EdwardJXLi@users.noreply.github.com> Date: Wed, 12 Mar 2025 20:24:36 +0000 Subject: [PATCH 07/10] Automated pre-commit update --- tests/unit/codegen/conftest.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/codegen/conftest.py b/tests/unit/codegen/conftest.py index 8b5bf53eb..f27c3a12a 100644 --- a/tests/unit/codegen/conftest.py +++ b/tests/unit/codegen/conftest.py @@ -1,5 +1,3 @@ -import os - import pytest from codegen.sdk.codebase.factory.get_session import get_codebase_session From e756e1571789bb3e908c7bd6ae33952f170952db Mon Sep 17 00:00:00 2001 From: Edo Pujol Date: Wed, 12 Mar 2025 13:17:16 -0400 Subject: [PATCH 08/10] misc: better-error (#805) # Motivation # Content # Testing # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: kopekC <28070492+kopekC@users.noreply.github.com> --- src/codegen/extensions/langchain/graph.py | 243 +++++++++++++++++++- src/codegen/extensions/langchain/tools.py | 33 ++- src/codegen/extensions/tools/create_file.py | 4 +- 3 files changed, 272 insertions(+), 8 deletions(-) diff --git a/src/codegen/extensions/langchain/graph.py b/src/codegen/extensions/langchain/graph.py index e5116630f..3da422560 100644 --- a/src/codegen/extensions/langchain/graph.py +++ b/src/codegen/extensions/langchain/graph.py @@ -71,9 +71,250 @@ def create(self, checkpointer: Optional[MemorySaver] = None, debug: bool = False jitter=True, ) + # Custom error handler for tool validation errors + def handle_tool_errors(exception): + error_msg = str(exception) + + # Extract tool name and input from the exception if possible + tool_name = "unknown" + tool_input = {} + + # Helper function to get field descriptions from any tool + def get_field_descriptions(tool_obj): + field_descriptions = {} + if not tool_obj or not hasattr(tool_obj, "args_schema"): + return field_descriptions + + try: + schema_cls = tool_obj.args_schema + + # Handle Pydantic v2 + if hasattr(schema_cls, "model_fields"): + for field_name, field in schema_cls.model_fields.items(): + field_descriptions[field_name] = field.description or f"Required parameter for {tool_obj.name}" + + # Handle Pydantic v1 with warning suppression + elif hasattr(schema_cls, "__fields__"): + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + for field_name, field in schema_cls.__fields__.items(): + field_descriptions[field_name] = field.field_info.description or f"Required parameter for {tool_obj.name}" + except Exception: + pass + + return field_descriptions + + # Try to extract tool name and input from the exception + import re + + tool_match = re.search(r"for (\w+)Input", error_msg) + if tool_match: + # Get the extracted name but preserve original case by finding the matching tool + extracted_name = tool_match.group(1).lower() + for t in self.tools: + if t.name.lower() == extracted_name: + tool_name = t.name # Use the original case from the tool + break + + # Try to extract the input values + input_match = re.search(r"input_value=(\{.*?\})", error_msg) + if input_match: + input_str = input_match.group(1) + # Simple parsing of the dict-like string + try: + # Clean up the string to make it more parseable + input_str = input_str.replace("'", '"') + import json + + tool_input = json.loads(input_str) + except: + pass + + # Handle validation errors with more helpful messages + if "validation error" in error_msg.lower(): + # Find the tool in our tools list to get its schema + tool = next((t for t in self.tools if t.name == tool_name), None) + + # If we couldn't find the tool by extracted name, try to find it by looking at all tools + if tool is None: + # Try to extract tool name from the error message + for t in self.tools: + if t.name.lower() in error_msg.lower(): + tool = t + tool_name = t.name + break + + # If still not found, check if any tool's schema name matches + if tool is None: + for t in self.tools: + if hasattr(t, "args_schema") and t.args_schema.__name__.lower() in error_msg.lower(): + tool = t + tool_name = t.name + break + + # Check for type errors + type_errors = [] + if "type_error" in error_msg.lower(): + import re + + # Try to extract type error information + type_error_matches = re.findall(r"'(\w+)'.*?type_error\.(.*?)(?:;|$)", error_msg, re.IGNORECASE) + for field_name, error_type in type_error_matches: + if "json" in error_type: + type_errors.append(f"'{field_name}' must be a string, not a JSON object or dictionary") + elif "str_type" in error_type: + type_errors.append(f"'{field_name}' must be a string") + elif "int_type" in error_type: + type_errors.append(f"'{field_name}' must be an integer") + elif "bool_type" in error_type: + type_errors.append(f"'{field_name}' must be a boolean") + elif "list_type" in error_type: + type_errors.append(f"'{field_name}' must be a list") + else: + type_errors.append(f"'{field_name}' has an incorrect type") + + if type_errors: + errors_str = "\n- ".join(type_errors) + return f"Error using {tool_name} tool: Parameter type errors:\n- {errors_str}\n\nYou provided: {tool_input}\n\nPlease try again with the correct parameter types." + + # Get missing fields by comparing tool input with required fields + missing_fields = [] + if tool and hasattr(tool, "args_schema"): + try: + # Get the schema class + schema_cls = tool.args_schema + + # Handle Pydantic v2 (preferred) or v1 with warning suppression + if hasattr(schema_cls, "model_fields"): # Pydantic v2 + for field_name, field in schema_cls.model_fields.items(): + # Check if field is required and missing from input + if field.is_required() and field_name not in tool_input: + missing_fields.append(field_name) + else: # Pydantic v1 with warning suppression + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + for field_name, field in schema_cls.__fields__.items(): + # Check if field is required and missing from input + if field.required and field_name not in tool_input: + missing_fields.append(field_name) + except Exception as e: + # If we can't extract schema info, we'll fall back to regex + pass + + # If we couldn't get missing fields from schema, try to extract from error message + if not missing_fields: + # Extract the missing field name if possible using regex + import re + + field_matches = re.findall(r"'(\w+)'(?:\s+|.*?)field required", error_msg, re.IGNORECASE) + if field_matches: + missing_fields = field_matches + else: + # Try another pattern + field_match = re.search(r"(\w+)\s+Field required", error_msg) + if field_match: + missing_fields = [field_match.group(1)] + + # If we have identified missing fields, create a helpful error message + if missing_fields: + fields_str = ", ".join([f"'{f}'" for f in missing_fields]) + + # Get tool documentation if available + tool_docs = "" + if tool: + if hasattr(tool, "description") and tool.description: + tool_docs = f"\nTool description: {tool.description}\n" + + # Try to get parameter descriptions from the schema + param_docs = [] + try: + # Get all field descriptions from the tool + field_descriptions = get_field_descriptions(tool) + + # Add descriptions for missing fields + for field_name in missing_fields: + if field_name in field_descriptions: + param_docs.append(f"- {field_name}: {field_descriptions[field_name]}") + else: + param_docs.append(f"- {field_name}: Required parameter") + + if param_docs: + tool_docs += "\nParameter descriptions:\n" + "\n".join(param_docs) + except Exception: + # Fallback to simple parameter list + param_docs = [f"- {field}: Required parameter" for field in missing_fields] + if param_docs: + tool_docs += "\nMissing parameters:\n" + "\n".join(param_docs) + + # Add usage examples for common tools + example = "" + if tool_name == "create_file": + example = "\nExample: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + elif tool_name == "replace_edit": + example = "\nExample: replace_edit(filepath='path/to/file.py', old_text='def old_function()', new_text='def new_function()')" + elif tool_name == "view_file": + example = "\nExample: view_file(filepath='path/to/file.py')" + elif tool_name == "search": + example = "\nExample: search(query='function_name', file_extensions=['.py'])" + + return ( + f"Error using {tool_name} tool: Missing required parameter(s): {fields_str}\n\nYou provided: {tool_input}\n{tool_docs}{example}\nPlease try again with all required parameters." + ) + + # Common error patterns for specific tools (as fallback) + if tool_name == "create_file": + if "content" not in tool_input: + return ( + "Error: When using the create_file tool, you must provide both 'filepath' and 'content' parameters.\n" + "The 'content' parameter is missing. Please try again with both parameters.\n\n" + "Example: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + ) + elif "filepath" not in tool_input: + return ( + "Error: When using the create_file tool, you must provide both 'filepath' and 'content' parameters.\n" + "The 'filepath' parameter is missing. Please try again with both parameters.\n\n" + "Example: create_file(filepath='path/to/file.py', content='print(\"Hello world\")')" + ) + + elif tool_name == "replace_edit": + if "filepath" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'filepath' parameter is missing. Please try again with all required parameters." + ) + elif "old_text" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'old_text' parameter is missing. Please try again with all required parameters." + ) + elif "new_text" not in tool_input: + return ( + "Error: When using the replace_edit tool, you must provide 'filepath', 'old_text', and 'new_text' parameters.\n" + "The 'new_text' parameter is missing. Please try again with all required parameters." + ) + + # Generic validation error with better formatting + if tool: + return ( + f"Error using {tool_name} tool: {error_msg}\n\n" + f"You provided these parameters: {tool_input}\n\n" + f"Please check the tool's required parameters and try again with all required fields." + ) + else: + # If we couldn't identify the tool, list all available tools + available_tools = "\n".join([f"- {t.name}" for t in self.tools]) + return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters." + + # For other types of errors + return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters." + # Add nodes builder.add_node("reasoner", self.reasoner, retry=retry_policy) - builder.add_node("tools", ToolNode(self.tools), retry=retry_policy) + builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy) # Add edges builder.add_edge(START, "reasoner") diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 0ce6b97a1..877b59f05 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -136,7 +136,7 @@ class SearchTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, query: str, target_directories: Optional[list[str]] = None, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: + def _run(self, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex) return result.render() @@ -171,7 +171,6 @@ class EditFileTool(BaseTool): 1. Simple text: "function calculateTotal" (matches exactly, case-insensitive) 2. Regex: "def.*calculate.*\(.*\)" (with use_regex=True) 3. File-specific: "TODO" with file_extensions=[".py", ".ts"] - 4. Directory-specific: "api" with target_directories=["src/backend"] """ args_schema: ClassVar[type[BaseModel]] = EditFileInput codebase: Codebase = Field(exclude=True) @@ -188,21 +187,45 @@ class CreateFileInput(BaseModel): """Input for creating a file.""" filepath: str = Field(..., description="Path where to create the file") - content: str = Field(default="", description="Initial file content") + content: str = Field( + ..., + description=""" +Content for the new file (REQUIRED). + +⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type. +Example: content="print('Hello world')" +NOT: content={"code": "print('Hello world')"} + """, + ) class CreateFileTool(BaseTool): """Tool for creating files.""" name: ClassVar[str] = "create_file" - description: ClassVar[str] = "Create a new file in the codebase" + description: ClassVar[str] = """ +Create a new file in the codebase. Always provide content for the new file, even if minimal. + +⚠️ CRITICAL WARNING ⚠️ +Both parameters MUST be provided as STRINGS: +The content for the new file always needs to be provided. + +1. filepath: The path where to create the file (as a string) +2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object) + +✅ CORRECT usage: +create_file(filepath="path/to/file.py", content="print('Hello world')") + +The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about +missing content, you are likely trying to pass a dictionary instead of a string. +""" args_schema: ClassVar[type[BaseModel]] = CreateFileInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, filepath: str, content: str = "") -> str: + def _run(self, filepath: str, content: str) -> str: result = create_file(self.codebase, filepath, content) return result.render() diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index b10d01f52..3a54303ff 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -23,13 +23,13 @@ class CreateFileObservation(Observation): str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str = "") -> CreateFileObservation: +def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation: """Create a new file. Args: codebase: The codebase to operate on filepath: Path where to create the file - content: Initial file content + content: Content for the new file (required) Returns: CreateFileObservation containing new file state, or error if file exists From ee3de3b2e6a903e8e3a1d87fa31b1b2d719ae878 Mon Sep 17 00:00:00 2001 From: tomcodgen Date: Wed, 12 Mar 2025 19:29:47 +0100 Subject: [PATCH 09/10] [CG-10935] fix: issues with assigment (#737) # Motivation # Content # Testing # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Co-authored-by: tomcodegen --- src/codegen/sdk/core/expressions/name.py | 29 +++- src/codegen/sdk/core/file.py | 42 +++++- src/codegen/sdk/core/function.py | 7 +- .../sdk/core/interfaces/conditional_block.py | 17 +++ src/codegen/sdk/core/interfaces/editable.py | 7 +- .../sdk/core/statements/catch_statement.py | 3 +- .../sdk/core/statements/for_loop_statement.py | 18 ++- .../sdk/core/statements/if_block_statement.py | 28 +++- .../sdk/core/statements/switch_case.py | 7 +- .../core/statements/try_catch_statement.py | 3 +- src/codegen/sdk/python/function.py | 12 +- src/codegen/sdk/python/import_resolution.py | 6 +- .../sdk/python/statements/catch_statement.py | 5 + .../python/statements/if_block_statement.py | 1 - .../sdk/python/statements/match_case.py | 5 + .../sdk/python/statements/match_statement.py | 2 +- .../python/statements/try_catch_statement.py | 14 ++ src/codegen/sdk/typescript/export.py | 4 +- src/codegen/sdk/typescript/function.py | 12 +- .../typescript/statements/catch_statement.py | 6 +- .../typescript/statements/switch_statement.py | 2 +- .../statements/try_catch_statement.py | 17 +++ .../test_if_block_statement_properties.py | 140 ++++++++++++++++++ .../test_try_catch_statement.py | 20 +++ .../test_match_statement.py | 24 +++ 25 files changed, 390 insertions(+), 41 deletions(-) create mode 100644 src/codegen/sdk/core/interfaces/conditional_block.py diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 3ee1b6411..df5ef6872 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -5,13 +5,15 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.expression import Expression +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.resolvable import Resolvable from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from codegen.sdk.core.import_resolution import Import, WildcardImport from codegen.sdk.core.interfaces.has_name import HasName - + from codegen.sdk.core.symbol import Symbol Parent = TypeVar("Parent", bound="Expression") @@ -29,10 +31,9 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): @override def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: """Resolve the types used by this symbol.""" - if used := self.resolve_name(self.source, self.start_byte): + for used in self.resolve_name(self.source, self.start_byte): yield from self.with_resolution_frame(used) - @noapidoc @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: """Compute the dependencies of the export object.""" @@ -48,3 +49,25 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | def rename_if_matching(self, old: str, new: str): if self.source == old: self.edit(new) + + @noapidoc + @reader + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: + resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None) + if resolved_name: + yield resolved_name + else: + return + + if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): + top_of_conditional = conditional_parent.start_byte + if self.parent_of_type(ConditionalBlock) == conditional_parent: + # Use in the same block, should only depend on the inside of the block + return + for other_conditional in conditional_parent.other_possible_blocks: + if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): + if cond_name.start_byte >= other_conditional.start_byte: + yield cond_name + top_of_conditional = min(top_of_conditional, other_conditional.start_byte) + + yield from self.resolve_name(name, top_of_conditional, strict=False) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 8ad9e1385..12bcab303 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -3,7 +3,7 @@ import resource import sys from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Generator, Sequence from functools import cached_property from os import PathLike from pathlib import Path @@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None: Returns: Symbol | None: The found symbol, or None if not found. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Symbol): return symbol return next((x for x in self.symbols if x.name == name), None) @@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None: Returns: TClass | None: The matching Class object if found, None otherwise. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Class): return symbol @@ -880,13 +880,41 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: + """Resolves a name to a symbol, import, or wildcard import within the file's scope. + + Performs name resolution by first checking the file's valid symbols and imports. When a start_byte + is provided, ensures proper scope handling by only resolving to symbols that are defined before + that position in the file. + + Args: + name (str): The name to resolve. + start_byte (int | None): If provided, only resolves to symbols defined before this byte position + in the file. Used for proper scope handling. Defaults to None. + strict (bool): When True and using start_byte, only yields symbols if found in the correct scope. + When False, allows falling back to global scope. Defaults to True. + + Yields: + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches + the name and scope requirements. Yields at most one result. + """ if resolved := self.valid_symbol_names.get(name): + # If we have a start_byte and the resolved symbol is after it, + # we need to look for earlier definitions of the symbol if start_byte is not None and resolved.end_byte > start_byte: - for symbol in self.symbols: + # Search backwards through symbols to find the most recent definition + # that comes before our start_byte position + for symbol in reversed(self.symbols): if symbol.start_byte <= start_byte and symbol.name == name: - return symbol - return resolved + yield symbol + return + # If strict mode and no valid symbol found, return nothing + if not strict: + return + # Either no start_byte constraint or symbol is before start_byte + yield resolved + return + return @property @reader diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index 21aa3c1df..408c15a84 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -141,13 +141,14 @@ def is_async(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: from codegen.sdk.core.class_definition import Class for symbol in self.valid_symbol_names: if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte): - return symbol - return super().resolve_name(name, start_byte) + yield symbol + return + yield from super().resolve_name(name, start_byte) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py new file mode 100644 index 000000000..a11990908 --- /dev/null +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from codegen.sdk.core.statements.statement import Statement + + +class ConditionalBlock(Statement, ABC): + """An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect.""" + + @property + @abstractmethod + def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: + """Should return all other "branches" that might be executed instead.""" + + @property + def end_byte_for_condition_block(self) -> int: + return self.end_byte diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index f59037144..22ae37f51 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1003,10 +1003,11 @@ def viz(self) -> VizNode: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.parent is not None: - return self.parent.resolve_name(name, start_byte or self.start_byte) - return self.file.resolve_name(name, start_byte or self.start_byte) + yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict) + else: + yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/statements/catch_statement.py b/src/codegen/sdk/core/statements/catch_statement.py index e9e96fa09..6d7b36071 100644 --- a/src/codegen/sdk/core/statements/catch_statement.py +++ b/src/codegen/sdk/core/statements/catch_statement.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -17,7 +18,7 @@ @apidoc -class CatchStatement(BlockStatement[Parent], Generic[Parent]): +class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation catch clause. Attributes: diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index e6c6bc4b4..d884a52d0 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -12,6 +12,8 @@ from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Generator + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.expressions import Expression from codegen.sdk.core.import_resolution import Import, WildcardImport @@ -36,19 +38,23 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.item and isinstance(self.iterable, Chainable): if start_byte is None or start_byte > self.iterable.end_byte: if name == self.item: for frame in self.iterable.resolved_type_frames: if frame.generics: - return next(iter(frame.generics.values())) - return frame.top.node + yield next(iter(frame.generics.values())) + return + yield frame.top.node + return elif isinstance(self.item, Collection): for idx, item in enumerate(self.item): if item == name: for frame in self.iterable.resolved_type_frames: if frame.generics and len(frame.generics) > idx: - return list(frame.generics.values())[idx] - return frame.top.node - return super().resolve_name(name, start_byte) + yield list(frame.generics.values())[idx] + return + yield frame.top.node + return + yield from super().resolve_name(name, start_byte) diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index 31e2fcbe8..e3becc13c 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -8,11 +8,14 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.function import Function +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.statement import Statement, StatementType from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Sequence + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.expressions import Expression @@ -26,7 +29,7 @@ @apidoc -class IfBlockStatement(Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): +class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): """Abstract representation of the if/elif/else if/else statement block. For example, if there is a code block like: @@ -271,3 +274,26 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - self.remove_byte_range(self.ts_node.start_byte, remove_end) else: self.remove() + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.is_if_statement: + return self._main_if_block.alternative_blocks + elif self.is_elif_statement: + main = self._main_if_block + statements = [main] + if main.else_statement: + statements.append(main.else_statement) + for statement in main.elif_statements: + if statement != self: + statements.append(statement) + return statements + else: + main = self._main_if_block + return [main, *main.elif_statements] + + @property + def end_byte_for_condition_block(self) -> int: + if self.is_if_statement: + return self.consequence_block.end_byte + return self.end_byte diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py index 3ebafb57e..d293ad034 100644 --- a/src/codegen/sdk/core/statements/switch_case.py +++ b/src/codegen/sdk/core/statements/switch_case.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -18,7 +19,7 @@ @apidoc -class SwitchCase(BlockStatement[Parent], Generic[Parent]): +class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation for a switch case. Attributes: @@ -34,3 +35,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if self.condition: self.condition._compute_dependencies(usage_type, dest) super()._compute_dependencies(usage_type, dest) + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py index 1371a2d76..177ddde68 100644 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ b/src/codegen/sdk/core/statements/try_catch_statement.py @@ -3,6 +3,7 @@ from abc import ABC from typing import TYPE_CHECKING, Generic, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_block import HasBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.core.statements.statement import StatementType @@ -16,7 +17,7 @@ @apidoc -class TryCatchStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): +class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): """Abstract representation of the try catch statement block. Attributes: diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index 02a9dd55b..77d7e623d 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -119,15 +121,17 @@ def is_class_method(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.is_method: if not self.is_static_method: if len(self.parameters.symbols) > 0: if name == self.parameters[0].name: - return self.parent_class + yield self.parent_class + return if name == "super()": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @noapidoc @commiter diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index f8066c583..5c2a1f640 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -211,7 +211,11 @@ def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) """ for resolve_path in resolve_paths: filepath_new: str = os.path.join(resolve_path, filepath) - if file := self.ctx.get_file(filepath_new): + try: + file = self.ctx.get_file(filepath_new) + except AssertionError as e: + file = None + if file: return file return None diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py index 3bbea1b46..9ebee3f3f 100644 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ b/src/codegen/sdk/python/statements/catch_statement.py @@ -11,6 +11,7 @@ from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId @@ -26,3 +27,7 @@ class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.children[0] + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent] diff --git a/src/codegen/sdk/python/statements/if_block_statement.py b/src/codegen/sdk/python/statements/if_block_statement.py index 54585b9e7..dc73b21dd 100644 --- a/src/codegen/sdk/python/statements/if_block_statement.py +++ b/src/codegen/sdk/python/statements/if_block_statement.py @@ -14,7 +14,6 @@ from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - Parent = TypeVar("Parent", bound="PyCodeBlock") diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index 69528fbba..d5e1298fc 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.python.statements.match_statement import PyMatchStatement @@ -20,3 +21,7 @@ class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement) def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") + + @property + def other_possible_blocks(self) -> list["ConditionalBlock"]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/python/statements/match_statement.py b/src/codegen/sdk/python/statements/match_statement.py index 804ff2029..59f01164c 100644 --- a/src/codegen/sdk/python/statements/match_statement.py +++ b/src/codegen/sdk/python/statements/match_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.children_by_field_name("alternative"): - self.cases.append(PyMatchCase(node, file_node_id, ctx, self.parent, self.index)) + self.cases.append(PyMatchCase(node, file_node_id, ctx, self, self.index)) diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py index c4a4827b3..b54051f96 100644 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ b/src/codegen/sdk/python/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -96,3 +99,14 @@ def nested_code_blocks(self) -> list[PyCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + return self.except_clauses + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py index 44703749c..36c499358 100644 --- a/src/codegen/sdk/typescript/export.py +++ b/src/codegen/sdk/typescript/export.py @@ -204,7 +204,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if frame.parent_frame: frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) elif self._exported_symbol: - if not self.resolve_name(self._exported_symbol.source): + if not next(self.resolve_name(self._exported_symbol.source), None): self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) elif self.value: self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) @@ -218,7 +218,7 @@ def compute_export_dependencies(self) -> None: self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) elif self._exported_symbol is not None: symbol_name = self._exported_symbol.source - if (used_node := self.resolve_name(symbol_name)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): + if (used_node := next(self.resolve_name(symbol_name), None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) elif self.value is not None: if isinstance(self.value, Chainable): diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index a7be7b28f..5882bec74 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -358,7 +360,7 @@ def arrow_to_named(self, name: str | None = None) -> None: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: """Resolves the name of a symbol in the function. This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. @@ -367,14 +369,16 @@ def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Imp Args: name (str): The name of the symbol to resolve. start_byte (int | None): The start byte of the symbol to resolve. + strict (bool): If True considers candidates that don't satisfy start byte if none do. Returns: - Symbol | Import | WildcardImport | None: The resolved symbol, import, or wildcard import, or None if not found. + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import, or None if not found. """ if self.is_method: if name == "this": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @staticmethod def is_valid_node(node: TSNode) -> bool: diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py index c6dc10bae..ed46d2efc 100644 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ b/src/codegen/sdk/typescript/statements/catch_statement.py @@ -10,10 +10,10 @@ from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - Parent = TypeVar("Parent", bound="TSCodeBlock") @@ -29,3 +29,7 @@ class TSCatchStatement(CatchStatement[Parent], TSBlockStatement, Generic[Parent] def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("parameter") + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py index 914bde227..0dbec180f 100644 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ b/src/codegen/sdk/typescript/statements/switch_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent)) + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self)) diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index aa24178d2..8f499da04 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -91,3 +94,17 @@ def nested_code_blocks(self) -> list[TSCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.catch: + return [self.catch] + else: + return [] + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 23a3a7e6a..22f8af23f 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -126,3 +126,143 @@ def foo(): assert len(alt_blocks[2].alternative_blocks) == 0 assert len(alt_blocks[2].elif_statements) == 0 assert alt_blocks[2].else_statement is None + + +def test_if_else_reassigment_handling(tmpdir) -> None: + content = """ + + if True: + PYSPARK = True + elif False: + PYSPARK = False + else: + PYSPARK = None + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_function(tmpdir) -> None: + content = """ + if True: + def foo(): + print('t') + elif False: + def foo(): + print('t') + else: + def foo(): + print('t') + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + funct_call = file.function_calls[3] + for funct in file.functions: + usage = funct.usages[0] + assert usage.match == funct_call + + +def test_if_else_reassigment_handling_inside_func(tmpdir) -> None: + content = """ + def foo(a): + a = 1 + if xyz: + b = 1 + else: + b = 2 + f(a) # a resolves to 1 name + f(b) # b resolves to 2 possible names + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + assert foo + assert len(foo.parameters[0].usages) == 0 + funct_call_a = foo.function_calls[0].args[0] + funct_call_b = foo.function_calls[1] + for symbol in file.symbols(True): + if symbol.name == "a": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_a + elif symbol.name == "b": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_b + + +def test_if_else_reassigment_handling_partial_if(tmpdir) -> None: + content = """ + PYSPARK = "TEST" + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_double(tmpdir) -> None: + content = """ + if False: + PYSPARK = "TEST1" + elif True: + PYSPARK = "TEST2" + + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: + content = """ + if True: + PYSPARK = True + elif None: + PYSPARK = False + print(PYSPARK) + + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + first = file.symbols[0] + second = file.symbols[1] + assert len(first.usages) == 0 + assert second.usages[0].match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py index 0a1928551..76bb5d0f4 100644 --- a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py +++ b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py @@ -75,3 +75,23 @@ def risky(): assert not file.function_calls[0].is_wrapped_in(TryCatchStatement) assert file.function_calls[1].is_wrapped_in(TryCatchStatement) assert file.function_calls[2].is_wrapped_in(TryCatchStatement) + + +def test_try_except_reassigment_handling(tmpdir) -> None: + content = """ + try: + PYSPARK = True # This gets removed even though there is a later use + except ImportError: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py index 3a07f1d81..be972cffd 100644 --- a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py +++ b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py @@ -53,3 +53,27 @@ def risky(): assert len(dependencies) == 1 global_var = file.get_global_var("risky_var") assert dependencies[0] == global_var + + +def test_match_reassigment_handling(tmpdir) -> None: + content = """ +filter = 1 +match filter: + case 1: + PYSPARK=True + case 2: + PYSPARK=False + case _: + PYSPARK=None + +print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols[1:]: + usage = symb.usages[0] + assert usage.match == pyspark_arg From 2d256543306cd93fcfc6faecc5d059c174775cf0 Mon Sep 17 00:00:00 2001 From: Edward Li Date: Wed, 12 Mar 2025 13:35:16 -0700 Subject: [PATCH 10/10] Add JWT as explicit dependency (#810) # Motivation # Content # Testing # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --- pyproject.toml | 1 + uv.lock | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 441fe4954..1b5e2e607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dependencies = [ "colorlog>=6.9.0", "langsmith", "langchain-xai>=0.2.1", + "jwt>=1.3.1", ] license = { text = "Apache-2.0" } diff --git a/uv.lock b/uv.lock index a5a858c41..f5a64b7a0 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.12, <3.14" resolution-markers = [ "python_full_version >= '3.12.4'", @@ -549,6 +548,7 @@ dependencies = [ { name = "hatchling" }, { name = "httpx" }, { name = "humanize" }, + { name = "jwt" }, { name = "langchain", extra = ["openai"] }, { name = "langchain-anthropic" }, { name = "langchain-core" }, @@ -679,6 +679,7 @@ requires-dist = [ { name = "hatchling", specifier = ">=1.25.0" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "humanize", specifier = ">=4.10.0,<5.0.0" }, + { name = "jwt" }, { name = "langchain", extras = ["openai"] }, { name = "langchain-anthropic", specifier = ">=0.3.7" }, { name = "langchain-core" }, @@ -742,7 +743,6 @@ requires-dist = [ { name = "wrapt", specifier = ">=1.16.0,<2.0.0" }, { name = "xmltodict", specifier = ">=0.13.0,<1.0.0" }, ] -provides-extras = ["lsp", "types"] [package.metadata.requires-dev] dev = [ @@ -2002,6 +2002,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700 }, ] +[[package]] +name = "jwt" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/66/1e792aef36645b96271b4d27c2a8cc9fc7bbbaf06277a849b9e1a6360e6a/jwt-1.3.1-py3-none-any.whl", hash = "sha256:61c9170f92e736b530655e75374681d4fcca9cfa8763ab42be57353b2b203494", size = 18192 }, +] + [[package]] name = "langchain" version = "0.3.20"