diff --git a/codegen-examples/examples/dict_to_schema/run.py b/codegen-examples/examples/dict_to_schema/run.py
index e50482ab6..69ae8a3cf 100644
--- a/codegen-examples/examples/dict_to_schema/run.py
+++ b/codegen-examples/examples/dict_to_schema/run.py
@@ -84,7 +84,7 @@ def run(codebase: Codebase):
# Add imports if needed
if needs_imports:
- file.add_import_from_import_string("from pydantic import BaseModel")
+ file.add_import("from pydantic import BaseModel")
if file_modified:
files_modified += 1
diff --git a/codegen-examples/examples/flask_to_fastapi_migration/run.py b/codegen-examples/examples/flask_to_fastapi_migration/run.py
index 90db1d39b..ea8823a9a 100644
--- a/codegen-examples/examples/flask_to_fastapi_migration/run.py
+++ b/codegen-examples/examples/flask_to_fastapi_migration/run.py
@@ -57,7 +57,7 @@ def setup_static_files(file):
print(f"š Processing file: {file.filepath}")
# Add import for StaticFiles
- file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
+ file.add_import("from fastapi.staticfiles import StaticFiles")
print("ā
Added import: from fastapi.staticfiles import StaticFiles")
# Add app.mount for static file handling
diff --git a/codegen-examples/examples/sqlalchemy_soft_delete/README.md b/codegen-examples/examples/sqlalchemy_soft_delete/README.md
index 9c2bec6ec..b8c9a22db 100644
--- a/codegen-examples/examples/sqlalchemy_soft_delete/README.md
+++ b/codegen-examples/examples/sqlalchemy_soft_delete/README.md
@@ -58,7 +58,7 @@ The codemod processes your codebase in several steps:
```python
def ensure_and_import(file):
if not any("and_" in imp.name for imp in file.imports):
- file.add_import_from_import_string("from sqlalchemy import and_")
+ file.add_import("from sqlalchemy import and_")
```
- Automatically adds required SQLAlchemy imports (`and_`)
diff --git a/codegen-examples/examples/sqlalchemy_soft_delete/run.py b/codegen-examples/examples/sqlalchemy_soft_delete/run.py
index 4090bfa32..fb248e31a 100644
--- a/codegen-examples/examples/sqlalchemy_soft_delete/run.py
+++ b/codegen-examples/examples/sqlalchemy_soft_delete/run.py
@@ -51,7 +51,7 @@ def ensure_and_import(file):
"""Ensure the file has the necessary and_ import."""
if not any("and_" in imp.name for imp in file.imports):
print(f"File {file.filepath} does not import and_. Adding import.")
- file.add_import_from_import_string("from sqlalchemy import and_")
+ file.add_import("from sqlalchemy import and_")
def clone_repo(repo_url: str, repo_path: Path) -> None:
diff --git a/codegen-examples/examples/sqlalchemy_type_annotations/run.py b/codegen-examples/examples/sqlalchemy_type_annotations/run.py
index 96574152d..fdfcf5a9a 100644
--- a/codegen-examples/examples/sqlalchemy_type_annotations/run.py
+++ b/codegen-examples/examples/sqlalchemy_type_annotations/run.py
@@ -100,16 +100,16 @@ def run(codebase: Codebase):
# Add necessary imports
if not cls.file.has_import("Mapped"):
- cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n")
+ cls.file.add_import("from sqlalchemy.orm import Mapped\n")
if "Optional" in new_type and not cls.file.has_import("Optional"):
- cls.file.add_import_from_import_string("from typing import Optional\n")
+ cls.file.add_import("from typing import Optional\n")
if "Decimal" in new_type and not cls.file.has_import("Decimal"):
- cls.file.add_import_from_import_string("from decimal import Decimal\n")
+ cls.file.add_import("from decimal import Decimal\n")
if "datetime" in new_type and not cls.file.has_import("datetime"):
- cls.file.add_import_from_import_string("from datetime import datetime\n")
+ cls.file.add_import("from datetime import datetime\n")
if class_modified:
classes_modified += 1
diff --git a/codegen-examples/examples/unittest_to_pytest/run.py b/codegen-examples/examples/unittest_to_pytest/run.py
index b4e32a55d..339b583b9 100644
--- a/codegen-examples/examples/unittest_to_pytest/run.py
+++ b/codegen-examples/examples/unittest_to_pytest/run.py
@@ -24,7 +24,7 @@ def convert_to_pytest_fixtures(file):
print(f"š Processing file: {file.filepath}")
if not any(imp.name == "pytest" for imp in file.imports):
- file.add_import_from_import_string("import pytest")
+ file.add_import("import pytest")
print(f"ā Added pytest import to {file.filepath}")
for cls in file.classes:
diff --git a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md
index 7d30ab454..4ec033802 100644
--- a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md
+++ b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/README.md
@@ -25,7 +25,7 @@ The script automates the entire migration process in a few key steps:
```python
import_str = "import { useQuery, useSuspenseQueries } from '@tanstack/react-query'"
- file.add_import_from_import_string(import_str)
+ file.add_import(import_str)
```
- Uses Codegen's import analysis to add required imports
diff --git a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py
index 392f741eb..0804c7123 100644
--- a/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py
+++ b/codegen-examples/examples/usesuspensequery_to_usesuspensequeries/run.py
@@ -26,7 +26,7 @@ def run(codebase: Codebase):
print(f"Processing {file.filepath}")
# Add the import statement
- file.add_import_from_import_string(import_str)
+ file.add_import(import_str)
file_modified = False
# Iterate through all functions in the file
diff --git a/docs/building-with-codegen/imports.mdx b/docs/building-with-codegen/imports.mdx
index 95ecff990..8a707365d 100644
--- a/docs/building-with-codegen/imports.mdx
+++ b/docs/building-with-codegen/imports.mdx
@@ -120,7 +120,7 @@ for module, imports in module_imports.items():
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
- file.add_import_from_import_string(
+ file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
diff --git a/docs/building-with-codegen/react-and-jsx.mdx b/docs/building-with-codegen/react-and-jsx.mdx
index 395a16dd6..1784c8a41 100644
--- a/docs/building-with-codegen/react-and-jsx.mdx
+++ b/docs/building-with-codegen/react-and-jsx.mdx
@@ -136,5 +136,5 @@ for function in codebase.functions:
# Add import if needed
if not file.has_import("NewComponent"):
- file.add_symbol_import(new_component)
+ file.add_import(new_component)
```
diff --git a/docs/tutorials/flask-to-fastapi.mdx b/docs/tutorials/flask-to-fastapi.mdx
index ae72e8a9f..e8076ceeb 100644
--- a/docs/tutorials/flask-to-fastapi.mdx
+++ b/docs/tutorials/flask-to-fastapi.mdx
@@ -119,7 +119,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi
```python
# Add StaticFiles import
-file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
+file.add_import("from fastapi.staticfiles import StaticFiles")
# Mount static directory
file.add_symbol_from_source(
diff --git a/docs/tutorials/modularity.mdx b/docs/tutorials/modularity.mdx
index 6923b3471..84c55835d 100644
--- a/docs/tutorials/modularity.mdx
+++ b/docs/tutorials/modularity.mdx
@@ -116,17 +116,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if third_party_imports:
for imp in third_party_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if local_imports:
for imp in local_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
# Organize imports in all files
for file in codebase.files:
diff --git a/docs/tutorials/react-modernization.mdx b/docs/tutorials/react-modernization.mdx
index a4036cbea..170999c5a 100644
--- a/docs/tutorials/react-modernization.mdx
+++ b/docs/tutorials/react-modernization.mdx
@@ -82,7 +82,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) =
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
- file.add_import_from_import_string("import { useState, useEffect } from 'react';")
+ file.add_import("import { useState, useEffect } from 'react';")
```
## Migrating to Modern Hooks
@@ -100,7 +100,7 @@ for function in codebase.functions:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
- function.file.add_import_from_import_string(
+ function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
diff --git a/src/codegen/cli/mcp/resources/system_prompt.py b/src/codegen/cli/mcp/resources/system_prompt.py
index 9535570ab..9c7e23c6b 100644
--- a/src/codegen/cli/mcp/resources/system_prompt.py
+++ b/src/codegen/cli/mcp/resources/system_prompt.py
@@ -2909,7 +2909,7 @@ def validate_data(data: dict) -> bool:
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
- file.add_import_from_import_string(
+ file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
@@ -5180,7 +5180,7 @@ def build_graph(func, depth=0):
# Add import if needed
if not file.has_import("NewComponent"):
- file.add_symbol_import(new_component)
+ file.add_import(new_component)
```
@@ -7316,17 +7316,17 @@ def organize_file_imports(file):
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if third_party_imports:
for imp in third_party_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
if local_imports:
for imp in local_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
# Organize imports in all files
for file in codebase.files:
@@ -8593,7 +8593,7 @@ class FeatureFlags:
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
- file.add_import_from_import_string("import { useState, useEffect } from 'react';")
+ file.add_import("import { useState, useEffect } from 'react';")
```
## Migrating to Modern Hooks
@@ -8611,7 +8611,7 @@ class FeatureFlags:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
- function.file.add_import_from_import_string(
+ function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
@@ -9813,7 +9813,7 @@ def create_user():
```python
# Add StaticFiles import
-file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
+file.add_import("from fastapi.staticfiles import StaticFiles")
# Mount static directory
file.add_symbol_from_source(
diff --git a/src/codegen/sdk/core/class_definition.py b/src/codegen/sdk/core/class_definition.py
index 755d46f9b..bbf2682ab 100644
--- a/src/codegen/sdk/core/class_definition.py
+++ b/src/codegen/sdk/core/class_definition.py
@@ -378,9 +378,9 @@ def add_attribute(self, attribute: Attribute, include_dependencies: bool = False
file = self.file
for d in deps:
if isinstance(d, Import):
- file.add_symbol_import(d.imported_symbol)
+ file.add_import(d.imported_symbol)
elif isinstance(d, Symbol):
- file.add_symbol_import(d)
+ file.add_import(d)
@property
@noapidoc
diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py
index b282942ae..8ad9e1385 100644
--- a/src/codegen/sdk/core/file.py
+++ b/src/codegen/sdk/core/file.py
@@ -944,62 +944,56 @@ def update_filepath(self, new_filepath: str) -> None:
imp.set_import_module(new_module_name)
@writer
- def add_symbol_import(
- self,
- symbol: Symbol,
- alias: str | None = None,
- import_type: ImportType = ImportType.UNKNOWN,
- is_type_import: bool = False,
- ) -> Import | None:
- """Adds an import to a file for a given symbol.
-
- This method adds an import statement to the file for a specified symbol. If an import for the
- symbol already exists, it returns the existing import instead of creating a new one.
+ def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None:
+ """Adds an import to the file.
- Args:
- symbol (Symbol): The symbol to import.
- alias (str | None): Optional alias for the imported symbol. Defaults to None.
- import_type (ImportType): The type of import to use. Defaults to ImportType.UNKNOWN.
- is_type_import (bool): Whether this is a type-only import. Defaults to False.
-
- Returns:
- Import | None: The existing import for the symbol or None if it was added.
- """
- imports = self.imports
- match = next((x for x in imports if x.imported_symbol == symbol), None)
- if match:
- return match
-
- import_string = symbol.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
- self.add_import_from_import_string(import_string)
-
- @writer(commit=False)
- def add_import_from_import_string(self, import_string: str) -> None:
- """Adds import to the file from a string representation of an import statement.
-
- This method adds a new import statement to the file based on its string representation.
+ This method adds an import statement to the file. It can handle both string imports and symbol imports.
If the import already exists in the file, or is pending to be added, it won't be added again.
If there are existing imports, the new import will be added before the first import,
otherwise it will be added at the beginning of the file.
Args:
- import_string (str): The string representation of the import statement to add.
+ imp (Symbol | str): Either a Symbol to import or a string representation of an import statement.
+ alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None.
+ import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN.
+ is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False.
Returns:
- None
+ Import | None: The existing import for the symbol if found, otherwise None.
"""
- if any(import_string.strip() in imp.source for imp in self.imports):
- return
+ # Handle Symbol imports
+ if isinstance(imp, str):
+ # Handle string imports
+ import_string = imp
+ # Check for duplicate imports
+ if any(import_string.strip() in imp.source for imp in self.imports):
+ return None
+ else:
+ # Check for existing imports of this symbol
+ imports = self.imports
+ match = next((x for x in imports if x.imported_symbol == imp), None)
+ if match:
+ return match
+
+ # Convert symbol to import string
+ import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
+
if import_string.strip() in self._pending_imports:
# Don't add the import string if it will already be added by another symbol
- return
+ return None
+
+ # Add to pending imports and setup undo
self._pending_imports.add(import_string.strip())
self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear())
+
+ # Insert the import at the appropriate location
if self.imports:
self.imports[0].insert_before(import_string, priority=1)
else:
self.insert_before(import_string, priority=1)
+ return None
+
@writer
def add_symbol_from_source(self, source: str) -> None:
"""Adds a symbol to a file from a string representation.
diff --git a/src/codegen/sdk/core/symbol.py b/src/codegen/sdk/core/symbol.py
index 559a5cd58..cc0238b45 100644
--- a/src/codegen/sdk/core/symbol.py
+++ b/src/codegen/sdk/core/symbol.py
@@ -329,19 +329,19 @@ def _move_to_file(
# =====[ Imports - copy over ]=====
elif isinstance(dep, Import):
if dep.imported_symbol:
- file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
+ file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
else:
- file.add_import_from_import_string(dep.source)
+ file.add_import(imp=dep.source)
else:
for dep in self.dependencies:
# =====[ Symbols - add back edge ]=====
if isinstance(dep, Symbol) and dep.is_top_level:
- file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
+ file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=False)
elif isinstance(dep, Import):
if dep.imported_symbol:
- file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source)
+ file.add_import(imp=dep.imported_symbol, alias=dep.alias.source)
else:
- file.add_import_from_import_string(dep.source)
+ file.add_import(imp=dep.source)
# =====[ Make a new symbol in the new file ]=====
file.add_symbol(self)
@@ -364,7 +364,7 @@ def _move_to_file(
# Here, we will add a "back edge" to the old file importing the symbol
elif strategy == "add_back_edge":
if is_used_in_file or any(usage.kind is UsageKind.IMPORTED and usage.usage_symbol not in encountered_symbols for usage in self.usages):
- self.file.add_import_from_import_string(import_line)
+ self.file.add_import(imp=import_line)
# Delete the original symbol
self.remove()
@@ -374,7 +374,7 @@ def _move_to_file(
for usage in self.usages:
if isinstance(usage.usage_symbol, Import) and usage.usage_symbol.file != file:
# Add updated import
- usage.usage_symbol.file.add_import_from_import_string(import_line)
+ usage.usage_symbol.file.add_import(import_line)
usage.usage_symbol.remove()
elif usage.usage_type == UsageType.CHAINED:
# Update all previous usages of import * to the new import name
@@ -383,11 +383,11 @@ def _move_to_file(
usage.match.get_name().edit(self.name)
if isinstance(usage.match, ChainedAttribute):
usage.match.edit(self.name)
- usage.usage_symbol.file.add_import_from_import_string(import_line)
+ usage.usage_symbol.file.add_import(imp=import_line)
# Add the import to the original file
if is_used_in_file:
- self.file.add_import_from_import_string(import_line)
+ self.file.add_import(imp=import_line)
# Delete the original symbol
self.remove()
diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py
index 3c92feaef..3b1fc9f93 100644
--- a/src/codegen/sdk/python/file.py
+++ b/src/codegen/sdk/python/file.py
@@ -5,6 +5,7 @@
from codegen.sdk.core.autocommit import reader, writer
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interface import Interface
+from codegen.sdk.core.symbol import Symbol
from codegen.sdk.enums import ImportType
from codegen.sdk.extensions.utils import cached_property
from codegen.sdk.python import PyAssignment
@@ -20,7 +21,7 @@
if TYPE_CHECKING:
from codegen.sdk.codebase.codebase_context import CodebaseContext
- from codegen.sdk.core.import_resolution import WildcardImport
+ from codegen.sdk.core.import_resolution import Import, WildcardImport
from codegen.sdk.python.symbol import PySymbol
@@ -119,7 +120,7 @@ def get_import_insert_index(self, import_string) -> int | None:
The function determines the optimal position for inserting a new import statement, following Python's import ordering conventions.
Future imports are placed at the top of the file, followed by all other imports.
- Args:
+ Args:z
import_string (str): The import statement to be inserted.
Returns:
@@ -146,28 +147,57 @@ def get_import_insert_index(self, import_string) -> int | None:
####################################################################################################################
@writer
- def add_import_from_import_string(self, import_string: str) -> None:
- """Adds an import statement to the file from a string representation.
+ def add_import(self, imp: Symbol | str, *, alias: str | None = None, import_type: ImportType = ImportType.UNKNOWN, is_type_import: bool = False) -> Import | None:
+ """Adds an import to the file.
- This method adds a new import statement to the file, handling placement based on existing imports.
- Future imports are placed at the top of the file, followed by regular imports.
+ This method adds an import statement to the file. It can handle both string imports and symbol imports.
+ If the import already exists in the file, or is pending to be added, it won't be added again.
+ Future imports are placed at the top, followed by regular imports.
Args:
- import_string (str): The string representation of the import statement to add (e.g., 'from module import symbol').
+ imp (Symbol | str): Either a Symbol to import or a string representation of an import statement.
+ alias (str | None): Optional alias for the imported symbol. Only used when imp is a Symbol. Defaults to None.
+ import_type (ImportType): The type of import to use. Only used when imp is a Symbol. Defaults to ImportType.UNKNOWN.
+ is_type_import (bool): Whether this is a type-only import. Only used when imp is a Symbol. Defaults to False.
Returns:
- None: This function modifies the file in place.
+ Import | None: The existing import for the symbol if found, otherwise None.
"""
+ # Handle Symbol imports
+ if isinstance(imp, Symbol):
+ imports = self.imports
+ match = next((x for x in imports if x.imported_symbol == imp), None)
+ if match:
+ return match
+
+ # Convert symbol to import string
+ import_string = imp.get_import_string(alias, import_type=import_type, is_type_import=is_type_import)
+ else:
+ # Handle string imports
+ import_string = str(imp)
+
+ # Check for duplicate imports
+ if any(import_string.strip() in str(imp.source) for imp in self.imports):
+ return None
+ if import_string.strip() in self._pending_imports:
+ return None
+
+ # Add to pending imports
+ self._pending_imports.add(import_string.strip())
+ self.transaction_manager.pending_undos.add(lambda: self._pending_imports.clear())
+
+ # Insert at correct location
if self.imports:
import_insert_index = self.get_import_insert_index(import_string) or 0
if import_insert_index < len(self.imports):
self.imports[import_insert_index].insert_before(import_string, priority=1)
else:
- # If import_insert_index is out of bounds, do insert after the last import
self.imports[-1].insert_after(import_string, priority=1)
else:
self.insert_before(import_string, priority=1)
+ return None
+
@noapidoc
def remove_unused_exports(self) -> None:
"""Removes unused exports from the file. NO-OP for python"""
diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt
index 0a0cff0e9..ad5007004 100644
--- a/src/codegen/sdk/system-prompt.txt
+++ b/src/codegen/sdk/system-prompt.txt
@@ -29,7 +29,7 @@ codebase.commit()
-Codegen handles complex refactors while maintaining correctness, enabling a broad set of advanced code manipulation programs.
+Codegen handles complex refactors while maintaining correctness, enabling a broad set of advanced code manipulation programs.
Codegen works with both Python and Typescript/JSX codebases. Learn more about language support [here](/building-with-codegen/language-support).
@@ -492,7 +492,7 @@ Let's walk through a minimal example of using Codegen in a project:
```bash
codegen init
```
-
+
This creates a `.codegen/` directory with:
```bash
.codegen/
@@ -560,7 +560,7 @@ Let's walk through a minimal example of using Codegen in a project:
For more help, join our [community Slack](/introduction/community) or check the [FAQ](/introduction/faq).
-
+
---
title: "Using Codegen in Your IDE"
@@ -589,7 +589,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your
```bash
.codegen/.venv/bin/python
```
-
+
Alternatively, create a `.vscode/settings.json`:
```json
{
@@ -611,7 +611,7 @@ Codegen creates a custom Python environment in `.codegen/.venv`. Configure your
.codegen/.venv/bin/python
```
-
+
@@ -1156,8 +1156,8 @@ iconType: "solid"
- Yes - [by design](/introduction/guiding-principles#python-first-composability).
-
+ Yes - [by design](/introduction/guiding-principles#python-first-composability).
+
Codegen works like any other python package. It works alongside your IDE, version control system, and other development tools.
- Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects.
+ Currently, the codebase object can only parse source code files of one language at a time. This means that if you want to work with both Python and TypeScript files, you will need to create two separate codebase objects.
## Accessing Code
@@ -2923,7 +2923,7 @@ for module, imports in module_imports.items():
if len(imports) > 1:
# Create combined import
symbols = [imp.name for imp in imports]
- file.add_import_from_import_string(
+ file.add_import(
f"import {{ {', '.join(symbols)} }} from '{module}'"
)
# Remove old imports
@@ -2933,7 +2933,7 @@ for module, imports in module_imports.items():
Always check if imports resolve to external modules before modification to avoid breaking third-party package imports.
-
+
## Import Statements vs Imports
@@ -3135,7 +3135,7 @@ for exp in file.exports:
# Get original and current symbols
current = exp.exported_symbol
original = exp.resolved_symbol
-
+
print(f"Re-exporting {original.name} from {exp.from_file.filepath}")
print(f"Through: {' -> '.join(e.file.filepath for e in exp.export_chain)}")
```
@@ -3185,7 +3185,7 @@ for from_file, exports in file_exports.items():
When managing exports, consider the impact on your module's public API. Not all symbols that can be exported should be exported.
-
+
---
title: "Inheritable Behaviors"
@@ -3675,9 +3675,9 @@ If `A` depends on `B`, then `B` is used by `A`. This relationship is tracked in
flowchart LR
B(BaseClass)
-
-
-
+
+
+
A(MyClass)
B ---| used by |A
A ---|depends on |B
@@ -3846,7 +3846,7 @@ class A:
def method_a(self): pass
class B(A):
- def method_b(self):
+ def method_b(self):
self.method_a()
class C(B):
@@ -4736,7 +4736,7 @@ for attr in class_def.attributes:
# Each attribute has an assignment property
attr_type = attr.assignment.type # -> TypeAnnotation
print(f"{attr.name}: {attr_type.source}") # e.g. "x: int"
-
+
# Set attribute type
attr.assignment.set_type("int")
@@ -4753,7 +4753,7 @@ Union types ([UnionType](/api-reference/core/UnionType)) can be manipulated as c
```python
# Get union type
-union_type = function.return_type # -> A | B
+union_type = function.return_type # -> A | B
print(union_type.symbols) # ["A", "B"]
# Add/remove options
@@ -5271,7 +5271,7 @@ for function in codebase.functions:
# Add import if needed
if not file.has_import("NewComponent"):
- file.add_symbol_import(new_component)
+ file.add_import(new_component)
```
@@ -5604,13 +5604,13 @@ Here's an example of using flags during code analysis:
```python
def analyze_codebase(codebase):
- for function in codebase.functions:
+ for function in codebase.functions:
# Check documentation
if not function.docstring:
function.flag(
message="Missing docstring",
)
-
+
# Check error handling
if function.is_async and not function.has_try_catch:
function.flag(
@@ -6320,7 +6320,7 @@ Explore our tutorials to learn how to use Codegen for various code transformatio
>
Update API calls, handle breaking changes, and manage bulk updates across your codebase.
-
Convert Flask applications to FastAPI, updating routes and dependencies.
-
Migrate Python 2 code to Python 3, updating syntax and modernizing APIs.
@@ -6353,9 +6353,9 @@ Explore our tutorials to learn how to use Codegen for various code transformatio
>
Restructure files, enforce naming conventions, and improve project layout.
-
Split large files, extract shared logic, and manage dependencies.
@@ -6453,7 +6453,7 @@ The agent has access to powerful code viewing and manipulation tools powered by
- `CreateFileTool`: Create new files
- `DeleteFileTool`: Delete files
- `RenameFileTool`: Rename files
-- `EditFileTool`: Edit files
+- `EditFileTool`: Edit files
@@ -6960,7 +6960,6 @@ Be explicit about the changes, produce a short summary, and point out possible i
Focus on facts and technical details, using code snippets where helpful.
"""
result = agent.run(prompt)
-
# Clean up the temporary comment
comment.delete()
```
@@ -7046,7 +7045,7 @@ While this example demonstrates a basic PR review bot, you can extend it to:
>
Understand code review patterns and best practices.
-
+
---
title: "Deep Code Research with AI"
@@ -7174,21 +7173,21 @@ def research(repo_name: Optional[str] = None, query: Optional[str] = None):
"""Start a code research session."""
# Initialize codebase
codebase = initialize_codebase(repo_name)
-
+
# Create and run the agent
agent = create_research_agent(codebase)
-
+
# Main research loop
while True:
if not query:
query = Prompt.ask("[bold cyan]Research query[/bold cyan]")
-
+
result = agent.invoke(
{"input": query},
config={"configurable": {"thread_id": 1}}
)
console.print(Markdown(result["messages"][-1].content))
-
+
query = None # Clear for next iteration
```
@@ -7236,7 +7235,7 @@ class CustomAnalysisTool(BaseTool):
"""Custom tool for specialized code analysis."""
name = "custom_analysis"
description = "Performs specialized code analysis"
-
+
def _run(self, query: str) -> str:
# Custom analysis logic
return results
@@ -7514,7 +7513,7 @@ from codegen import Codebase
# Initialize codebase
codebase = Codebase("path/to/posthog/")
-# Create a directed graph for representing call relationships
+# Create a directed graph for representing call relationships
G = nx.DiGraph()
# Configuration flags
@@ -7536,7 +7535,7 @@ We'll create a function that will recursively traverse the call trace of a funct
```python
def create_downstream_call_trace(src_func: Function, depth: int = 0):
"""Creates call graph by recursively traversing function calls
-
+
Args:
src_func (Function): Starting function for call graph
depth (int): Current recursion depth
@@ -7544,7 +7543,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
# Prevent infinite recursion
if MAX_DEPTH <= depth:
return
-
+
# External modules are not functions
if isinstance(src_func, ExternalModule):
return
@@ -7554,12 +7553,12 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
# Skip self-recursive calls
if call.name == src_func.name:
continue
-
+
# Get called function definition
func = call.function_definition
if not func:
continue
-
+
# Apply configured filters
if isinstance(func, ExternalModule) and IGNORE_EXTERNAL_MODULE_CALLS:
continue
@@ -7573,7 +7572,7 @@ def create_downstream_call_trace(src_func: Function, depth: int = 0):
func_name = f"{func.parent_class.name}.{func.name}" if func.is_method else func.name
# Add node and edge with metadata
- G.add_node(func, name=func_name,
+ G.add_node(func, name=func_name,
color=COLOR_PALETTE.get(func.__class__.__name__))
G.add_edge(src_func, func, **generate_edge_meta(call))
@@ -7588,10 +7587,10 @@ We can enrich our edges with metadata about the function calls:
```python
def generate_edge_meta(call: FunctionCall) -> dict:
"""Generate metadata for call graph edges
-
+
Args:
call (FunctionCall): Function call information
-
+
Returns:
dict: Edge metadata including name and location
"""
@@ -7610,8 +7609,8 @@ Finally, we can visualize our call graph starting from a specific function:
target_class = codebase.get_class('SharingConfigurationViewSet')
target_method = target_class.get_method('patch')
-# Add root node
-G.add_node(target_method,
+# Add root node
+G.add_node(target_method,
name=f"{target_class.name}.{target_method.name}",
color=COLOR_PALETTE["StartFunction"])
@@ -7661,7 +7660,7 @@ The core function for building our dependency graph:
```python
def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
"""Creates visualization of symbol dependencies
-
+
Args:
symbol (Symbol): Starting symbol to analyze
depth (int): Current recursion depth
@@ -7669,11 +7668,11 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
# Prevent excessive recursion
if depth >= MAX_DEPTH:
return
-
+
# Process each dependency
for dep in symbol.dependencies:
dep_symbol = None
-
+
# Handle different dependency types
if isinstance(dep, Symbol):
# Direct symbol reference
@@ -7684,13 +7683,13 @@ def create_dependencies_visualization(symbol: Symbol, depth: int = 0):
if dep_symbol:
# Add node with appropriate styling
- G.add_node(dep_symbol,
- color=COLOR_PALETTE.get(dep_symbol.__class__.__name__,
+ G.add_node(dep_symbol,
+ color=COLOR_PALETTE.get(dep_symbol.__class__.__name__,
"#f694ff"))
-
+
# Add dependency relationship
G.add_edge(symbol, dep_symbol)
-
+
# Recurse unless it's a class (avoid complexity)
if not isinstance(dep_symbol, PyClass):
create_dependencies_visualization(dep_symbol, depth + 1)
@@ -7702,7 +7701,7 @@ Finally, we can visualize our dependency graph starting from a specific symbol:
# Get target symbol
target_func = codebase.get_function("get_query_runner")
-# Add root node
+# Add root node
G.add_node(target_func, color=COLOR_PALETTE["StartFunction"])
# Generate dependency graph
@@ -7745,16 +7744,16 @@ HTTP_METHODS = ["get", "put", "patch", "post", "head", "delete"]
def generate_edge_meta(usage: Usage) -> dict:
"""Generate metadata for graph edges
-
+
Args:
usage (Usage): Usage relationship information
-
+
Returns:
dict: Edge metadata including name and location
"""
return {
"name": usage.match.source,
- "file_path": usage.match.filepath,
+ "file_path": usage.match.filepath,
"start_point": usage.match.start_point,
"end_point": usage.match.end_point,
"symbol_name": usage.match.__class__.__name__
@@ -7762,10 +7761,10 @@ def generate_edge_meta(usage: Usage) -> dict:
def is_http_method(symbol: PySymbol) -> bool:
"""Check if a symbol is an HTTP endpoint method
-
+
Args:
symbol (PySymbol): Symbol to check
-
+
Returns:
bool: True if symbol is an HTTP method
"""
@@ -7779,7 +7778,7 @@ The main function for creating our blast radius visualization:
```python
def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
"""Create visualization of symbol usage relationships
-
+
Args:
symbol (PySymbol): Starting symbol to analyze
depth (int): Current recursion depth
@@ -7787,11 +7786,11 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
# Prevent excessive recursion
if depth >= MAX_DEPTH:
return
-
+
# Process each usage of the symbol
for usage in symbol.usages:
usage_symbol = usage.usage_symbol
-
+
# Determine node color based on type
if is_http_method(usage_symbol):
color = COLOR_PALETTE.get("HTTP_METHOD")
@@ -7801,7 +7800,7 @@ def create_blast_radius_visualization(symbol: PySymbol, depth: int = 0):
# Add node and edge to graph
G.add_node(usage_symbol, color=color)
G.add_edge(symbol, usage_symbol, **generate_edge_meta(usage))
-
+
# Recursively process usage symbol
create_blast_radius_visualization(usage_symbol, depth + 1)
```
@@ -7952,7 +7951,7 @@ for call in old_api.call_sites:
f"data={call.get_arg_by_parameter_name('input').value}",
f"timeout={call.get_arg_by_parameter_name('wait').value}"
]
-
+
# Replace the old call with the new API
call.replace(f"new_process_data({', '.join(args)})")
```
@@ -7966,10 +7965,10 @@ When updating chained method calls, like database queries or builder patterns:
for execute_call in codebase.function_calls:
if execute_call.name != "execute":
continue
-
+
# Get the full chain
chain = execute_call.call_chain
-
+
# Example: Add .timeout() before .execute()
if "timeout" not in {call.name for call in chain}:
execute_call.insert_before("timeout(30)")
@@ -7988,45 +7987,45 @@ Here's a comprehensive example:
```python
def migrate_api_v1_to_v2(codebase):
old_api = codebase.get_function("create_user_v1")
-
+
# Document all existing call patterns
call_patterns = {}
for call in old_api.call_sites:
args = [arg.source for arg in call.args]
pattern = ", ".join(args)
call_patterns[pattern] = call_patterns.get(pattern, 0) + 1
-
+
print("Found call patterns:")
for pattern, count in call_patterns.items():
print(f" {pattern}: {count} occurrences")
-
+
# Create new API version
new_api = old_api.copy()
new_api.rename("create_user_v2")
-
+
# Update parameter types
new_api.get_parameter("email").type = "EmailStr"
new_api.get_parameter("role").type = "UserRole"
-
+
# Add new required parameters
new_api.add_parameter("tenant_id: UUID")
-
+
# Update all call sites
for call in old_api.call_sites:
# Get current arguments
email_arg = call.get_arg_by_parameter_name("email")
role_arg = call.get_arg_by_parameter_name("role")
-
+
# Build new argument list with type conversions
new_args = [
f"email=EmailStr({email_arg.value})",
f"role=UserRole({role_arg.value})",
"tenant_id=get_current_tenant_id()"
]
-
+
# Replace old call with new version
call.replace(f"create_user_v2({', '.join(new_args)})")
-
+
# Add deprecation notice to old version
old_api.add_decorator('@deprecated("Use create_user_v2 instead")')
@@ -8048,10 +8047,10 @@ migrate_api_v1_to_v2(codebase)
```python
# First update parameter names
param.rename("new_name")
-
+
# Then update types
param.type = "new_type"
-
+
# Finally update call sites
for call in api.call_sites:
# ... update calls
@@ -8061,7 +8060,7 @@ migrate_api_v1_to_v2(codebase)
```python
# Add new parameter with default
api.add_parameter("new_param: str = None")
-
+
# Later make it required
api.get_parameter("new_param").remove_default()
```
@@ -8076,7 +8075,7 @@ migrate_api_v1_to_v2(codebase)
Remember to test thoroughly after making bulk changes to APIs. While Codegen ensures syntactic correctness, you'll want to verify the semantic correctness of the changes.
-
+
---
title: "Organizing Your Codebase"
@@ -8640,16 +8639,16 @@ from collections import defaultdict
# Create a graph of file dependencies
def create_dependency_graph():
G = nx.DiGraph()
-
+
for file in codebase.files:
# Add node for this file
G.add_node(file.filepath)
-
+
# Add edges for each import
for imp in file.imports:
if imp.from_file: # Skip external imports
G.add_edge(file.filepath, imp.from_file.filepath)
-
+
return G
# Create and analyze the graph
@@ -8678,18 +8677,18 @@ def break_circular_dependency(cycle):
# Get the first two files in the cycle
file1 = codebase.get_file(cycle[0])
file2 = codebase.get_file(cycle[1])
-
+
# Create a shared module for common code
shared_dir = "shared"
if not codebase.has_directory(shared_dir):
codebase.create_directory(shared_dir)
-
+
# Find symbols used by both files
shared_symbols = []
for symbol in file1.symbols:
if any(usage.file == file2 for usage in symbol.usages):
shared_symbols.append(symbol)
-
+
# Move shared symbols to a new file
if shared_symbols:
shared_file = codebase.create_file(f"{shared_dir}/shared_types.py")
@@ -8711,7 +8710,7 @@ def organize_file_imports(file):
std_lib_imports = []
third_party_imports = []
local_imports = []
-
+
for imp in file.imports:
if imp.is_standard_library:
std_lib_imports.append(imp)
@@ -8719,29 +8718,29 @@ def organize_file_imports(file):
third_party_imports.append(imp)
else:
local_imports.append(imp)
-
+
# Sort each group
for group in [std_lib_imports, third_party_imports, local_imports]:
group.sort(key=lambda x: x.module_name)
-
+
# Remove all existing imports
for imp in file.imports:
imp.remove()
-
+
# Add imports back in organized groups
if std_lib_imports:
for imp in std_lib_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
-
+
if third_party_imports:
for imp in third_party_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
file.insert_after_imports("") # Add newline
-
+
if local_imports:
for imp in local_imports:
- file.add_import_from_import_string(imp.source)
+ file.add_import(imp.source)
# Organize imports in all files
for file in codebase.files:
@@ -8757,22 +8756,22 @@ from collections import defaultdict
def analyze_module_coupling():
coupling_scores = defaultdict(int)
-
+
for file in codebase.files:
# Count unique files imported from
imported_files = {imp.from_file for imp in file.imports if imp.from_file}
coupling_scores[file.filepath] = len(imported_files)
-
+
# Count files that import this file
- importing_files = {usage.file for symbol in file.symbols
+ importing_files = {usage.file for symbol in file.symbols
for usage in symbol.usages if usage.file != file}
coupling_scores[file.filepath] += len(importing_files)
-
+
# Sort by coupling score
- sorted_files = sorted(coupling_scores.items(),
- key=lambda x: x[1],
+ sorted_files = sorted(coupling_scores.items(),
+ key=lambda x: x[1],
reverse=True)
-
+
print("\nš Module Coupling Analysis:")
print("\nMost coupled files:")
for filepath, score in sorted_files[:5]:
@@ -8790,9 +8789,9 @@ def extract_shared_code(file, min_usages=3):
# Find symbols used by multiple files
for symbol in file.symbols:
# Get unique files using this symbol
- using_files = {usage.file for usage in symbol.usages
+ using_files = {usage.file for usage in symbol.usages
if usage.file != file}
-
+
if len(using_files) >= min_usages:
# Create appropriate shared module
module_name = determine_shared_module(symbol)
@@ -8800,7 +8799,7 @@ def extract_shared_code(file, min_usages=3):
shared_file = codebase.create_file(f"shared/{module_name}.py")
else:
shared_file = codebase.get_file(f"shared/{module_name}.py")
-
+
# Move symbol to shared module
symbol.move_to_file(shared_file, strategy="update_all_imports")
@@ -8854,7 +8853,7 @@ if feature_flag_class:
# Initialize usage count for all attributes
for attr in feature_flag_class.attributes:
feature_flag_usage[attr.name] = 0
-
+
# Get all usages of the FeatureFlag class
for usage in feature_flag_class.usages:
usage_source = usage.usage_symbol.source if hasattr(usage, 'usage_symbol') else str(usage)
@@ -9599,7 +9598,7 @@ Let's break down how this works:
if export.is_reexport() and export.is_default_export():
print(f" š Converting default export '{export.name}'")
```
-
+
The code identifies default exports by checking:
1. If it's a re-export (`is_reexport()`)
2. If it's a default export (`is_default_export()`)
@@ -9707,7 +9706,7 @@ for file in codebase.files:
print(f"⨠Fixed exports in {target_file.filepath}")
-```
+```
---
title: "Creating Documentation"
@@ -9796,11 +9795,11 @@ for directory in codebase.directories:
# Skip test, sql and alembic directories
if any(x in directory.path.lower() for x in ['test', 'sql', 'alembic']):
continue
-
+
# Get undecorated functions
funcs = [f for f in directory.functions if not f.is_decorated]
total = len(funcs)
-
+
# Only analyze dirs with >10 functions
if total > 10:
documented = sum(1 for f in funcs if f.docstring)
@@ -9815,12 +9814,12 @@ for directory in codebase.directories:
if dir_stats:
lowest_dir = min(dir_stats.items(), key=lambda x: x[1]['coverage'])
path, stats = lowest_dir
-
+
print(f"š Lowest coverage directory: '{path}'")
print(f" ⢠Total functions: {stats['total']}")
print(f" ⢠Documented: {stats['documented']}")
print(f" ⢠Coverage: {stats['coverage']:.1f}%")
-
+
# Print all directory stats for comparison
print("\nš All directory coverage rates:")
for path, stats in sorted(dir_stats.items(), key=lambda x: x[1]['coverage']):
@@ -10008,7 +10007,7 @@ const {class_def.name} = ({class_def.get_method("render").parameters[0].name}) =
# Add required imports
file = class_def.file
if not any("useState" in imp.source for imp in file.imports):
- file.add_import_from_import_string("import { useState, useEffect } from 'react';")
+ file.add_import("import { useState, useEffect } from 'react';")
```
## Migrating to Modern Hooks
@@ -10026,7 +10025,7 @@ for function in codebase.functions:
# Convert withRouter to useNavigate
if call.name == "withRouter":
# Add useNavigate import
- function.file.add_import_from_import_string(
+ function.file.add_import(
"import { useNavigate } from 'react-router-dom';"
)
# Add navigate hook
@@ -10608,7 +10607,7 @@ iconType: "solid"
-Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.
+Import loops occur when two or more Python modules depend on each other, creating a circular dependency. While some import cycles can be harmless, others can lead to runtime errors and make code harder to maintain.
In this tutorial, we'll explore how to identify and fix problematic import cycles using Codegen.
@@ -11244,7 +11243,7 @@ FastAPI handles static files differently than Flask. We need to add the StaticFi
```python
# Add StaticFiles import
-file.add_import_from_import_string("from fastapi.staticfiles import StaticFiles")
+file.add_import("from fastapi.staticfiles import StaticFiles")
# Mount static directory
file.add_symbol_from_source(
@@ -11505,10 +11504,10 @@ Match (s: Func )-[r: CALLS]-> (e:Func) RETURN s, e LIMIT 10
```cypher
Match path = (:(Method|Func)) -[:CALLS*5..10]-> (:(Method|Func))
-Return path
+Return path
LIMIT 20
```
-
\ No newline at end of file
+
diff --git a/src/codegen/sdk/typescript/symbol.py b/src/codegen/sdk/typescript/symbol.py
index fc41d1ee7..e3cc89828 100644
--- a/src/codegen/sdk/typescript/symbol.py
+++ b/src/codegen/sdk/typescript/symbol.py
@@ -283,9 +283,9 @@ def _move_to_file(
# =====[ Imports - copy over ]=====
elif isinstance(dep, TSImport):
if dep.imported_symbol:
- file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type)
+ file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type)
else:
- file.add_import_from_import_string(dep.source)
+ file.add_import(dep.source)
else:
msg = f"Unknown dependency type {type(dep)}"
@@ -301,7 +301,7 @@ def _move_to_file(
# =====[ Symbols - move over ]=====
elif isinstance(dep, Symbol) and dep.is_top_level:
- file.add_symbol_import(symbol=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias))
+ file.add_import(imp=dep, alias=dep.name, import_type=ImportType.NAMED_EXPORT, is_type_import=isinstance(dep, TypeAlias))
if not dep.is_exported:
dep.file.add_export_to_symbol(dep)
@@ -310,9 +310,9 @@ def _move_to_file(
# =====[ Imports - copy over ]=====
elif isinstance(dep, TSImport):
if dep.imported_symbol:
- file.add_symbol_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import())
+ file.add_import(dep.imported_symbol, alias=dep.alias.source, import_type=dep.import_type, is_type_import=dep.is_type_import())
else:
- file.add_import_from_import_string(dep.source)
+ file.add_import(dep.source)
except Exception as e:
print(f"Failed to move dependencies of {self.name}: {e}")
@@ -336,12 +336,12 @@ def _move_to_file(
# Here, we will add a "back edge" to the old file importing the self
elif strategy == "add_back_edge":
if is_used_in_file:
- self.file.add_import_from_import_string(import_line)
+ self.file.add_import(import_line)
if self.is_exported:
- self.file.add_import_from_import_string(f"export {{ {self.name} }}")
+ self.file.add_import(f"export {{ {self.name} }}")
elif self.is_exported:
module_name = file.name
- self.file.add_import_from_import_string(f"export {{ {self.name} }} from '{module_name}'")
+ self.file.add_import(f"export {{ {self.name} }} from '{module_name}'")
# Delete the original symbol
self.remove()
@@ -352,7 +352,7 @@ def _move_to_file(
if isinstance(usage.usage_symbol, TSImport):
# Add updated import
if usage.usage_symbol.resolved_symbol is not None and usage.usage_symbol.resolved_symbol.node_type == NodeType.SYMBOL and usage.usage_symbol.resolved_symbol == self:
- usage.usage_symbol.file.add_import_from_import_string(import_line)
+ usage.usage_symbol.file.add_import(import_line)
usage.usage_symbol.remove()
elif usage.usage_type == UsageType.CHAINED:
# Update all previous usages of import * to the new import name
@@ -361,9 +361,9 @@ def _move_to_file(
usage.match.get_name().edit(self.name)
if isinstance(usage.match, ChainedAttribute):
usage.match.edit(self.name)
- usage.usage_symbol.file.add_import_from_import_string(import_line)
+ usage.usage_symbol.file.add_import(import_line)
if is_used_in_file:
- self.file.add_import_from_import_string(import_line)
+ self.file.add_import(import_line)
# Delete the original symbol
self.remove()
@@ -377,7 +377,7 @@ def _convert_proptype_to_typescript(self, prop_type: Editable, param: Parameter
if prop_type.attribute.source == "node":
return "T"
if prop_type.attribute.source == "element":
- self.file.add_import_from_import_string("import React from 'react';\n")
+ self.file.add_import("import React from 'react';\n")
return "React.ReactElement"
if prop_type.attribute.source in type_map:
return type_map[prop_type.attribute.source]
@@ -476,7 +476,7 @@ def convert_to_react_interface(self) -> str | None:
if "PropTypes.node" in proptypes.source:
generics = ""
generic_name = ""
- self.file.add_import_from_import_string("import React from 'react';\n")
+ self.file.add_import("import React from 'react';\n")
interface_name = f"{component_name}Props"
# Create interface definition
interface_def = f"interface {interface_name}{generics} {self._convert_dict(proptypes, 1)}"
diff --git a/src/codegen/sdk/utils.py b/src/codegen/sdk/utils.py
index 913782f1f..7476e6e8a 100644
--- a/src/codegen/sdk/utils.py
+++ b/src/codegen/sdk/utils.py
@@ -104,8 +104,6 @@ def find_import_node(node: TSNode) -> TSNode | None:
# we only parse imports inside expressions and variable declarations
- # import_nodes = [_node for _node in find_all_descendants(node, ["call_expression", "statement_block"], nested=False) if _node.type == "call_expression"]
-
if member_expression := find_first_descendant(node, ["member_expression"]):
# there may be multiple call expressions (for cases such as import(a).then(module => module).then(module => module)
descendants = find_all_descendants(member_expression, ["call_expression"], stop_at_first="statement_block")
diff --git a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py b/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py
index cfeb7ca5f..5cb33414b 100644
--- a/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py
+++ b/src/codemods/canonical/add_function_parameter_type_annotations/add_function_parameter_type_annotations.py
@@ -48,4 +48,4 @@ def execute(self, codebase: Codebase) -> None:
# Ensure the necessary import is present
file = function.file
if "SessionLocal" not in [imp.name for imp in file.imports]:
- file.add_import_from_import_string("from app.db import SessionLocal")
+ file.add_import("from app.db import SessionLocal")
diff --git a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py
index caa40b799..de97853a6 100644
--- a/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py
+++ b/src/codemods/canonical/change_component_tag_names/change_component_tag_names.py
@@ -56,4 +56,4 @@ def execute(self, codebase: Codebase):
element.set_name("PrivateRoutesContainer")
# Add the import if it doesn't exist
if not file.has_import("PrivateRoutesContainer"):
- file.add_symbol_import(PrivateRoutesContainer)
+ file.add_import(PrivateRoutesContainer)
diff --git a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py b/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py
index afc67d40d..a0fecf515 100644
--- a/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py
+++ b/src/codemods/canonical/convert_attribute_to_decorator/convert_attribute_to_decorator.py
@@ -29,7 +29,7 @@ class MySession(SessionInterface):
...
That is, it deletes the attribute and adds the appropriate decorator via the `cls.add_decorator` method.
- Note that `cls.file.add_import_from_import_string(import_str)` is the method used to add import for the decorator.
+ Note that `cls.file.add_import(import_str)` is the method used to add import for the decorator.
"""
language = ProgrammingLanguage.PYTHON
@@ -51,7 +51,7 @@ def execute(self, codebase: Codebase) -> None:
decorator_name = attr_value_to_decorator[attribute.right.source]
# Import the necessary decorators
required_import = f"from src.flask.sessions import {decorator_name}"
- cls.file.add_import_from_import_string(required_import)
+ cls.file.add_import(required_import)
# Add the appropriate decorator
cls.add_decorator(f"@{decorator_name}")
diff --git a/src/codemods/canonical/pivot_return_types/pivot_return_types.py b/src/codemods/canonical/pivot_return_types/pivot_return_types.py
index 367e40ad2..aeb1cdee8 100644
--- a/src/codemods/canonical/pivot_return_types/pivot_return_types.py
+++ b/src/codemods/canonical/pivot_return_types/pivot_return_types.py
@@ -41,7 +41,7 @@ def execute(self, codebase: Codebase) -> None:
function.set_return_type("FastStr")
# Add import for 'FastStr' if it doesn't exist
- function.file.add_import_from_import_string("from app.models.fast_str import FastStr")
+ function.file.add_import("from app.models.fast_str import FastStr")
# Modify all return statements within the function
for return_stmt in function.code_block.return_statements:
diff --git a/src/codemods/canonical/split_large_files/split_large_files.py b/src/codemods/canonical/split_large_files/split_large_files.py
index c1e3e295e..33f846421 100644
--- a/src/codemods/canonical/split_large_files/split_large_files.py
+++ b/src/codemods/canonical/split_large_files/split_large_files.py
@@ -46,4 +46,4 @@ def execute(self, codebase: Codebase):
# Move the symbol to the new file
symbol.move_to_file(new_file)
# Add a back edge to the original file
- file.add_symbol_import(symbol)
+ file.add_import(symbol)
diff --git a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py b/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py
index ce4d8235f..14c4e96f4 100644
--- a/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py
+++ b/src/codemods/canonical/swap_call_site_imports/swap_call_site_imports.py
@@ -60,4 +60,4 @@ def execute(self, codebase: Codebase) -> None:
legacy_function.remove()
# Add import of the new function
- call_site.file.add_import_from_import_string(f"from settings.collections import {legacy_function.name}")
+ call_site.file.add_import(f"from settings.collections import {legacy_function.name}")
diff --git a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py b/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py
index 1a9bcb8bf..60c520986 100644
--- a/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py
+++ b/src/codemods/canonical/swap_class_attribute_usages/swap_class_attribute_usages.py
@@ -42,7 +42,7 @@ def execute(self, codebase: Codebase) -> None:
class_a_param.edit("cache_config: CacheConfig")
# Add import of `CacheConfig` to function definition file
- function.file.add_symbol_import(class_b_symb)
+ function.file.add_import(class_b_symb)
# Check if the function body is using `cache_config`
if len(function.code_block.get_variable_usages(class_a_param.name)) > 0:
diff --git a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py b/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py
index 2a7699df0..d1ceea089 100644
--- a/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py
+++ b/src/codemods/canonical/update_optional_type_annotations/update_optional_type_annotations.py
@@ -51,5 +51,5 @@ def update_type_annotation(type: Type) -> str:
new_type = update_type_annotation(parameter.type)
if parameter.type != new_type:
# Add the future annotations import
- file.add_import_from_import_string("from __future__ import annotations\n")
+ file.add_import("from __future__ import annotations\n")
parameter.type.edit(new_type)
diff --git a/src/codemods/canonical/wrap_with_component/wrap_with_component.py b/src/codemods/canonical/wrap_with_component/wrap_with_component.py
index 715b98216..b1bed4cc1 100644
--- a/src/codemods/canonical/wrap_with_component/wrap_with_component.py
+++ b/src/codemods/canonical/wrap_with_component/wrap_with_component.py
@@ -48,4 +48,4 @@ def execute(self, codebase: Codebase) -> None:
element.edit(f"{element.source}")
# Add an import for the Alert component
- file.add_symbol_import(alert)
+ file.add_import(alert)
diff --git a/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py b/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py
index f5dfe151d..6a74ff2e4 100644
--- a/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py
+++ b/tests/unit/codegen/sdk/python/autocommit/test_autocommit.py
@@ -141,7 +141,7 @@ def a():
autocommit = codebase.ctx._autocommit
file1 = codebase.get_file(file1_name)
fun = file1.get_function("a")
- file1.add_import_from_import_string("import os")
+ file1.add_import("import os")
assert fun.node_id not in autocommit._nodes
if edit_block:
block = fun.code_block
@@ -200,7 +200,7 @@ def a(a: int):
param = fun.parameters[0]
assert fun.node_id not in autocommit._nodes
param.edit("try_to_break_this: str")
- file1.add_import_from_import_string("import os")
+ file1.add_import("import os")
assert fun.node_id in autocommit._nodes
if edit_block:
block = fun.code_block
@@ -230,7 +230,7 @@ def b(a: int):
param = fun.parameters[0]
assert fun.node_id not in autocommit._nodes
param.edit("try_to_break_this: str")
- file1.add_import_from_import_string("import os")
+ file1.add_import("import os")
assert fun.node_id in autocommit._nodes
block = funb.code_block
block.insert_before("a", fix_indentation=True)
diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_import.py b/tests/unit/codegen/sdk/python/file/test_file_add_import.py
new file mode 100644
index 000000000..f0e353d81
--- /dev/null
+++ b/tests/unit/codegen/sdk/python/file/test_file_add_import.py
@@ -0,0 +1,276 @@
+import pytest
+
+from codegen.sdk.codebase.factory.get_session import get_codebase_session
+from codegen.shared.enums.programming_language import ProgrammingLanguage
+
+
+def test_file_add_symbol_import_updates_source(tmpdir) -> None:
+ # language=python
+ content1 = """
+import datetime
+
+def foo():
+ return datetime.datetime.now()
+"""
+
+ # language=python
+ content2 = """
+def bar():
+ return 1
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"file1.py": content1, "file2.py": content2}) as codebase:
+ file1 = codebase.get_file("file1.py")
+ file2 = codebase.get_file("file2.py")
+
+ file2.add_import(file1.get_symbol("foo"))
+
+ assert "import foo" in file2.content
+
+
+def test_file_add_import_string_no_imports_adds_to_top(tmpdir) -> None:
+ # language=python
+ content = """
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from sqlalchemy.orm import Session")
+
+ file_lines = file.content.split("\n")
+ assert "from sqlalchemy.orm import Session" in file_lines[0]
+
+
+def test_file_add_import_string_adds_before_first_import(tmpdir) -> None:
+ # language=python
+ content = """
+# top level comment
+
+# adds new import here
+from typing import List
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from sqlalchemy.orm import Session")
+
+ file_lines = file.content.split("\n")
+ assert "from sqlalchemy.orm import Session" in file_lines
+ assert file_lines.index("from sqlalchemy.orm import Session") == file_lines.index("from typing import List") - 1
+
+
+@pytest.mark.parametrize("sync", [True, False])
+def test_file_add_import_string_adds_remove(tmpdir, sync) -> None:
+ # language=python
+ content = """import b
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content.strip()}, sync_graph=sync) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("import antigravity")
+ file.remove()
+ if sync:
+ assert not codebase.get_file(file.filepath, optional=True)
+
+
+def test_file_add_import_typescript_string_adds_before_first_import(tmpdir) -> None:
+ # language=typescript
+ content = """
+// top level comment
+
+// existing imports below
+import { Something } from 'somewhere'
+
+function bar(): number {
+ return 1;
+}
+ """
+ with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase:
+ file = codebase.get_file("test.ts")
+
+ file.add_import("import { NewThing } from 'elsewhere'")
+
+ file_lines = file.content.split("\n")
+ assert "import { NewThing } from 'elsewhere'" in file_lines
+ assert file_lines.index("import { NewThing } from 'elsewhere'") < file_lines.index("import { Something } from 'somewhere'")
+
+
+def test_file_add_import_typescript_string_no_imports_adds_to_top(tmpdir) -> None:
+ # language=typescript
+ content = """
+ function bar(): number {
+ return 1;
+ }
+ """
+ with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase:
+ file = codebase.get_file("test.ts")
+
+ file.add_import("import { Something } from 'somewhere';")
+
+ file_lines = file.content.split("\n")
+ assert "import { Something } from 'somewhere';" in file_lines[0]
+
+
+def test_file_add_import_typescript_multiple_symbols(tmpdir) -> None:
+ FILE1_FILENAME = "file1.ts"
+ FILE2_FILENAME = "file2.ts"
+
+ # language=typescript
+ FILE1_CONTENT = """
+ export function foo(): string {
+ return 'foo';
+ }
+
+ export function bar(): string {
+ return 'bar';
+ }
+ """
+
+ # language=typescript
+ FILE2_CONTENT = """
+ function test(): number {
+ return 1;
+ }
+ """
+ with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={FILE1_FILENAME: FILE1_CONTENT, FILE2_FILENAME: FILE2_CONTENT}) as codebase:
+ file1 = codebase.get_file(FILE1_FILENAME)
+ file2 = codebase.get_file(FILE2_FILENAME)
+
+ # Add multiple symbols one after another
+ file2.add_import(file1.get_symbol("foo"))
+ file2.add_import(file1.get_symbol("bar"))
+
+ # Updated assertion to check for separate imports since that's the current behavior
+ assert "import { foo } from 'file1';" in file2.content
+ assert "import { bar } from 'file1';" in file2.content
+
+
+def test_file_add_import_typescript_default_import(tmpdir) -> None:
+ # language=typescript
+ content = """
+ function bar(): number {
+ return 1;
+ }
+ """
+ with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={"test.ts": content}) as codebase:
+ file = codebase.get_file("test.ts")
+
+ file.add_import("import React from 'react';")
+ file.add_import("import { useState } from 'react';")
+
+ file_lines = file.content.split("\n")
+ assert "import React from 'react';" in file_lines
+ assert "import { useState } from 'react';" in file_lines
+
+
+def test_file_add_import_typescript_duplicate_prevention(tmpdir) -> None:
+ FILE1_FILENAME = "file1.ts"
+ FILE2_FILENAME = "file2.ts"
+
+ # language=typescript
+ FILE1_CONTENT = """
+ export function foo(): string {
+ return 'foo';
+ }
+ """
+
+ # language=typescript
+ FILE2_CONTENT = """
+ import { foo } from 'file1';
+
+ function test(): string {
+ return foo();
+ }
+ """
+ with get_codebase_session(tmpdir=tmpdir, programming_language=ProgrammingLanguage.TYPESCRIPT, files={FILE1_FILENAME: FILE1_CONTENT, FILE2_FILENAME: FILE2_CONTENT}) as codebase:
+ file1 = codebase.get_file(FILE1_FILENAME)
+ file2 = codebase.get_file(FILE2_FILENAME)
+
+ # Try to add the same import again
+ file2.add_import(file1.get_symbol("foo"))
+
+ # Verify no duplicate import was added
+ assert file2.content.count("import { foo }") == 1
+
+
+def test_file_add_import_string_adds_after_future(tmpdir) -> None:
+ # language=python
+ content = """
+from __future__ import annotations
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from sqlalchemy.orm import Session")
+
+ file_lines = file.content.split("\n")
+ assert "from __future__ import annotations" in file_lines[1]
+ assert "from sqlalchemy.orm import Session" in file_lines[2]
+
+
+def test_file_add_import_string_adds_after_last_future(tmpdir) -> None:
+ # language=python
+ content = """
+from __future__ import annotations
+from __future__ import division
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from sqlalchemy.orm import Session")
+
+ file_lines = file.content.split("\n")
+ assert "from __future__ import annotations" in file_lines[1]
+ assert "from __future__ import division" in file_lines[2]
+ assert "from sqlalchemy.orm import Session" in file_lines[3]
+
+
+def test_file_add_import_string_adds_after_future_before_non_future(tmpdir) -> None:
+ # language=python
+ content = """
+from __future__ import annotations
+from typing import List
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from sqlalchemy.orm import Session")
+
+ file_lines = file.content.split("\n")
+ assert "from __future__ import annotations" in file_lines[1]
+ assert "from sqlalchemy.orm import Session" in file_lines[2]
+ assert "from typing import List" in file_lines[3]
+
+
+def test_file_add_import_string_future_import_adds_to_top(tmpdir) -> None:
+ # language=python
+ content = """
+from __future__ import annotations
+
+def foo():
+ print("this is foo")
+"""
+ with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
+ file = codebase.get_file("test.py")
+
+ file.add_import("from __future__ import division")
+
+ file_lines = file.content.split("\n")
+ assert "from __future__ import division" in file_lines[1]
+ assert "from __future__ import annotations" in file_lines[2]
diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py b/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py
index 1d905332a..089a4849d 100644
--- a/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py
+++ b/tests/unit/codegen/sdk/python/file/test_file_add_import_from_import_string.py
@@ -14,7 +14,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from sqlalchemy.orm import Session")
+ file.add_import("from sqlalchemy.orm import Session")
file_lines = file.content.split("\n")
assert "from __future__ import annotations" in file_lines[1]
@@ -33,7 +33,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from sqlalchemy.orm import Session")
+ file.add_import("from sqlalchemy.orm import Session")
file_lines = file.content.split("\n")
assert "from __future__ import annotations" in file_lines[1]
@@ -53,7 +53,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from sqlalchemy.orm import Session")
+ file.add_import("from sqlalchemy.orm import Session")
file_lines = file.content.split("\n")
assert "from __future__ import annotations" in file_lines[1]
@@ -72,7 +72,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from __future__ import division")
+ file.add_import("from __future__ import division")
file_lines = file.content.split("\n")
assert "from __future__ import division" in file_lines[1]
@@ -88,7 +88,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from sqlalchemy.orm import Session")
+ file.add_import("from sqlalchemy.orm import Session")
file_lines = file.content.split("\n")
assert "from sqlalchemy.orm import Session" in file_lines[0]
@@ -108,7 +108,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("from sqlalchemy.orm import Session")
+ file.add_import("from sqlalchemy.orm import Session")
file_lines = file.content.split("\n")
assert "from sqlalchemy.orm import Session" in file_lines
@@ -126,7 +126,7 @@ def foo():
with get_codebase_session(tmpdir=tmpdir, files={"test.py": content.strip()}, sync_graph=sync) as codebase:
file = codebase.get_file("test.py")
- file.add_import_from_import_string("import antigravity")
+ file.add_import("import antigravity")
file.remove()
if sync:
assert not codebase.get_file(file.filepath, optional=True)
diff --git a/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py b/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py
deleted file mode 100644
index 7088f2c1e..000000000
--- a/tests/unit/codegen/sdk/python/file/test_file_add_symbol_import.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from codegen.sdk.codebase.factory.get_session import get_codebase_session
-
-
-def test_file_add_symbol_import_updates_source(tmpdir) -> None:
- # language=python
- content1 = """
-import datetime
-
-def foo():
- return datetime.datetime.now()
-"""
-
- # language=python
- content2 = """
-def bar():
- return 1
-"""
- with get_codebase_session(tmpdir=tmpdir, files={"file1.py": content1, "file2.py": content2}) as codebase:
- file1 = codebase.get_file("file1.py")
- file2 = codebase.get_file("file2.py")
-
- file2.add_symbol_import(file1.get_symbol("foo"))
-
- assert "import foo" in file2.content
diff --git a/tests/unit/codegen/sdk/python/file/test_file_reparse.py b/tests/unit/codegen/sdk/python/file/test_file_reparse.py
index 7f14c79c4..b4314f9f2 100644
--- a/tests/unit/codegen/sdk/python/file/test_file_reparse.py
+++ b/tests/unit/codegen/sdk/python/file/test_file_reparse.py
@@ -98,7 +98,7 @@ def test_file_reparse_move_global_var(mock_codebase_setup: tuple[Codebase, File,
global_var1.remove()
global_var2 = file2.get_global_var("GLOBAL_CONSTANT_2")
global_var2.insert_before(global_var1.source)
- file1.add_symbol_import(global_var1)
+ file1.add_import(global_var1)
# Remove the import to GLOBAL_CONSTANT_1 from file2
imp_to_remove = file2.get_import("GLOBAL_CONSTANT_1")
diff --git a/tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py b/tests/unit/codegen/sdk/typescript/file/test_file_add_import.py
similarity index 94%
rename from tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py
rename to tests/unit/codegen/sdk/typescript/file/test_file_add_import.py
index 81115ce0b..40fa1ca4f 100644
--- a/tests/unit/codegen/sdk/typescript/file/test_file_add_symbol_import.py
+++ b/tests/unit/codegen/sdk/typescript/file/test_file_add_import.py
@@ -25,6 +25,6 @@ def test_file_add_symbol_import_updates_source(tmpdir) -> None:
file1 = codebase.get_file(FILE1_FILENAME)
file2 = codebase.get_file(FILE2_FILENAME)
- file2.add_symbol_import(file1.get_symbol("foo"))
+ file2.add_import(file1.get_symbol("foo"))
assert "import { foo } from 'file1';" in file2.content
diff --git a/tests/unit/skills/implementations/decorator_skills.py b/tests/unit/skills/implementations/decorator_skills.py
index d586b0464..f74877deb 100644
--- a/tests/unit/skills/implementations/decorator_skills.py
+++ b/tests/unit/skills/implementations/decorator_skills.py
@@ -54,7 +54,7 @@ def python_skill_func(codebase: CodebaseType):
# if the file does not have the decorator symbol and the decorator symbol is not in the same file
if not file.has_import(decorator_symbol.name) and decorator_symbol.file != file:
# import the decorator symbol
- file.add_symbol_import(decorator_symbol)
+ file.add_import(decorator_symbol)
# iterate through each function in the file
for function in file.functions:
diff --git a/tests/unit/skills/implementations/eval_skills.py b/tests/unit/skills/implementations/eval_skills.py
index 0a25fa376..99e0b65ac 100644
--- a/tests/unit/skills/implementations/eval_skills.py
+++ b/tests/unit/skills/implementations/eval_skills.py
@@ -84,7 +84,7 @@ def python_skill_func(codebase: CodebaseType):
# if the decorator is not imported or declared in the file
if not file.has_import("decorator_function") and decorator_symbol.file != file:
# add an import for the decorator function
- file.add_symbol_import(decorator_symbol)
+ file.add_import(decorator_symbol)
# add the decorator to the function
function.add_decorator(f"@{decorator_symbol.name}")
@@ -370,7 +370,7 @@ def typescript_skill_func(codebase: CodebaseType):
# if the file does not exist create it
new_file = codebase.create_file(str(new_file_path))
# add an import for React
- new_file.add_import_from_import_string('import React from "react";')
+ new_file.add_import('import React from "react";')
# move the component to the new file
component.move_to_file(new_file)
diff --git a/tests/unit/skills/implementations/example_skills.py b/tests/unit/skills/implementations/example_skills.py
index aa122000e..e0c025b88 100644
--- a/tests/unit/skills/implementations/example_skills.py
+++ b/tests/unit/skills/implementations/example_skills.py
@@ -141,13 +141,13 @@ def python_skill_func(codebase: CodebaseType):
for file in codebase.files:
for function in file.functions:
if function.name.startswith("test_"):
- file.add_import_from_import_string("import pytest")
+ file.add_import("import pytest")
function.add_decorator('@pytest.mark.skip(reason="This is a test")')
for cls in file.classes:
for method in cls.methods:
if method.name.startswith("test_"):
- file.add_import_from_import_string("import pytest")
+ file.add_import("import pytest")
method.add_decorator('@pytest.mark.skip(reason="This is a test")')
@staticmethod
@@ -181,7 +181,7 @@ def python_skill_func(codebase: CodebaseType):
function.set_return_type("None")
else:
function.set_return_type("Any")
- function.file.add_import_from_import_string("from typing import Any")
+ function.file.add_import("from typing import Any")
for param in function.parameters:
if not param.is_typed:
@@ -191,7 +191,7 @@ def python_skill_func(codebase: CodebaseType):
param.set_type_annotation("str")
else:
param.set_type_annotation("Any")
- function.file.add_import_from_import_string("from typing import Any")
+ function.file.add_import("from typing import Any")
@staticmethod
@skill_impl(test_cases=[], skip_test=True, language=ProgrammingLanguage.TYPESCRIPT)
diff --git a/tests/unit/skills/implementations/guides/increase-type-coverage.py b/tests/unit/skills/implementations/guides/increase-type-coverage.py
index a84b09b74..f04d74288 100644
--- a/tests/unit/skills/implementations/guides/increase-type-coverage.py
+++ b/tests/unit/skills/implementations/guides/increase-type-coverage.py
@@ -318,7 +318,7 @@ def python_skill_func(codebase: CodebaseType):
# import c from module
c = codebase.get_file("path/to/module.py").get_symbol("c")
- target_file.add_symbol_import(c)
+ target_file.add_import(c)
# Add a new option to the return type
function.return_type.append("c")
@@ -331,7 +331,7 @@ def typescript_skill_func(codebase: CodebaseType):
function = target_file.get_function("functionName")
# function functionName(): a | b: ...
c = codebase.get_file("path/to/module.ts").get_symbol("c")
- target_file.add_symbol_import(c)
+ target_file.add_import(c)
# Add a new option to the return type
function.return_type.append("c")