diff --git a/codegen-examples/examples/dict_to_schema/run.py b/codegen-examples/examples/dict_to_schema/run.py index e50482ab6..69ae8a3cf 100644 --- a/codegen-examples/examples/dict_to_schema/run.py +++ b/codegen-examples/examples/dict_to_schema/run.py @@ -84,7 +84,7 @@ def run(codebase: Codebase): # Add imports if needed if needs_imports: - file.add_import_from_import_string("from pydantic import BaseModel") + file.add_import("from pydantic import BaseModel") if file_modified: files_modified += 1 diff --git a/codegen-examples/examples/flask_to_fastapi_migration/run.py b/codegen-examples/examples/flask_to_fastapi_migration/run.py index 90db1d39b..ea8823a9a 100644 --- a/codegen-examples/examples/flask_to_fastapi_migration/run.py +++ b/codegen-examples/examples/flask_to_fastapi_migration/run.py @@ -57,7 +57,7 @@ def setup_static_files(file): print(f"šŸ“ Processing file: {file.filepath}") # Add import for StaticFiles - file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles") + file.add_import("from fastapi.staticfiles import StaticFiles") print("āœ… Added import: from fastapi.staticfiles import StaticFiles") # Add app.mount for static file handling diff --git a/codegen-examples/examples/sqlalchemy_soft_delete/README.md b/codegen-examples/examples/sqlalchemy_soft_delete/README.md index 9c2bec6ec..b8c9a22db 100644 --- a/codegen-examples/examples/sqlalchemy_soft_delete/README.md +++ b/codegen-examples/examples/sqlalchemy_soft_delete/README.md @@ -58,7 +58,7 @@ The codemod processes your codebase in several steps: ```python def ensure_and_import(file): if not any("and_" in imp.name for imp in file.imports): - file.add_import_from_import_string("from sqlalchemy import and_") + file.add_import("from sqlalchemy import and_") ``` - Automatically adds required SQLAlchemy imports (`and_`) diff --git a/codegen-examples/examples/sqlalchemy_soft_delete/run.py b/codegen-examples/examples/sqlalchemy_soft_delete/run.py index 4090bfa32..fb248e31a 100644 --- a/codegen-examples/examples/sqlalchemy_soft_delete/run.py +++ b/codegen-examples/examples/sqlalchemy_soft_delete/run.py @@ -51,7 +51,7 @@ def ensure_and_import(file): """Ensure the file has the necessary and_ import.""" if not any("and_" in imp.name for imp in file.imports): print(f"File {file.filepath} does not import and_. Adding import.") - file.add_import_from_import_string("from sqlalchemy import and_") + file.add_import("from sqlalchemy import and_") def clone_repo(repo_url: str, repo_path: Path) -> None: diff --git a/codegen-examples/examples/sqlalchemy_type_annotations/run.py b/codegen-examples/examples/sqlalchemy_type_annotations/run.py index 96574152d..fdfcf5a9a 100644 --- a/codegen-examples/examples/sqlalchemy_type_annotations/run.py +++ b/codegen-examples/examples/sqlalchemy_type_annotations/run.py @@ -100,16 +100,16 @@ def run(codebase: Codebase): # Add necessary imports if not cls.file.has_import("Mapped"): - cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n") + cls.file.add_import("from sqlalchemy.orm import Mapped\n") if "Optional" in new_type and not cls.file.has_import("Optional"): - cls.file.add_import_from_import_string("from typing import Optional\n") + cls.file.add_import("from typing import Optional\n") if "Decimal" in new_type and not cls.file.has_import("Decimal"): - cls.file.add_import_from_import_string("from decimal import Decimal\n") + cls.file.add_import("from decimal import Decimal\n") if "datetime" in new_type and not cls.file.has_import("datetime"): - cls.file.add_import_from_import_string("from datetime import datetime\n") + cls.file.add_import("from datetime import datetime\n") if class_modified: classes_modified += 1 diff --git a/codegen-examples/examples/unittest_to_pytest/run.py b/codegen-examples/examples/unittest_to_pytest/run.py index b4e32a55d..339b583b9 100644 --- a/codegen-examples/examples/unittest_to_pytest/run.py +++ b/codegen-examples/examples/unittest_to_pytest/run.py @@ -24,7 +24,7 @@ def convert_to_pytest_fixtures(file): print(f"šŸ” Processing file: {file.filepath}") if not any(imp.name == "pytest" for imp in file.imports): - file.add_import_from_import_string("import pytest") + file.add_import("import pytest") print(f"āž• Added pytest import to {file.filepath}") for cls in file.classes: diff --git a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md index 7d30ab454..4ec033802 100644 --- a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md +++ b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md @@ -25,7 +25,7 @@ The script automates the entire migration process in a few key steps: ```python import_str = "import { useQuery, useSuspenseQueries } from '@tanstack/react-query'" - file.add_import_from_import_string(import_str) + file.add_import(import_str) ``` - Uses Codegen's import analysis to add required imports diff --git a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py index 392f741eb..0804c7123 100644 --- a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py +++ b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py @@ -26,7 +26,7 @@ def run(codebase: Codebase): print(f"Processing {file.filepath}") # Add the import statement - file.add_import_from_import_string(import_str) + file.add_import(import_str) file_modified = False # Iterate through all functions in the file diff --git a/docs/building-with-codegen/imports.mdx b/docs/building-with-codegen/imports.mdx index 95ecff990..8a707365d 100644 --- a/docs/building-with-codegen/imports.mdx +++ b/docs/building-with-codegen/imports.mdx @@ -120,7 +120,7 @@ for module, imports in module_imports.items(): if len(imports) > 1: # Create combined import symbols = [imp.name for imp in imports] - file.add_import_from_import_string( + file.add_import( f"import {{ {', '.join(symbols)} }} from '{module}'" ) # Remove old imports diff --git a/docs/building-with-codegen/react-and-jsx.mdx b/docs/building-with-codegen/react-and-jsx.mdx index 395a16dd6..1784c8a41 100644 --- a/docs/building-with-codegen/react-and-jsx.mdx +++ b/docs/building-with-codegen/react-and-jsx.mdx @@ -136,5 +136,5 @@ for function in codebase.functions: # Add import if needed if not file.has_import("NewComponent"): - file.add_symbol_import(new_component) + file.add_import(new_component) ``` diff --git a/docs/tutorials/flask-to-fastapi.mdx b/docs/tutorials/flask-to-fastapi.mdx index ae72e8a9f..e8076ceeb 100644 --- a/docs/tutorials/flask-to-fastapi.mdx +++ b/docs/tutorials/flask-to-fastapi.mdx @@ -119,7 +119,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi ```python # Add StaticFiles import -file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles") +file.add_import("from fastapi.staticfiles import StaticFiles") # Mount static directory file.add_symbol_from_source( diff --git a/docs/tutorials/modularity.mdx b/docs/tutorials/modularity.mdx index 6923b3471..84c55835d 100644 --- a/docs/tutorials/modularity.mdx +++ b/docs/tutorials/modularity.mdx @@ -116,17 +116,17 @@ def organize_file_imports(file): # Add imports back in organized groups if std_lib_imports: for imp in std_lib_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline if third_party_imports: for imp in third_party_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline if local_imports: for imp in local_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) # Organize imports in all files for file in codebase.files: diff --git a/docs/tutorials/react-modernization.mdx b/docs/tutorials/react-modernization.mdx index a4036cbea..170999c5a 100644 --- a/docs/tutorials/react-modernization.mdx +++ b/docs/tutorials/react-modernization.mdx @@ -82,7 +82,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) = # Add required imports file = class_def.file if not any("useState" in imp.source for imp in file.imports): - file.add_import_from_import_string("import { useState, useEffect } from 'react';") + file.add_import("import { useState, useEffect } from 'react';") ``` ## Migrating to Modern Hooks @@ -100,7 +100,7 @@ for function in codebase.functions: # Convert withRouter to useNavigate if call.name == "withRouter": # Add useNavigate import - function.file.add_import_from_import_string( + function.file.add_import( "import { useNavigate } from 'react-router-dom';" ) # Add navigate hook diff --git a/src/codegen/cli/mcp/resources/system_prompt.py b/src/codegen/cli/mcp/resources/system_prompt.py index 9535570ab..9c7e23c6b 100644 --- a/src/codegen/cli/mcp/resources/system_prompt.py +++ b/src/codegen/cli/mcp/resources/system_prompt.py @@ -2909,7 +2909,7 @@ def validate_data(data: dict) -> bool: if len(imports) > 1: # Create combined import symbols = [imp.name for imp in imports] - file.add_import_from_import_string( + file.add_import( f"import {{ {', '.join(symbols)} }} from '{module}'" ) # Remove old imports @@ -5180,7 +5180,7 @@ def build_graph(func, depth=0): # Add import if needed if not file.has_import("NewComponent"): - file.add_symbol_import(new_component) + file.add_import(new_component) ``` @@ -7316,17 +7316,17 @@ def organize_file_imports(file): # Add imports back in organized groups if std_lib_imports: for imp in std_lib_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline if third_party_imports: for imp in third_party_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline if local_imports: for imp in local_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) # Organize imports in all files for file in codebase.files: @@ -8593,7 +8593,7 @@ class FeatureFlags: # Add required imports file = class_def.file if not any("useState" in imp.source for imp in file.imports): - file.add_import_from_import_string("import { useState, useEffect } from 'react';") + file.add_import("import { useState, useEffect } from 'react';") ``` ## Migrating to Modern Hooks @@ -8611,7 +8611,7 @@ class FeatureFlags: # Convert withRouter to useNavigate if call.name == "withRouter": # Add useNavigate import - function.file.add_import_from_import_string( + function.file.add_import( "import { useNavigate } from 'react-router-dom';" ) # Add navigate hook @@ -9813,7 +9813,7 @@ def create_user(): ```python # Add StaticFiles import -file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles") +file.add_import("from fastapi.staticfiles import StaticFiles") # Mount static directory file.add_symbol_from_source( diff --git a/src/codegen/sdk/core/class_definition.py b/src/codegen/sdk/core/class_definition.py index 755d46f9b..bbf2682ab 100644 --- a/src/codegen/sdk/core/class_definition.py +++ b/src/codegen/sdk/core/class_definition.py @@ -378,9 +378,9 @@ def add_attribute(self, attribute: Attribute, include_dependencies: bool = False file = self.file for d in deps: if isinstance(d, Import): - file.add_symbol_import(d.imported_symbol) + file.add_import(d.imported_symbol) elif isinstance(d, Symbol): - file.add_symbol_import(d) + file.add_import(d) @property @noapidoc diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index b282942ae..8ad9e1385 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -944,62 +944,56 @@ def update_filepath(self, new_filepath: str) -> None: imp.set_import_module(new_module_name) @writer - def add_symbol_import( - self, - symbol: Symbol, - alias: str | None = None, - import_type: ImportType = ImportType.UNKNOWN, - is_type_import: bool = False, - ) -> Import | None: - """Adds an import to a file for a given symbol. - - This method adds an import statement to the file for a specified symbol. If an import for the - symbol already exists, it returns the existing import instead of creating a new one. + def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None: + """Adds an import to the file. - Args: - symbol (Symbol): The symbol to import. - alias (str | None): Optional alias for the imported symbol. Defaults to None. - import_type (ImportType): The type of import to use. Defaults to ImportType.UNKNOWN. - is_type_import (bool): Whether this is a type-only import. Defaults to False. - - Returns: - Import | None: The existing import for the symbol or None if it was added. - """ - imports = self.imports - match = next((x for x in imports if x.imported_symbol == symbol), None) - if match: - return match - - import_string = symbol.get_import_string(alias, import_type=import_type, is_type_import=is_type_import) - self.add_import_from_import_string(import_string) - - @writer(commit=False) - def add_import_from_import_string(self, import_string: str) -> None: - """Adds import to the file from a string representation of an import statement. - - This method adds a new import statement to the file based on its string representation. + This method adds an import statement to the file. It can handle both string imports and symbol imports. If the import already exists in the file, or is pending to be added, it won't be added again. If there are existing imports, the new import will be added before the first import, otherwise it will be added at the beginning of the file. Args: - import_string (str): The string representation of the import statement to add. + imp (Symbol | str): Either a Symbol to import or a string representation of an import statement. + alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None. + import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN. + is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False. Returns: - None + Import | None: The existing import for the symbol if found, otherwise None. """ - if any(import_string.strip() in imp.source for imp in self.imports): - return + # Handle Symbol imports + if isinstance(imp, str): + # Handle string imports + import_string = imp + # Check for duplicate imports + if any(import_string.strip() in imp.source for imp in self.imports): + return None + else: + # Check for existing imports of this symbol + imports = self.imports + match = next((x for x in imports if x.imported_symbol == imp), None) + if match: + return match + + # Convert symbol to import string + import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import) + if import_string.strip() in self._pending_imports: # Don't add the import string if it will already be added by another symbol - return + return None + + # Add to pending imports and setup undo self._pending_imports.add(import_string.strip()) self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear()) + + # Insert the import at the appropriate location if self.imports: self.imports[0].insert_before(import_string, priority=1) else: self.insert_before(import_string, priority=1) + return None + @writer def add_symbol_from_source(self, source: str) -> None: """Adds a symbol to a file from a string representation. diff --git a/src/codegen/sdk/core/symbol.py b/src/codegen/sdk/core/symbol.py index 559a5cd58..cc0238b45 100644 --- a/src/codegen/sdk/core/symbol.py +++ b/src/codegen/sdk/core/symbol.py @@ -329,19 +329,19 @@ def _move_to_file( # =====[ Imports - copy over ]===== elif isinstance(dep, Import): if dep.imported_symbol: - file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source) + file.add_import(imp=dep.imported_symbol, alias=dep.alias.source) else: - file.add_import_from_import_string(dep.source) + file.add_import(imp=dep.source) else: for dep in self.dependencies: # =====[ Symbols - add back edge ]===== if isinstance(dep, Symbol) and dep.is_top_level: - file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False) + file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False) elif isinstance(dep, Import): if dep.imported_symbol: - file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source) + file.add_import(imp=dep.imported_symbol, alias=dep.alias.source) else: - file.add_import_from_import_string(dep.source) + file.add_import(imp=dep.source) # =====[ Make a new symbol in the new file ]===== file.add_symbol(self) @@ -364,7 +364,7 @@ def _move_to_file( # Here, we will add a "back edge" to the old file importing the symbol elif strategy == "add_back_edge": if is_used_in_file or any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages): - self.file.add_import_from_import_string(import_line) + self.file.add_import(imp=import_line) # Delete the original symbol self.remove() @@ -374,7 +374,7 @@ def _move_to_file( for usage in self.usages: if isinstance(usage.usage_symbol, Import) and usage.usage_symbol.file != file: # Add updated import - usage.usage_symbol.file.add_import_from_import_string(import_line) + usage.usage_symbol.file.add_import(import_line) usage.usage_symbol.remove() elif usage.usage_type == UsageType.CHAINED: # Update all previous usages of import * to the new import name @@ -383,11 +383,11 @@ def _move_to_file( usage.match.get_name().edit(self.name) if isinstance(usage.match, ChainedAttribute): usage.match.edit(self.name) - usage.usage_symbol.file.add_import_from_import_string(import_line) + usage.usage_symbol.file.add_import(imp=import_line) # Add the import to the original file if is_used_in_file: - self.file.add_import_from_import_string(import_line) + self.file.add_import(imp=import_line) # Delete the original symbol self.remove() diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index 3c92feaef..3b1fc9f93 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -5,6 +5,7 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interface import Interface +from codegen.sdk.core.symbol import Symbol from codegen.sdk.enums import ImportType from codegen.sdk.extensions.utils import cached_property from codegen.sdk.python import PyAssignment @@ -20,7 +21,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_context import CodebaseContext - from codegen.sdk.core.import_resolution import WildcardImport + from codegen.sdk.core.import_resolution import Import, WildcardImport from codegen.sdk.python.symbol import PySymbol @@ -119,7 +120,7 @@ def get_import_insert_index(self, import_string) -> int | None: The function determines the optimal position for inserting a new import statement, following Python's import ordering conventions. Future imports are placed at the top of the file, followed by all other imports. - Args: + Args:z import_string (str): The import statement to be inserted. Returns: @@ -146,28 +147,57 @@ def get_import_insert_index(self, import_string) -> int | None: #################################################################################################################### @writer - def add_import_from_import_string(self, import_string: str) -> None: - """Adds an import statement to the file from a string representation. + def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None: + """Adds an import to the file. - This method adds a new import statement to the file, handling placement based on existing imports. - Future imports are placed at the top of the file, followed by regular imports. + This method adds an import statement to the file. It can handle both string imports and symbol imports. + If the import already exists in the file, or is pending to be added, it won't be added again. + Future imports are placed at the top, followed by regular imports. Args: - import_string (str): The string representation of the import statement to add (e.g., 'from module import symbol'). + imp (Symbol | str): Either a Symbol to import or a string representation of an import statement. + alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None. + import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN. + is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False. Returns: - None: This function modifies the file in place. + Import | None: The existing import for the symbol if found, otherwise None. """ + # Handle Symbol imports + if isinstance(imp, Symbol): + imports = self.imports + match = next((x for x in imports if x.imported_symbol == imp), None) + if match: + return match + + # Convert symbol to import string + import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import) + else: + # Handle string imports + import_string = str(imp) + + # Check for duplicate imports + if any(import_string.strip() in str(imp.source) for imp in self.imports): + return None + if import_string.strip() in self._pending_imports: + return None + + # Add to pending imports + self._pending_imports.add(import_string.strip()) + self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear()) + + # Insert at correct location if self.imports: import_insert_index = self.get_import_insert_index(import_string) or 0 if import_insert_index < len(self.imports): self.imports[import_insert_index].insert_before(import_string, priority=1) else: - # If import_insert_index is out of bounds, do insert after the last import self.imports[-1].insert_after(import_string, priority=1) else: self.insert_before(import_string, priority=1) + return None + @noapidoc def remove_unused_exports(self) -> None: """Removes unused exports from the file. NO-OP for python""" diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt index 0a0cff0e9..ad5007004 100644 --- a/src/codegen/sdk/system-prompt.txt +++ b/src/codegen/sdk/system-prompt.txt @@ -29,7 +29,7 @@ codebase.commit() -Codegen handles complex refactors while maintaining correctness, enabling a broad set of advanced code manipulation programs. +Codegen handles complex refactors while maintaining correctness, enabling a broad set of advanced code manipulation programs. Codegen works with both Python and Typescript/JSX codebases. Learn more about language support [here](/building-with-codegen/language-support). @@ -492,7 +492,7 @@ Let's walk through a minimal example of using Codegen in a project: ```bash codegen init ``` - + This creates a `.codegen/` directory with: ```bash .codegen/ @@ -560,7 +560,7 @@ Let's walk through a minimal example of using Codegen in a project: For more help, join our [community Slack](/introduction/community) or check the [FAQ](/introduction/faq). - + --- title: "Using Codegen in Your IDE" @@ -589,7 +589,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your ```bash .codegen/.venv/bin/python ``` - + Alternatively, create a `.vscode/settings.json`: ```json { @@ -611,7 +611,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your .codegen/.venv/bin/python ``` - + @@ -1156,8 +1156,8 @@ iconType: "solid" - Yes - [by design](/introduction/guiding-principles#python-first-composability). - + Yes - [by design](/introduction/guiding-principles#python-first-composability). + Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools. - Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. + Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects. ## Accessing Code @@ -2923,7 +2923,7 @@ for module, imports in module_imports.items(): if len(imports) > 1: # Create combined import symbols = [imp.name for imp in imports] - file.add_import_from_import_string( + file.add_import( f"import {{ {', '.join(symbols)} }} from '{module}'" ) # Remove old imports @@ -2933,7 +2933,7 @@ for module, imports in module_imports.items(): Always check if imports resolve to external modules before modification to avoid breaking third-party package imports. - + ## Import Statements vs Imports @@ -3135,7 +3135,7 @@ for exp in file.exports: # Get original and current symbols current = exp.exported_symbol original = exp.resolved_symbol - + print(f"Re-exporting {original.name} from {exp.from_file.filepath}") print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}") ``` @@ -3185,7 +3185,7 @@ for from_file, exports in file_exports.items(): When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported. - + --- title: "Inheritable Behaviors" @@ -3675,9 +3675,9 @@ If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in flowchart LR B(BaseClass) - - - + + + A(MyClass) B ---| used by |A A ---|depends on |B @@ -3846,7 +3846,7 @@ class A: def method_a(self): pass class B(A): - def method_b(self): + def method_b(self): self.method_a() class C(B): @@ -4736,7 +4736,7 @@ for attr in class_def.attributes: # Each attribute has an assignment property attr_type = attr.assignment.type # -> TypeAnnotation print(f"{attr.name}: {attr_type.source}") # e.g. "x: int" - + # Set attribute type attr.assignment.set_type("int") @@ -4753,7 +4753,7 @@ Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as c ```python # Get union type -union_type = function.return_type # -> A | B +union_type = function.return_type # -> A | B print(union_type.symbols) # ["A", "B"] # Add/remove options @@ -5271,7 +5271,7 @@ for function in codebase.functions: # Add import if needed if not file.has_import("NewComponent"): - file.add_symbol_import(new_component) + file.add_import(new_component) ``` @@ -5604,13 +5604,13 @@ Here's an example of using flags during code analysis: ```python def analyze_codebase(codebase): - for function in codebase.functions: + for function in codebase.functions: # Check documentation if not function.docstring: function.flag( message="Missing docstring", ) - + # Check error handling if function.is_async and not function.has_try_catch: function.flag( @@ -6320,7 +6320,7 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Update API calls, handle breaking changes, and manage bulk updates across your codebase. - Convert Flask applications to FastAPI, updating routes and dependencies. - Migrate Python 2 code to Python 3, updating syntax and modernizing APIs. @@ -6353,9 +6353,9 @@ Explore our tutorials to learn how to use Codegen for various code transformatio > Restructure files, enforce naming conventions, and improve project layout. - Split large files, extract shared logic, and manage dependencies. @@ -6453,7 +6453,7 @@ The agent has access to powerful code viewing and manipulation tools powered by - `CreateFileTool`: Create new files - `DeleteFileTool`: Delete files - `RenameFileTool`: Rename files -- `EditFileTool`: Edit files +- `EditFileTool`: Edit files @@ -6960,7 +6960,6 @@ Be explicit about the changes, produce a short summary, and point out possible i Focus on facts and technical details, using code snippets where helpful. """ result = agent.run(prompt) - # Clean up the temporary comment comment.delete() ``` @@ -7046,7 +7045,7 @@ While this example demonstrates a basic PR review bot, you can extend it to: > Understand code review patterns and best practices. - + --- title: "Deep Code Research with AI" @@ -7174,21 +7173,21 @@ def research(repo_name: Optional[str] = None, query: Optional[str] = None): """Start a code research session.""" # Initialize codebase codebase = initialize_codebase(repo_name) - + # Create and run the agent agent = create_research_agent(codebase) - + # Main research loop while True: if not query: query = Prompt.ask("[bold cyan]Research query[/bold cyan]") - + result = agent.invoke( {"input": query}, config={"configurable": {"thread_id": 1}} ) console.print(Markdown(result["messages"][-1].content)) - + query = None # Clear for next iteration ``` @@ -7236,7 +7235,7 @@ class CustomAnalysisTool(BaseTool): """Custom tool for specialized code analysis.""" name = "custom_analysis" description = "Performs specialized code analysis" - + def _run(self, query: str) -> str: # Custom analysis logic return results @@ -7514,7 +7513,7 @@ from codegen import Codebase # Initialize codebase codebase = Codebase("path/to/posthog/") -# Create a directed graph for representing call relationships +# Create a directed graph for representing call relationships G = nx.DiGraph() # Configuration flags @@ -7536,7 +7535,7 @@ We'll create a function that will recursively traverse the call trace of a funct ```python def create_downstream_call_trace(src_func: Function, depth: int = 0): """Creates call graph by recursively traversing function calls - + Args: src_func (Function): Starting function for call graph depth (int): Current recursion depth @@ -7544,7 +7543,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Prevent infinite recursion if MAX_DEPTH <= depth: return - + # External modules are not functions if isinstance(src_func, ExternalModule): return @@ -7554,12 +7553,12 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): # Skip self-recursive calls if call.name == src_func.name: continue - + # Get called function definition func = call.function_definition if not func: continue - + # Apply configured filters if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS: continue @@ -7573,7 +7572,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0): func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name # Add node and edge with metadata - G.add_node(func, name=func_name, + G.add_node(func, name=func_name, color=COLOR_PALETTE.get(func.__class__.__name__)) G.add_edge(src_func, func, **generate_edge_meta(call)) @@ -7588,10 +7587,10 @@ We can enrich our edges with metadata about the function calls: ```python def generate_edge_meta(call: FunctionCall) -> dict: """Generate metadata for call graph edges - + Args: call (FunctionCall): Function call information - + Returns: dict: Edge metadata including name and location """ @@ -7610,8 +7609,8 @@ Finally, we can visualize our call graph starting from a specific function: target_class = codebase.get_class('SharingConfigurationViewSet') target_method = target_class.get_method('patch') -# Add root node -G.add_node(target_method, +# Add root node +G.add_node(target_method, name=f"{target_class.name}.{target_method.name}", color=COLOR_PALETTE["StartFunction"]) @@ -7661,7 +7660,7 @@ The core function for building our dependency graph: ```python def create_dependencies_visualization(symbol: Symbol, depth: int = 0): """Creates visualization of symbol dependencies - + Args: symbol (Symbol): Starting symbol to analyze depth (int): Current recursion depth @@ -7669,11 +7668,11 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each dependency for dep in symbol.dependencies: dep_symbol = None - + # Handle different dependency types if isinstance(dep, Symbol): # Direct symbol reference @@ -7684,13 +7683,13 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0): if dep_symbol: # Add node with appropriate styling - G.add_node(dep_symbol, - color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, + G.add_node(dep_symbol, + color=COLOR_PALETTE.get(dep_symbol.__class__.__name__, "#f694ff")) - + # Add dependency relationship G.add_edge(symbol, dep_symbol) - + # Recurse unless it's a class (avoid complexity) if not isinstance(dep_symbol, PyClass): create_dependencies_visualization(dep_symbol, depth + 1) @@ -7702,7 +7701,7 @@ Finally, we can visualize our dependency graph starting from a specific symbol: # Get target symbol target_func = codebase.get_function("get_query_runner") -# Add root node +# Add root node G.add_node(target_func, color=COLOR_PALETTE["StartFunction"]) # Generate dependency graph @@ -7745,16 +7744,16 @@ HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"] def generate_edge_meta(usage: Usage) -> dict: """Generate metadata for graph edges - + Args: usage (Usage): Usage relationship information - + Returns: dict: Edge metadata including name and location """ return { "name": usage.match.source, - "file_path": usage.match.filepath, + "file_path": usage.match.filepath, "start_point": usage.match.start_point, "end_point": usage.match.end_point, "symbol_name": usage.match.__class__.__name__ @@ -7762,10 +7761,10 @@ def generate_edge_meta(usage: Usage) -> dict: def is_http_method(symbol: PySymbol) -> bool: """Check if a symbol is an HTTP endpoint method - + Args: symbol (PySymbol): Symbol to check - + Returns: bool: True if symbol is an HTTP method """ @@ -7779,7 +7778,7 @@ The main function for creating our blast radius visualization: ```python def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): """Create visualization of symbol usage relationships - + Args: symbol (PySymbol): Starting symbol to analyze depth (int): Current recursion depth @@ -7787,11 +7786,11 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Prevent excessive recursion if depth >= MAX_DEPTH: return - + # Process each usage of the symbol for usage in symbol.usages: usage_symbol = usage.usage_symbol - + # Determine node color based on type if is_http_method(usage_symbol): color = COLOR_PALETTE.get("HTTP_METHOD") @@ -7801,7 +7800,7 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0): # Add node and edge to graph G.add_node(usage_symbol, color=color) G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage)) - + # Recursively process usage symbol create_blast_radius_visualization(usage_symbol, depth + 1) ``` @@ -7952,7 +7951,7 @@ for call in old_api.call_sites: f"data={call.get_arg_by_parameter_name('input').value}", f"timeout={call.get_arg_by_parameter_name('wait').value}" ] - + # Replace the old call with the new API call.replace(f"new_process_data({', '.join(args)})") ``` @@ -7966,10 +7965,10 @@ When updating chained method calls, like database queries or builder patterns: for execute_call in codebase.function_calls: if execute_call.name != "execute": continue - + # Get the full chain chain = execute_call.call_chain - + # Example: Add .timeout() before .execute() if "timeout" not in {call.name for call in chain}: execute_call.insert_before("timeout(30)") @@ -7988,45 +7987,45 @@ Here's a comprehensive example: ```python def migrate_api_v1_to_v2(codebase): old_api = codebase.get_function("create_user_v1") - + # Document all existing call patterns call_patterns = {} for call in old_api.call_sites: args = [arg.source for arg in call.args] pattern = ", ".join(args) call_patterns[pattern] = call_patterns.get(pattern, 0) + 1 - + print("Found call patterns:") for pattern, count in call_patterns.items(): print(f" {pattern}: {count} occurrences") - + # Create new API version new_api = old_api.copy() new_api.rename("create_user_v2") - + # Update parameter types new_api.get_parameter("email").type = "EmailStr" new_api.get_parameter("role").type = "UserRole" - + # Add new required parameters new_api.add_parameter("tenant_id: UUID") - + # Update all call sites for call in old_api.call_sites: # Get current arguments email_arg = call.get_arg_by_parameter_name("email") role_arg = call.get_arg_by_parameter_name("role") - + # Build new argument list with type conversions new_args = [ f"email=EmailStr({email_arg.value})", f"role=UserRole({role_arg.value})", "tenant_id=get_current_tenant_id()" ] - + # Replace old call with new version call.replace(f"create_user_v2({', '.join(new_args)})") - + # Add deprecation notice to old version old_api.add_decorator('@deprecated("Use create_user_v2 instead")') @@ -8048,10 +8047,10 @@ migrate_api_v1_to_v2(codebase) ```python # First update parameter names param.rename("new_name") - + # Then update types param.type = "new_type" - + # Finally update call sites for call in api.call_sites: # ... update calls @@ -8061,7 +8060,7 @@ migrate_api_v1_to_v2(codebase) ```python # Add new parameter with default api.add_parameter("new_param: str = None") - + # Later make it required api.get_parameter("new_param").remove_default() ``` @@ -8076,7 +8075,7 @@ migrate_api_v1_to_v2(codebase) Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes. - + --- title: "Organizing Your Codebase" @@ -8640,16 +8639,16 @@ from collections import defaultdict # Create a graph of file dependencies def create_dependency_graph(): G = nx.DiGraph() - + for file in codebase.files: # Add node for this file G.add_node(file.filepath) - + # Add edges for each import for imp in file.imports: if imp.from_file: # Skip external imports G.add_edge(file.filepath, imp.from_file.filepath) - + return G # Create and analyze the graph @@ -8678,18 +8677,18 @@ def break_circular_dependency(cycle): # Get the first two files in the cycle file1 = codebase.get_file(cycle[0]) file2 = codebase.get_file(cycle[1]) - + # Create a shared module for common code shared_dir = "shared" if not codebase.has_directory(shared_dir): codebase.create_directory(shared_dir) - + # Find symbols used by both files shared_symbols = [] for symbol in file1.symbols: if any(usage.file == file2 for usage in symbol.usages): shared_symbols.append(symbol) - + # Move shared symbols to a new file if shared_symbols: shared_file = codebase.create_file(f"{shared_dir}/shared_types.py") @@ -8711,7 +8710,7 @@ def organize_file_imports(file): std_lib_imports = [] third_party_imports = [] local_imports = [] - + for imp in file.imports: if imp.is_standard_library: std_lib_imports.append(imp) @@ -8719,29 +8718,29 @@ def organize_file_imports(file): third_party_imports.append(imp) else: local_imports.append(imp) - + # Sort each group for group in [std_lib_imports, third_party_imports, local_imports]: group.sort(key=lambda x: x.module_name) - + # Remove all existing imports for imp in file.imports: imp.remove() - + # Add imports back in organized groups if std_lib_imports: for imp in std_lib_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if third_party_imports: for imp in third_party_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) file.insert_after_imports("") # Add newline - + if local_imports: for imp in local_imports: - file.add_import_from_import_string(imp.source) + file.add_import(imp.source) # Organize imports in all files for file in codebase.files: @@ -8757,22 +8756,22 @@ from collections import defaultdict def analyze_module_coupling(): coupling_scores = defaultdict(int) - + for file in codebase.files: # Count unique files imported from imported_files = {imp.from_file for imp in file.imports if imp.from_file} coupling_scores[file.filepath] = len(imported_files) - + # Count files that import this file - importing_files = {usage.file for symbol in file.symbols + importing_files = {usage.file for symbol in file.symbols for usage in symbol.usages if usage.file != file} coupling_scores[file.filepath] += len(importing_files) - + # Sort by coupling score - sorted_files = sorted(coupling_scores.items(), - key=lambda x: x[1], + sorted_files = sorted(coupling_scores.items(), + key=lambda x: x[1], reverse=True) - + print("\nšŸ” Module Coupling Analysis:") print("\nMost coupled files:") for filepath, score in sorted_files[:5]: @@ -8790,9 +8789,9 @@ def extract_shared_code(file, min_usages=3): # Find symbols used by multiple files for symbol in file.symbols: # Get unique files using this symbol - using_files = {usage.file for usage in symbol.usages + using_files = {usage.file for usage in symbol.usages if usage.file != file} - + if len(using_files) >= min_usages: # Create appropriate shared module module_name = determine_shared_module(symbol) @@ -8800,7 +8799,7 @@ def extract_shared_code(file, min_usages=3): shared_file = codebase.create_file(f"shared/{module_name}.py") else: shared_file = codebase.get_file(f"shared/{module_name}.py") - + # Move symbol to shared module symbol.move_to_file(shared_file, strategy="update_all_imports") @@ -8854,7 +8853,7 @@ if feature_flag_class: # Initialize usage count for all attributes for attr in feature_flag_class.attributes: feature_flag_usage[attr.name] = 0 - + # Get all usages of the FeatureFlag class for usage in feature_flag_class.usages: usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage) @@ -9599,7 +9598,7 @@ Let's break down how this works: if export.is_reexport() and export.is_default_export(): print(f" šŸ”„ Converting default export '{export.name}'") ``` - + The code identifies default exports by checking: 1. If it's a re-export (`is_reexport()`) 2. If it's a default export (`is_default_export()`) @@ -9707,7 +9706,7 @@ for file in codebase.files: print(f"✨ Fixed exports in {target_file.filepath}") -``` +``` --- title: "Creating Documentation" @@ -9796,11 +9795,11 @@ for directory in codebase.directories: # Skip test, sql and alembic directories if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']): continue - + # Get undecorated functions funcs = [f for f in directory.functions if not f.is_decorated] total = len(funcs) - + # Only analyze dirs with >10 functions if total > 10: documented = sum(1 for f in funcs if f.docstring) @@ -9815,12 +9814,12 @@ for directory in codebase.directories: if dir_stats: lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage']) path, stats = lowest_dir - + print(f"šŸ“‰ Lowest coverage directory: '{path}'") print(f" • Total functions: {stats['total']}") print(f" • Documented: {stats['documented']}") print(f" • Coverage: {stats['coverage']:.1f}%") - + # Print all directory stats for comparison print("\nšŸ“Š All directory coverage rates:") for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']): @@ -10008,7 +10007,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) = # Add required imports file = class_def.file if not any("useState" in imp.source for imp in file.imports): - file.add_import_from_import_string("import { useState, useEffect } from 'react';") + file.add_import("import { useState, useEffect } from 'react';") ``` ## Migrating to Modern Hooks @@ -10026,7 +10025,7 @@ for function in codebase.functions: # Convert withRouter to useNavigate if call.name == "withRouter": # Add useNavigate import - function.file.add_import_from_import_string( + function.file.add_import( "import { useNavigate } from 'react-router-dom';" ) # Add navigate hook @@ -10608,7 +10607,7 @@ iconType: "solid" -Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. +Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain. In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen. @@ -11244,7 +11243,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi ```python # Add StaticFiles import -file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles") +file.add_import("from fastapi.staticfiles import StaticFiles") # Mount static directory file.add_symbol_from_source( @@ -11505,10 +11504,10 @@ Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10 ```cypher Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func)) -Return path +Return path LIMIT 20 ``` - \ No newline at end of file + diff --git a/src/codegen/sdk/typescript/symbol.py b/src/codegen/sdk/typescript/symbol.py index fc41d1ee7..e3cc89828 100644 --- a/src/codegen/sdk/typescript/symbol.py +++ b/src/codegen/sdk/typescript/symbol.py @@ -283,9 +283,9 @@ def _move_to_file( # =====[ Imports - copy over ]===== elif isinstance(dep, TSImport): if dep.imported_symbol: - file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type) + file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type) else: - file.add_import_from_import_string(dep.source) + file.add_import(dep.source) else: msg = f"Unknown dependency type {type(dep)}" @@ -301,7 +301,7 @@ def _move_to_file( # =====[ Symbols - move over ]===== elif isinstance(dep, Symbol) and dep.is_top_level: - file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias)) + file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias)) if not dep.is_exported: dep.file.add_export_to_symbol(dep) @@ -310,9 +310,9 @@ def _move_to_file( # =====[ Imports - copy over ]===== elif isinstance(dep, TSImport): if dep.imported_symbol: - file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import()) + file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import()) else: - file.add_import_from_import_string(dep.source) + file.add_import(dep.source) except Exception as e: print(f"Failed to move dependencies of {self.name}: {e}") @@ -336,12 +336,12 @@ def _move_to_file( # Here, we will add a "back edge" to the old file importing the self elif strategy == "add_back_edge": if is_used_in_file: - self.file.add_import_from_import_string(import_line) + self.file.add_import(import_line) if self.is_exported: - self.file.add_import_from_import_string(f"export {{ {self.name} }}") + self.file.add_import(f"export {{ {self.name} }}") elif self.is_exported: module_name = file.name - self.file.add_import_from_import_string(f"export {{ {self.name} }} from '{module_name}'") + self.file.add_import(f"export {{ {self.name} }} from '{module_name}'") # Delete the original symbol self.remove() @@ -352,7 +352,7 @@ def _move_to_file( if isinstance(usage.usage_symbol, TSImport): # Add updated import if usage.usage_symbol.resolved_symbol is not None and usage.usage_symbol.resolved_symbol.node_type == NodeType.SYMBOL and usage.usage_symbol.resolved_symbol == self: - usage.usage_symbol.file.add_import_from_import_string(import_line) + usage.usage_symbol.file.add_import(import_line) usage.usage_symbol.remove() elif usage.usage_type == UsageType.CHAINED: # Update all previous usages of import * to the new import name @@ -361,9 +361,9 @@ def _move_to_file( usage.match.get_name().edit(self.name) if isinstance(usage.match, ChainedAttribute): usage.match.edit(self.name) - usage.usage_symbol.file.add_import_from_import_string(import_line) + usage.usage_symbol.file.add_import(import_line) if is_used_in_file: - self.file.add_import_from_import_string(import_line) + self.file.add_import(import_line) # Delete the original symbol self.remove() @@ -377,7 +377,7 @@ def _convert_proptype_to_typescript(self, prop_type: Editable, param: Parameter if prop_type.attribute.source == "node": return "T" if prop_type.attribute.source == "element": - self.file.add_import_from_import_string("import React from 'react';\n") + self.file.add_import("import React from 'react';\n") return "React.ReactElement" if prop_type.attribute.source in type_map: return type_map[prop_type.attribute.source] @@ -476,7 +476,7 @@ def convert_to_react_interface(self) -> str | None: if "PropTypes.node" in proptypes.source: generics = "" generic_name = "" - self.file.add_import_from_import_string("import React from 'react';\n") + self.file.add_import("import React from 'react';\n") interface_name = f"{component_name}Props" # Create interface definition interface_def = f"interface {interface_name}{generics} {self._convert_dict(proptypes, 1)}" diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py index 913782f1f..7476e6e8a 100644 --- a/src/codegen/sdk/utils.py +++ b/src/codegen/sdk/utils.py @@ -104,8 +104,6 @@ def find_import_node(node: TSNode) -> TSNode | None: # we only parse imports inside expressions and variable declarations - # import_nodes = [_node for _node in find_all_descendants(node, ["call_expression", "statement_block"], nested=False) if _node.type == "call_expression"] - if member_expression := find_first_descendant(node, ["member_expression"]): # there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module) descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block") diff --git a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py b/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py index cfeb7ca5f..5cb33414b 100644 --- a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py +++ b/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py @@ -48,4 +48,4 @@ def execute(self, codebase: Codebase) -> None: # Ensure the necessary import is present file = function.file if "SessionLocal" not in [imp.name for imp in file.imports]: - file.add_import_from_import_string("from app.db import SessionLocal") + file.add_import("from app.db import SessionLocal") diff --git a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py index caa40b799..de97853a6 100644 --- a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py +++ b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py @@ -56,4 +56,4 @@ def execute(self, codebase: Codebase): element.set_name("PrivateRoutesContainer") # Add the import if it doesn't exist if not file.has_import("PrivateRoutesContainer"): - file.add_symbol_import(PrivateRoutesContainer) + file.add_import(PrivateRoutesContainer) diff --git a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py b/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py index afc67d40d..a0fecf515 100644 --- a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py +++ b/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py @@ -29,7 +29,7 @@ class MySession(SessionInterface): ... That is, it deletes the attribute and adds the appropriate decorator via the `cls.add_decorator` method. - Note that `cls.file.add_import_from_import_string(import_str)` is the method used to add import for the decorator. + Note that `cls.file.add_import(import_str)` is the method used to add import for the decorator. """ language = ProgrammingLanguage.PYTHON @@ -51,7 +51,7 @@ def execute(self, codebase: Codebase) -> None: decorator_name = attr_value_to_decorator[attribute.right.source] # Import the necessary decorators required_import = f"from src.flask.sessions import {decorator_name}" - cls.file.add_import_from_import_string(required_import) + cls.file.add_import(required_import) # Add the appropriate decorator cls.add_decorator(f"@{decorator_name}") diff --git a/src/codemods/canonical/pivot_return_types/pivot_return_types.py b/src/codemods/canonical/pivot_return_types/pivot_return_types.py index 367e40ad2..aeb1cdee8 100644 --- a/src/codemods/canonical/pivot_return_types/pivot_return_types.py +++ b/src/codemods/canonical/pivot_return_types/pivot_return_types.py @@ -41,7 +41,7 @@ def execute(self, codebase: Codebase) -> None: function.set_return_type("FastStr") # Add import for 'FastStr' if it doesn't exist - function.file.add_import_from_import_string("from app.models.fast_str import FastStr") + function.file.add_import("from app.models.fast_str import FastStr") # Modify all return statements within the function for return_stmt in function.code_block.return_statements: diff --git a/src/codemods/canonical/split_large_files/split_large_files.py b/src/codemods/canonical/split_large_files/split_large_files.py index c1e3e295e..33f846421 100644 --- a/src/codemods/canonical/split_large_files/split_large_files.py +++ b/src/codemods/canonical/split_large_files/split_large_files.py @@ -46,4 +46,4 @@ def execute(self, codebase: Codebase): # Move the symbol to the new file symbol.move_to_file(new_file) # Add a back edge to the original file - file.add_symbol_import(symbol) + file.add_import(symbol) diff --git a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py b/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py index ce4d8235f..14c4e96f4 100644 --- a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py +++ b/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py @@ -60,4 +60,4 @@ def execute(self, codebase: Codebase) -> None: legacy_function.remove() # Add import of the new function - call_site.file.add_import_from_import_string(f"from settings.collections import {legacy_function.name}") + call_site.file.add_import(f"from settings.collections import {legacy_function.name}") diff --git a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py b/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py index 1a9bcb8bf..60c520986 100644 --- a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py +++ b/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py @@ -42,7 +42,7 @@ def execute(self, codebase: Codebase) -> None: class_a_param.edit("cache_config: CacheConfig") # Add import of `CacheConfig` to function definition file - function.file.add_symbol_import(class_b_symb) + function.file.add_import(class_b_symb) # Check if the function body is using `cache_config` if len(function.code_block.get_variable_usages(class_a_param.name)) > 0: diff --git a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py b/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py index 2a7699df0..d1ceea089 100644 --- a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py +++ b/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py @@ -51,5 +51,5 @@ def update_type_annotation(type: Type) -> str: new_type = update_type_annotation(parameter.type) if parameter.type != new_type: # Add the future annotations import - file.add_import_from_import_string("from __future__ import annotations\n") + file.add_import("from __future__ import annotations\n") parameter.type.edit(new_type) diff --git a/src/codemods/canonical/wrap_with_component/wrap_with_component.py b/src/codemods/canonical/wrap_with_component/wrap_with_component.py index 715b98216..b1bed4cc1 100644 --- a/src/codemods/canonical/wrap_with_component/wrap_with_component.py +++ b/src/codemods/canonical/wrap_with_component/wrap_with_component.py @@ -48,4 +48,4 @@ def execute(self, codebase: Codebase) -> None: element.edit(f"{element.source}") # Add an import for the Alert component - file.add_symbol_import(alert) + file.add_import(alert) diff --git a/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py b/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py index f5dfe151d..6a74ff2e4 100644 --- a/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py +++ b/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py @@ -141,7 +141,7 @@ def a(): autocommit = codebase.ctx._autocommit file1 = codebase.get_file(file1_name) fun = file1.get_function("a") - file1.add_import_from_import_string("import os") + file1.add_import("import os") assert fun.node_id not in autocommit._nodes if edit_block: block = fun.code_block @@ -200,7 +200,7 @@ def a(a: int): param = fun.parameters[0] assert fun.node_id not in autocommit._nodes param.edit("try_to_break_this: str") - file1.add_import_from_import_string("import os") + file1.add_import("import os") assert fun.node_id in autocommit._nodes if edit_block: block = fun.code_block @@ -230,7 +230,7 @@ def b(a: int): param = fun.parameters[0] assert fun.node_id not in autocommit._nodes param.edit("try_to_break_this: str") - file1.add_import_from_import_string("import os") + file1.add_import("import os") assert fun.node_id in autocommit._nodes block = funb.code_block block.insert_before("a", fix_indentation=True) diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_import.py b/tests/unit/codegen/sdk/python/file/test_file_add_import.py new file mode 100644 index 000000000..f0e353d81 --- /dev/null +++ b/tests/unit/codegen/sdk/python/file/test_file_add_import.py @@ -0,0 +1,276 @@ +import pytest + +from codegen.sdk.codebase.factory.get_session import get_codebase_session +from codegen.shared.enums.programming_language import ProgrammingLanguage + + +def test_file_add_symbol_import_updates_source(tmpdir) -> None: + # language=python + content1 = """ +import datetime + +def foo(): + return datetime.datetime.now() +""" + + # language=python + content2 = """ +def bar(): + return 1 +""" + with get_codebase_session(tmpdir=tmpdir, files={"file1.py": content1, "file2.py": content2}) as codebase: + file1 = codebase.get_file("file1.py") + file2 = codebase.get_file("file2.py") + + file2.add_import(file1.get_symbol("foo")) + + assert "import foo" in file2.content + + +def test_file_add_import_string_no_imports_adds_to_top(tmpdir) -> None: + # language=python + content = """ +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from sqlalchemy.orm import Session") + + file_lines = file.content.split("\n") + assert "from sqlalchemy.orm import Session" in file_lines[0] + + +def test_file_add_import_string_adds_before_first_import(tmpdir) -> None: + # language=python + content = """ +# top level comment + +# adds new import here +from typing import List + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from sqlalchemy.orm import Session") + + file_lines = file.content.split("\n") + assert "from sqlalchemy.orm import Session" in file_lines + assert file_lines.index("from sqlalchemy.orm import Session") == file_lines.index("from typing import List") - 1 + + +@pytest.mark.parametrize("sync", [True, False]) +def test_file_add_import_string_adds_remove(tmpdir, sync) -> None: + # language=python + content = """import b + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content.strip()}, sync_graph=sync) as codebase: + file = codebase.get_file("test.py") + + file.add_import("import antigravity") + file.remove() + if sync: + assert not codebase.get_file(file.filepath, optional=True) + + +def test_file_add_import_typescript_string_adds_before_first_import(tmpdir) -> None: + # language=typescript + content = """ +// top level comment + +// existing imports below +import { Something } from 'somewhere' + +function bar(): number { + return 1; +} + """ + with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: + file = codebase.get_file("test.ts") + + file.add_import("import { NewThing } from 'elsewhere'") + + file_lines = file.content.split("\n") + assert "import { NewThing } from 'elsewhere'" in file_lines + assert file_lines.index("import { NewThing } from 'elsewhere'") < file_lines.index("import { Something } from 'somewhere'") + + +def test_file_add_import_typescript_string_no_imports_adds_to_top(tmpdir) -> None: + # language=typescript + content = """ + function bar(): number { + return 1; + } + """ + with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: + file = codebase.get_file("test.ts") + + file.add_import("import { Something } from 'somewhere';") + + file_lines = file.content.split("\n") + assert "import { Something } from 'somewhere';" in file_lines[0] + + +def test_file_add_import_typescript_multiple_symbols(tmpdir) -> None: + FILE1_FILENAME = "file1.ts" + FILE2_FILENAME = "file2.ts" + + # language=typescript + FILE1_CONTENT = """ + export function foo(): string { + return 'foo'; + } + + export function bar(): string { + return 'bar'; + } + """ + + # language=typescript + FILE2_CONTENT = """ + function test(): number { + return 1; + } + """ + with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={FILE1_FILENAME: FILE1_CONTENT, FILE2_FILENAME: FILE2_CONTENT}) as codebase: + file1 = codebase.get_file(FILE1_FILENAME) + file2 = codebase.get_file(FILE2_FILENAME) + + # Add multiple symbols one after another + file2.add_import(file1.get_symbol("foo")) + file2.add_import(file1.get_symbol("bar")) + + # Updated assertion to check for separate imports since that's the current behavior + assert "import { foo } from 'file1';" in file2.content + assert "import { bar } from 'file1';" in file2.content + + +def test_file_add_import_typescript_default_import(tmpdir) -> None: + # language=typescript + content = """ + function bar(): number { + return 1; + } + """ + with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase: + file = codebase.get_file("test.ts") + + file.add_import("import React from 'react';") + file.add_import("import { useState } from 'react';") + + file_lines = file.content.split("\n") + assert "import React from 'react';" in file_lines + assert "import { useState } from 'react';" in file_lines + + +def test_file_add_import_typescript_duplicate_prevention(tmpdir) -> None: + FILE1_FILENAME = "file1.ts" + FILE2_FILENAME = "file2.ts" + + # language=typescript + FILE1_CONTENT = """ + export function foo(): string { + return 'foo'; + } + """ + + # language=typescript + FILE2_CONTENT = """ + import { foo } from 'file1'; + + function test(): string { + return foo(); + } + """ + with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={FILE1_FILENAME: FILE1_CONTENT, FILE2_FILENAME: FILE2_CONTENT}) as codebase: + file1 = codebase.get_file(FILE1_FILENAME) + file2 = codebase.get_file(FILE2_FILENAME) + + # Try to add the same import again + file2.add_import(file1.get_symbol("foo")) + + # Verify no duplicate import was added + assert file2.content.count("import { foo }") == 1 + + +def test_file_add_import_string_adds_after_future(tmpdir) -> None: + # language=python + content = """ +from __future__ import annotations + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from sqlalchemy.orm import Session") + + file_lines = file.content.split("\n") + assert "from __future__ import annotations" in file_lines[1] + assert "from sqlalchemy.orm import Session" in file_lines[2] + + +def test_file_add_import_string_adds_after_last_future(tmpdir) -> None: + # language=python + content = """ +from __future__ import annotations +from __future__ import division + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from sqlalchemy.orm import Session") + + file_lines = file.content.split("\n") + assert "from __future__ import annotations" in file_lines[1] + assert "from __future__ import division" in file_lines[2] + assert "from sqlalchemy.orm import Session" in file_lines[3] + + +def test_file_add_import_string_adds_after_future_before_non_future(tmpdir) -> None: + # language=python + content = """ +from __future__ import annotations +from typing import List + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from sqlalchemy.orm import Session") + + file_lines = file.content.split("\n") + assert "from __future__ import annotations" in file_lines[1] + assert "from sqlalchemy.orm import Session" in file_lines[2] + assert "from typing import List" in file_lines[3] + + +def test_file_add_import_string_future_import_adds_to_top(tmpdir) -> None: + # language=python + content = """ +from __future__ import annotations + +def foo(): + print("this is foo") +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + + file.add_import("from __future__ import division") + + file_lines = file.content.split("\n") + assert "from __future__ import division" in file_lines[1] + assert "from __future__ import annotations" in file_lines[2] diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py b/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py index 1d905332a..089a4849d 100644 --- a/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py +++ b/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py @@ -14,7 +14,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from sqlalchemy.orm import Session") + file.add_import("from sqlalchemy.orm import Session") file_lines = file.content.split("\n") assert "from __future__ import annotations" in file_lines[1] @@ -33,7 +33,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from sqlalchemy.orm import Session") + file.add_import("from sqlalchemy.orm import Session") file_lines = file.content.split("\n") assert "from __future__ import annotations" in file_lines[1] @@ -53,7 +53,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from sqlalchemy.orm import Session") + file.add_import("from sqlalchemy.orm import Session") file_lines = file.content.split("\n") assert "from __future__ import annotations" in file_lines[1] @@ -72,7 +72,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from __future__ import division") + file.add_import("from __future__ import division") file_lines = file.content.split("\n") assert "from __future__ import division" in file_lines[1] @@ -88,7 +88,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from sqlalchemy.orm import Session") + file.add_import("from sqlalchemy.orm import Session") file_lines = file.content.split("\n") assert "from sqlalchemy.orm import Session" in file_lines[0] @@ -108,7 +108,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("from sqlalchemy.orm import Session") + file.add_import("from sqlalchemy.orm import Session") file_lines = file.content.split("\n") assert "from sqlalchemy.orm import Session" in file_lines @@ -126,7 +126,7 @@ def foo(): with get_codebase_session(tmpdir=tmpdir, files={"test.py": content.strip()}, sync_graph=sync) as codebase: file = codebase.get_file("test.py") - file.add_import_from_import_string("import antigravity") + file.add_import("import antigravity") file.remove() if sync: assert not codebase.get_file(file.filepath, optional=True) diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py b/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py deleted file mode 100644 index 7088f2c1e..000000000 --- a/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py +++ /dev/null @@ -1,24 +0,0 @@ -from codegen.sdk.codebase.factory.get_session import get_codebase_session - - -def test_file_add_symbol_import_updates_source(tmpdir) -> None: - # language=python - content1 = """ -import datetime - -def foo(): - return datetime.datetime.now() -""" - - # language=python - content2 = """ -def bar(): - return 1 -""" - with get_codebase_session(tmpdir=tmpdir, files={"file1.py": content1, "file2.py": content2}) as codebase: - file1 = codebase.get_file("file1.py") - file2 = codebase.get_file("file2.py") - - file2.add_symbol_import(file1.get_symbol("foo")) - - assert "import foo" in file2.content diff --git a/tests/unit/codegen/sdk/python/file/test_file_reparse.py b/tests/unit/codegen/sdk/python/file/test_file_reparse.py index 7f14c79c4..b4314f9f2 100644 --- a/tests/unit/codegen/sdk/python/file/test_file_reparse.py +++ b/tests/unit/codegen/sdk/python/file/test_file_reparse.py @@ -98,7 +98,7 @@ def test_file_reparse_move_global_var(mock_codebase_setup: tuple[Codebase, File, global_var1.remove() global_var2 = file2.get_global_var("GLOBAL_CONSTANT_2") global_var2.insert_before(global_var1.source) - file1.add_symbol_import(global_var1) + file1.add_import(global_var1) # Remove the import to GLOBAL_CONSTANT_1 from file2 imp_to_remove = file2.get_import("GLOBAL_CONSTANT_1") diff --git a/tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py b/tests/unit/codegen/sdk/typescript/file/test_file_add_import.py similarity index 94% rename from tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py rename to tests/unit/codegen/sdk/typescript/file/test_file_add_import.py index 81115ce0b..40fa1ca4f 100644 --- a/tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py +++ b/tests/unit/codegen/sdk/typescript/file/test_file_add_import.py @@ -25,6 +25,6 @@ def test_file_add_symbol_import_updates_source(tmpdir) -> None: file1 = codebase.get_file(FILE1_FILENAME) file2 = codebase.get_file(FILE2_FILENAME) - file2.add_symbol_import(file1.get_symbol("foo")) + file2.add_import(file1.get_symbol("foo")) assert "import { foo } from 'file1';" in file2.content diff --git a/tests/unit/skills/implementations/decorator_skills.py b/tests/unit/skills/implementations/decorator_skills.py index d586b0464..f74877deb 100644 --- a/tests/unit/skills/implementations/decorator_skills.py +++ b/tests/unit/skills/implementations/decorator_skills.py @@ -54,7 +54,7 @@ def python_skill_func(codebase: CodebaseType): # if the file does not have the decorator symbol and the decorator symbol is not in the same file if not file.has_import(decorator_symbol.name) and decorator_symbol.file != file: # import the decorator symbol - file.add_symbol_import(decorator_symbol) + file.add_import(decorator_symbol) # iterate through each function in the file for function in file.functions: diff --git a/tests/unit/skills/implementations/eval_skills.py b/tests/unit/skills/implementations/eval_skills.py index 0a25fa376..99e0b65ac 100644 --- a/tests/unit/skills/implementations/eval_skills.py +++ b/tests/unit/skills/implementations/eval_skills.py @@ -84,7 +84,7 @@ def python_skill_func(codebase: CodebaseType): # if the decorator is not imported or declared in the file if not file.has_import("decorator_function") and decorator_symbol.file != file: # add an import for the decorator function - file.add_symbol_import(decorator_symbol) + file.add_import(decorator_symbol) # add the decorator to the function function.add_decorator(f"@{decorator_symbol.name}") @@ -370,7 +370,7 @@ def typescript_skill_func(codebase: CodebaseType): # if the file does not exist create it new_file = codebase.create_file(str(new_file_path)) # add an import for React - new_file.add_import_from_import_string('import React from "react";') + new_file.add_import('import React from "react";') # move the component to the new file component.move_to_file(new_file) diff --git a/tests/unit/skills/implementations/example_skills.py b/tests/unit/skills/implementations/example_skills.py index aa122000e..e0c025b88 100644 --- a/tests/unit/skills/implementations/example_skills.py +++ b/tests/unit/skills/implementations/example_skills.py @@ -141,13 +141,13 @@ def python_skill_func(codebase: CodebaseType): for file in codebase.files: for function in file.functions: if function.name.startswith("test_"): - file.add_import_from_import_string("import pytest") + file.add_import("import pytest") function.add_decorator('@pytest.mark.skip(reason="This is a test")') for cls in file.classes: for method in cls.methods: if method.name.startswith("test_"): - file.add_import_from_import_string("import pytest") + file.add_import("import pytest") method.add_decorator('@pytest.mark.skip(reason="This is a test")') @staticmethod @@ -181,7 +181,7 @@ def python_skill_func(codebase: CodebaseType): function.set_return_type("None") else: function.set_return_type("Any") - function.file.add_import_from_import_string("from typing import Any") + function.file.add_import("from typing import Any") for param in function.parameters: if not param.is_typed: @@ -191,7 +191,7 @@ def python_skill_func(codebase: CodebaseType): param.set_type_annotation("str") else: param.set_type_annotation("Any") - function.file.add_import_from_import_string("from typing import Any") + function.file.add_import("from typing import Any") @staticmethod @skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT) diff --git a/tests/unit/skills/implementations/guides/increase-type-coverage.py b/tests/unit/skills/implementations/guides/increase-type-coverage.py index a84b09b74..f04d74288 100644 --- a/tests/unit/skills/implementations/guides/increase-type-coverage.py +++ b/tests/unit/skills/implementations/guides/increase-type-coverage.py @@ -318,7 +318,7 @@ def python_skill_func(codebase: CodebaseType): # import c from module c = codebase.get_file("path/to/module.py").get_symbol("c") - target_file.add_symbol_import(c) + target_file.add_import(c) # Add a new option to the return type function.return_type.append("c") @@ -331,7 +331,7 @@ def typescript_skill_func(codebase: CodebaseType): function = target_file.get_function("functionName") # function functionName(): a | b: ... c = codebase.get_file("path/to/module.ts").get_symbol("c") - target_file.add_symbol_import(c) + target_file.add_import(c) # Add a new option to the return type function.return_type.append("c")