Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .codegen/codemods/test_language/test_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import codegen
from codegen.sdk.core.codebase import Codebase
from codegen.shared.enums.programming_language import ProgrammingLanguage


@codegen.function("test-language", subdirectories=["src/codegen/cli"], language=ProgrammingLanguage.PYTHON)
def run(codebase: Codebase):
file = codebase.get_file("src/codegen/cli/errors.py")
print(f"File: {file.path}")
for s in file.symbols:
print(s.name)


if __name__ == "__main__":
print("Parsing codebase...")
codebase = Codebase("./")

print("Running...")
run(codebase)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ dependencies = [
"langchain_core",
"langchain_openai",
"langgraph",
"langgraph-prebuilt",
"numpy>=2.2.2",
"mcp[cli]",
"neo4j",
Expand Down Expand Up @@ -239,6 +240,7 @@ DEP002 = [
"pip",
"python-levenshtein",
"pytest-snapshot",
"langgraph-prebuilt",
]
DEP003 = "sqlalchemy"
DEP004 = "pytest"
Expand Down
2 changes: 2 additions & 0 deletions src/codegen/cli/commands/list/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def list_command():
table.add_column("Type", style="magenta")
table.add_column("Path", style="dim")
table.add_column("Subdirectories", style="dim")
table.add_column("Language", style="dim")

for func in functions:
func_type = "Webhook" if func.lint_mode else "Function"
Expand All @@ -26,6 +27,7 @@ def list_command():
func_type,
str(func.filepath.relative_to(Path.cwd())) if func.filepath else "<unknown>",
", ".join(func.subdirectories) if func.subdirectories else "",
func.language or "",
)

rich.print(table)
Expand Down
8 changes: 6 additions & 2 deletions src/codegen/cli/commands/run/run_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from codegen.cli.utils.function_finder import DecoratedFunction
from codegen.git.repo_operator.repo_operator import RepoOperator
from codegen.git.schemas.repo_config import RepoConfig
from codegen.git.utils.language import determine_project_language
from codegen.sdk.codebase.config import ProjectConfig
from codegen.sdk.core.codebase import Codebase
from codegen.shared.enums.programming_language import ProgrammingLanguage


def parse_codebase(
repo_path: Path,
subdirectories: list[str] | None = None,
language: ProgrammingLanguage | None = None,
) -> Codebase:
"""Parse the codebase at the given root.

Expand All @@ -29,6 +32,7 @@ def parse_codebase(
ProjectConfig(
repo_operator=RepoOperator(repo_config=RepoConfig.from_repo_path(repo_path=repo_path)),
subdirectories=subdirectories,
programming_language=language or determine_project_language(repo_path),
)
]
)
Expand All @@ -48,8 +52,8 @@ def run_local(
diff_preview: Number of lines of diff to preview (None for all)
"""
# Parse codebase and run
with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} ...", spinner="dots") as status:
codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories)
with Status(f"[bold]Parsing codebase at {session.repo_path} with subdirectories {function.subdirectories or 'ALL'} and language {function.language or 'AUTO'} ...", spinner="dots") as status:
codebase = parse_codebase(repo_path=session.repo_path, subdirectories=function.subdirectories, language=function.language)
status.update("[bold green]✓ Parsed codebase")

status.update("[bold]Running codemod...")
Expand Down
8 changes: 6 additions & 2 deletions src/codegen/cli/sdk/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import wraps
from typing import Literal, ParamSpec, TypeVar, get_type_hints

from codegen.shared.enums.programming_language import ProgrammingLanguage

P = ParamSpec("P")
T = TypeVar("T")
WebhookType = Literal["pr", "push", "issue", "release"]
Expand All @@ -16,12 +18,14 @@ def __init__(
name: str,
*,
subdirectories: list[str] | None = None,
language: ProgrammingLanguage | None = None,
webhook_config: dict | None = None,
lint_mode: bool = False,
lint_user_whitelist: Sequence[str] | None = None,
):
self.name = name
self.subdirectories = subdirectories
self.language = language
self.func: Callable | None = None
self.params_type = None
self.webhook_config = webhook_config
Expand All @@ -44,7 +48,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return wrapper


def function(name: str, subdirectories: list[str] | None = None) -> DecoratedFunction:
def function(name: str, subdirectories: list[str] | None = None, language: ProgrammingLanguage | None = None) -> DecoratedFunction:
"""Decorator for codegen functions.

Args:
Expand All @@ -56,7 +60,7 @@ def run(codebase):
pass

"""
return DecoratedFunction(name=name, subdirectories=subdirectories)
return DecoratedFunction(name=name, subdirectories=subdirectories, language=language)


def webhook(
Expand Down
27 changes: 16 additions & 11 deletions src/codegen/cli/utils/function_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from pathlib import Path

from codegen.shared.enums.programming_language import ProgrammingLanguage


@dataclass
class DecoratedFunction:
Expand All @@ -15,6 +17,7 @@ class DecoratedFunction:
lint_mode: bool
lint_user_whitelist: list[str]
subdirectories: list[str] | None = None
language: ProgrammingLanguage | None = None
filepath: Path | None = None
parameters: list[tuple[str, str | None]] = dataclasses.field(default_factory=list)
arguments_type_schema: dict | None = None
Expand Down Expand Up @@ -98,6 +101,14 @@ def get_subdirectories(self, node: ast.Call) -> list[str] | None:
return ast.literal_eval(node.args[1])
return None

def get_language(self, node: ast.Call) -> ProgrammingLanguage | None:
keywords = {k.arg: k.value for k in node.keywords}
if "language" in keywords:
return ProgrammingLanguage(keywords["language"].attr)
if len(node.args) > 2:
return ast.literal_eval(node.args[2])
return None

def get_function_body(self, node: ast.FunctionDef) -> str:
"""Extract and unindent the function body."""
# Get the start and end positions of the function body
Expand Down Expand Up @@ -202,10 +213,6 @@ def visit_FunctionDef(self, node):
(isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Attribute) and self._has_codegen_root(decorator.func.value))
)
):
# Get the function name from the decorator argument
func_name = self.get_function_name(decorator)
subdirectories = self.get_subdirectories(decorator)

# Get additional metadata for webhook
lint_mode = decorator.func.attr == "webhook"
lint_user_whitelist = []
Expand All @@ -214,17 +221,15 @@ def visit_FunctionDef(self, node):
if keyword.arg == "users" and isinstance(keyword.value, ast.List):
lint_user_whitelist = [ast.literal_eval(elt).lstrip("@") for elt in keyword.value.elts]

# Get just the function body, unindented
body_source = self.get_function_body(node)
parameters = self.get_function_parameters(node)
self.functions.append(
DecoratedFunction(
name=func_name,
subdirectories=subdirectories,
source=body_source,
name=self.get_function_name(decorator),
subdirectories=self.get_subdirectories(decorator),
language=self.get_language(decorator),
source=self.get_function_body(node),
lint_mode=lint_mode,
lint_user_whitelist=lint_user_whitelist,
parameters=parameters,
parameters=self.get_function_parameters(node),
)
)

Expand Down
Loading