diff --git a/.codegen/config.toml b/.codegen/config.toml index c8b8658f8..d6f3a6363 100644 --- a/.codegen/config.toml +++ b/.codegen/config.toml @@ -3,8 +3,12 @@ github_token = "" openai_api_key = "" [repository] -organization_name = "codegen-sh" -repo_name = "codegen-sdk" +repo_path = "" +repo_name = "" +full_name = "" +user_name = "" +user_email = "" +language = "" [feature_flags.codebase] debug = false diff --git a/src/codegen/cli/api/client.py b/src/codegen/cli/api/client.py index 0150299d0..7225ed47a 100644 --- a/src/codegen/cli/api/client.py +++ b/src/codegen/cli/api/client.py @@ -34,7 +34,6 @@ LookupOutput, PRLookupInput, PRLookupResponse, - PRSchema, RunCodemodInput, RunCodemodOutput, RunOnPRInput, @@ -57,9 +56,9 @@ class RestAPI: _session: ClassVar[requests.Session] = requests.Session() - auth_token: str | None = None + auth_token: str - def __init__(self, auth_token: str | None = None): + def __init__(self, auth_token: str): self.auth_token = auth_token def _get_headers(self) -> dict[str, str]: @@ -133,11 +132,10 @@ def run( template_context: Context variables to pass to the codemod """ - session = CodegenSession() - + session = CodegenSession.from_active_session() base_input = { "codemod_name": function.name, - "repo_full_name": session.repo_name, + "repo_full_name": session.config.repository.full_name, "codemod_run_type": run_type, } @@ -158,13 +156,13 @@ def run( RunCodemodOutput, ) - def get_docs(self) -> dict: + def get_docs(self) -> DocsResponse: """Search documentation.""" - session = CodegenSession() + session = CodegenSession.from_active_session() return self._make_request( "GET", DOCS_ENDPOINT, - DocsInput(docs_input=DocsInput.BaseDocsInput(repo_full_name=session.repo_name)), + DocsInput(docs_input=DocsInput.BaseDocsInput(repo_full_name=session.config.repository.full_name)), DocsResponse, ) @@ -179,11 +177,12 @@ def ask_expert(self, query: str) -> AskExpertResponse: def create(self, name: str, query: str) -> CreateResponse: """Get AI-generated starter code for a codemod.""" - session = CodegenSession() + session = CodegenSession.from_active_session() + language = ProgrammingLanguage(session.config.repository.language) return self._make_request( "GET", CREATE_ENDPOINT, - CreateInput(input=CreateInput.BaseCreateInput(name=name, query=query, language=session.language)), + CreateInput(input=CreateInput.BaseCreateInput(name=name, query=query, language=language)), CreateResponse, ) @@ -197,10 +196,16 @@ def identify(self) -> IdentifyResponse | None: ) def deploy( - self, codemod_name: str, codemod_source: str, lint_mode: bool = False, lint_user_whitelist: list[str] | None = None, message: str | None = None, arguments_schema: dict | None = None + self, + codemod_name: str, + codemod_source: str, + lint_mode: bool = False, + lint_user_whitelist: list[str] | None = None, + message: str | None = None, + arguments_schema: dict | None = None, ) -> DeployResponse: """Deploy a codemod to the Modal backend.""" - session = CodegenSession() + session = CodegenSession.from_active_session() return self._make_request( "POST", DEPLOY_ENDPOINT, @@ -208,7 +213,7 @@ def deploy( input=DeployInput.BaseDeployInput( codemod_name=codemod_name, codemod_source=codemod_source, - repo_full_name=session.repo_name, + repo_full_name=session.config.repository.full_name, lint_mode=lint_mode, lint_user_whitelist=lint_user_whitelist or [], message=message, @@ -220,11 +225,11 @@ def deploy( def lookup(self, codemod_name: str) -> LookupOutput: """Look up a codemod by name.""" - session = CodegenSession() + session = CodegenSession.from_active_session() return self._make_request( "GET", LOOKUP_ENDPOINT, - LookupInput(input=LookupInput.BaseLookupInput(codemod_name=codemod_name, repo_full_name=session.repo_name)), + LookupInput(input=LookupInput.BaseLookupInput(codemod_name=codemod_name, repo_full_name=session.config.repository.full_name)), LookupOutput, ) @@ -244,7 +249,7 @@ def run_on_pr(self, codemod_name: str, repo_full_name: str, github_pr_number: in RunOnPRResponse, ) - def lookup_pr(self, repo_full_name: str, github_pr_number: int) -> PRSchema: + def lookup_pr(self, repo_full_name: str, github_pr_number: int) -> PRLookupResponse: """Look up a PR by repository and PR number.""" return self._make_request( "GET", diff --git a/src/codegen/cli/auth/auth_session.py b/src/codegen/cli/auth/auth_session.py new file mode 100644 index 000000000..e2984f470 --- /dev/null +++ b/src/codegen/cli/auth/auth_session.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass +from pathlib import Path + +from codegen.cli.api.client import RestAPI +from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.token_manager import get_current_token +from codegen.cli.errors import AuthError, NoTokenError + + +@dataclass +class User: + full_name: str + email: str + github_username: str + + +@dataclass +class Identity: + token: str + expires_at: str + status: str + user: "User" + + +class CodegenAuthenticatedSession(CodegenSession): + """Represents an authenticated codegen session with user and repository context""" + + # =====[ Instance attributes ]===== + _token: str | None = None + + # =====[ Lazy instance attributes ]===== + _identity: Identity | None = None + + def __init__(self, token: str | None = None, repo_path: Path | None = None): + # TODO: fix jank. + # super().__init__(repo_path) + self._token = token + + @property + def token(self) -> str | None: + """Get the current authentication token""" + if self._token: + return self._token + return get_current_token() + + @property + def identity(self) -> Identity | None: + """Get the identity of the user, if a token has been provided""" + if self._identity: + return self._identity + if not self.token: + msg = "No authentication token found" + raise NoTokenError(msg) + + identity = RestAPI(self.token).identify() + if not identity: + return None + + self._identity = Identity( + token=self.token, + expires_at=identity.auth_context.expires_at, + status=identity.auth_context.status, + user=User( + full_name=identity.user.full_name, + email=identity.user.email, + github_username=identity.user.github_username, + ), + ) + return self._identity + + def is_authenticated(self) -> bool: + """Check if the session is fully authenticated, including token expiration""" + return bool(self.identity and self.identity.status == "active") + + def assert_authenticated(self) -> None: + """Raise an AuthError if the session is not fully authenticated""" + if not self.identity: + msg = "No identity found for session" + raise AuthError(msg) + if self.identity.status != "active": + msg = "Current session is not active. API Token may be invalid or may have expired." + raise AuthError(msg) diff --git a/src/codegen/cli/auth/decorators.py b/src/codegen/cli/auth/decorators.py index b983fea89..640bb98c5 100644 --- a/src/codegen/cli/auth/decorators.py +++ b/src/codegen/cli/auth/decorators.py @@ -4,9 +4,10 @@ import click import rich +from codegen.cli.auth.auth_session import CodegenAuthenticatedSession from codegen.cli.auth.login import login_routine -from codegen.cli.auth.session import CodegenSession from codegen.cli.errors import AuthError, InvalidTokenError, NoTokenError +from codegen.cli.rich.pretty_print import pretty_print_error def requires_auth(f: Callable) -> Callable: @@ -14,7 +15,12 @@ def requires_auth(f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args, **kwargs): - session = CodegenSession() + session = CodegenAuthenticatedSession.from_active_session() + + # Check for valid session + if not session.is_valid(): + pretty_print_error(f"The session at path {session.repo_path} is missing or corrupt.\nPlease run 'codegen init' to re-initialize the project.") + raise click.Abort() try: if not session.is_authenticated(): diff --git a/src/codegen/cli/auth/login.py b/src/codegen/cli/auth/login.py index 195263615..7e4b0e7ff 100644 --- a/src/codegen/cli/auth/login.py +++ b/src/codegen/cli/auth/login.py @@ -4,13 +4,13 @@ import rich_click as click from codegen.cli.api.webapp_routes import USER_SECRETS_ROUTE -from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.auth_session import CodegenAuthenticatedSession from codegen.cli.auth.token_manager import TokenManager from codegen.cli.env.global_env import global_env from codegen.cli.errors import AuthError -def login_routine(token: str | None = None) -> CodegenSession: +def login_routine(token: str | None = None) -> CodegenAuthenticatedSession: """Guide user through login flow and return authenticated session. Args: @@ -39,7 +39,7 @@ def login_routine(token: str | None = None) -> CodegenSession: # Validate and store token token_manager = TokenManager() - session = CodegenSession(_token) + session = CodegenAuthenticatedSession(token=_token) try: session.assert_authenticated() diff --git a/src/codegen/cli/auth/session.py b/src/codegen/cli/auth/session.py index 8e0090a11..915399c6c 100644 --- a/src/codegen/cli/auth/session.py +++ b/src/codegen/cli/auth/session.py @@ -1,107 +1,42 @@ -from dataclasses import dataclass from pathlib import Path from pygit2.repository import Repository -from codegen.cli.auth.constants import CODEGEN_DIR -from codegen.cli.auth.token_manager import get_current_token -from codegen.cli.errors import AuthError, NoTokenError from codegen.cli.git.repo import get_git_repo -from codegen.cli.utils.config import Config, get_config, write_config -from codegen.shared.enums.programming_language import ProgrammingLanguage - - -@dataclass -class Identity: - token: str - expires_at: str - status: str - user: "User" - - -@dataclass -class User: - full_name: str - email: str - github_username: str - - -@dataclass -class UserProfile: - """User profile populated from /identity endpoint""" - - name: str - email: str - username: str +from codegen.shared.configs.config import load +from codegen.shared.configs.constants import CODEGEN_DIR_NAME, CONFIG_FILENAME +from codegen.shared.configs.global_config import config as global_config +from codegen.shared.configs.models import Config class CodegenSession: """Represents an authenticated codegen session with user and repository context""" - # =====[ Instance attributes ]===== - _token: str | None = None - - # =====[ Lazy instance attributes ]===== - _config: Config | None = None - _identity: Identity | None = None - _profile: UserProfile | None = None - - @property - def token(self) -> str | None: - """Get the current authentication token""" - if self._token: - return self._token - return get_current_token() - - @property - def config(self) -> Config: - """Get the config for the current session""" - if self._config: - return self._config - self._config = get_config(self.codegen_dir) - return self._config - - @property - def identity(self) -> Identity | None: - """Get the identity of the user, if a token has been provided""" - if self._identity: - return self._identity - if not self.token: - msg = "No authentication token found" - raise NoTokenError(msg) - - from codegen.cli.api.client import RestAPI - - identity = RestAPI(self.token).identify() - if not identity: + repo_path: Path # TODO: rename to root_path + codegen_dir: Path + config: Config + existing: bool + + def __init__(self, repo_path: Path): + self.repo_path = repo_path + self.codegen_dir = repo_path / CODEGEN_DIR_NAME + self.existing = global_config.get_session(repo_path) is not None + self.config = load(self.codegen_dir / CONFIG_FILENAME) + global_config.set_active_session(repo_path) + + @classmethod + def from_active_session(cls) -> "CodegenSession | None": + active_session = global_config.get_active_session() + if not active_session: return None - self._identity = Identity( - token=self.token, - expires_at=identity.auth_context.expires_at, - status=identity.auth_context.status, - user=User( - full_name=identity.user.full_name, - email=identity.user.email, - github_username=identity.user.github_username, - ), - ) - return self._identity + return cls(active_session) - @property - def profile(self) -> UserProfile | None: - """Get the user profile information""" - if self._profile: - return self._profile - if not self.identity: - return None - - self._profile = UserProfile( - name=self.identity.user.full_name, - email=self.identity.user.email, - username=self.identity.user.github_username, - ) - return self._profile + def is_valid(self) -> bool: + """Validates that the session configuration is correct""" + # TODO: also make sure all the expected prompt, jupyter, codemods are present + # TODO: make sure there is still a git instance here. + return self.repo_path.exists() and self.codegen_dir.exists() and Path(self.config.file_path).exists() @property def git_repo(self) -> Repository: @@ -111,39 +46,5 @@ def git_repo(self) -> Repository: raise ValueError(msg) return git_repo - @property - def repo_name(self) -> str: - """Get the current repository name""" - return self.config.repo_full_name - - @property - def language(self) -> ProgrammingLanguage: - """Get the current language""" - # TODO(jayhack): This is a temporary solution to get the language. - # We should eventually get the language on init. - return self.config.programming_language or ProgrammingLanguage.PYTHON - - @property - def codegen_dir(self) -> Path: - """Get the path to the codegen-sh directory""" - return Path.cwd() / CODEGEN_DIR - def __str__(self) -> str: - return f"CodegenSession(user={self.profile.name}, repo={self.repo_name})" - - def is_authenticated(self) -> bool: - """Check if the session is fully authenticated, including token expiration""" - return bool(self.identity and self.identity.status == "active") - - def assert_authenticated(self) -> None: - """Raise an AuthError if the session is not fully authenticated""" - if not self.identity: - msg = "No identity found for session" - raise AuthError(msg) - if self.identity.status != "active": - msg = "Current session is not active. API Token may be invalid or may have expired." - raise AuthError(msg) - - def write_config(self) -> None: - """Write the config to the codegen-sh/config.toml file""" - write_config(self.config, self.codegen_dir) + return f"CodegenSession(user={self.config.repository.user_name}, repo={self.config.repository.repo_name})" diff --git a/src/codegen/cli/commands/config/main.py b/src/codegen/cli/commands/config/main.py index 05fbe8f2c..fdb13a0a5 100644 --- a/src/codegen/cli/commands/config/main.py +++ b/src/codegen/cli/commands/config/main.py @@ -5,7 +5,8 @@ import rich_click as click from rich.table import Table -from codegen.shared.configs.config import config +from codegen.cli.auth.session import CodegenSession +from codegen.cli.workspace.decorators import requires_init @click.group(name="config") @@ -15,7 +16,8 @@ def config_command(): @config_command.command(name="list") -def list_command(): +@requires_init +def list_command(session: CodegenSession): """List current configuration values.""" table = Table(title="Configuration Values", border_style="blue", show_header=True) table.add_column("Key", style="cyan", no_wrap=True) @@ -35,7 +37,7 @@ def flatten_dict(data: dict, prefix: str = "") -> dict: return items # Get flattened config and sort by keys - flat_config = flatten_dict(config.model_dump()) + flat_config = flatten_dict(session.config.model_dump()) sorted_items = sorted(flat_config.items(), key=lambda x: x[0]) # Group by top-level prefix @@ -54,10 +56,11 @@ def get_prefix(item): @config_command.command(name="get") +@requires_init @click.argument("key") -def get_command(key: str): +def get_command(session: CodegenSession, key: str): """Get a configuration value.""" - value = config.get(key) + value = session.config.get(key) if value is None: rich.print(f"[red]Error: Configuration key '{key}' not found[/red]") return @@ -66,18 +69,19 @@ def get_command(key: str): @config_command.command(name="set") +@requires_init @click.argument("key") @click.argument("value") -def set_command(key: str, value: str): +def set_command(session: CodegenSession, key: str, value: str): """Set a configuration value and write to config.toml.""" - cur_value = config.get(key) + cur_value = session.config.get(key) if cur_value is None: rich.print(f"[red]Error: Configuration key '{key}' not found[/red]") return if cur_value.lower() != value.lower(): try: - config.set(key, value) + session.config.set(key, value) except Exception as e: logging.exception(e) rich.print(f"[red]{e}[/red]") diff --git a/src/codegen/cli/commands/create/main.py b/src/codegen/cli/commands/create/main.py index e4804c07e..a7a6f0b13 100644 --- a/src/codegen/cli/commands/create/main.py +++ b/src/codegen/cli/commands/create/main.py @@ -93,7 +93,7 @@ def create_command(session: CodegenSession, name: str, path: Path, description: # Use API to generate implementation with create_spinner("Generating function (using LLM, this will take ~10s)") as status: response = RestAPI(session.token).create(name=name, query=description) - code = convert_to_cli(response.code, session.language, name) + code = convert_to_cli(response.code, session.config.repository.language, name) prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text(response.context) else: diff --git a/src/codegen/cli/commands/init/main.py b/src/codegen/cli/commands/init/main.py index 4a7a1e261..e52f53a8b 100644 --- a/src/codegen/cli/commands/init/main.py +++ b/src/codegen/cli/commands/init/main.py @@ -1,5 +1,3 @@ -import os -import subprocess import sys from pathlib import Path @@ -13,9 +11,7 @@ from codegen.cli.rich.codeblocks import format_command from codegen.cli.workspace.initialize_workspace import initialize_codegen from codegen.git.repo_operator.local_git_repo import LocalGitRepo -from codegen.shared.configs.config import config -from codegen.shared.configs.constants import CONFIG_PATH -from codegen.shared.enums.programming_language import ProgrammingLanguage +from codegen.git.utils.path import get_git_root_path @click.command(name="init") @@ -26,12 +22,9 @@ def init_command(path: str | None = None, token: str | None = None, language: str | None = None, fetch_docs: bool = False): """Initialize or update the Codegen folder.""" # Print a message if not in a git repo - path = str(Path.cwd()) if path is None else path - try: - os.chdir(path) - output = subprocess.run(["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True) - local_git = LocalGitRepo(repo_path=output.stdout.strip()) - except (subprocess.CalledProcessError, FileNotFoundError): + path = Path.cwd() if path is None else Path(path) + repo_path = get_git_root_path(path) + if repo_path is None: rich.print(f"\n[bold red]Error:[/bold red] Path={path} is not in a git repository") rich.print("[white]Please run this command from within a git repository.[/white]") rich.print("\n[dim]To initialize a new git repository:[/dim]") @@ -39,6 +32,7 @@ def init_command(path: str | None = None, token: str | None = None, language: st rich.print(format_command("codegen init")) sys.exit(1) + local_git = LocalGitRepo(repo_path=repo_path) if local_git.origin_remote is None: rich.print("\n[bold red]Error:[/bold red] No remote found for repository") rich.print("[white]Please add a remote to the repository.[/white]") @@ -46,6 +40,9 @@ def init_command(path: str | None = None, token: str | None = None, language: st rich.print(format_command("git remote add origin ")) sys.exit(1) + session = CodegenSession(repo_path=repo_path) + config = session.config + if token is None: token = config.secrets.github_token else: @@ -67,7 +64,7 @@ def init_command(path: str | None = None, token: str | None = None, language: st sys.exit(1) # Save repo config - config.repository.repo_path = local_git.repo_path + config.repository.repo_path = str(local_git.repo_path) config.repository.repo_name = local_git.name config.repository.full_name = local_git.full_name config.repository.user_name = local_git.user_name @@ -75,10 +72,8 @@ def init_command(path: str | None = None, token: str | None = None, language: st config.repository.language = (language or local_git.get_language(access_token=token)).upper() config.save() - session = CodegenSession() - action = "Updating" if CONFIG_PATH.exists() else "Initializing" - rich.print("") - codegen_dir, docs_dir, examples_dir = initialize_codegen(action, session=session, fetch_docs=fetch_docs, programming_language=ProgrammingLanguage(config.repository.language)) + action = "Updating" if session.existing else "Initializing" + codegen_dir, docs_dir, examples_dir = initialize_codegen(status=action, session=session, fetch_docs=fetch_docs) # Print success message rich.print(f"✅ {action} complete\n") diff --git a/src/codegen/cli/commands/login/main.py b/src/codegen/cli/commands/login/main.py index 67c7d46db..da096bfe3 100644 --- a/src/codegen/cli/commands/login/main.py +++ b/src/codegen/cli/commands/login/main.py @@ -1,9 +1,11 @@ +import sys + import rich import rich_click as click +from codegen.cli.auth.auth_session import CodegenAuthenticatedSession from codegen.cli.auth.login import login_routine -from codegen.cli.auth.session import CodegenSession -from codegen.cli.auth.token_manager import TokenManager +from codegen.cli.auth.token_manager import TokenManager, get_current_token @click.command(name="login") @@ -11,22 +13,23 @@ def login_command(token: str): """Store authentication token.""" # Check if already authenticated - token_manager = TokenManager() - if token_manager.get_token(): + if get_current_token(): msg = "Already authenticated. Use 'codegen logout' to clear the token." raise click.ClickException(msg) + if not token: + login_routine() + sys.exit(1) + # Use provided token or go through login flow - if token: - session = CodegenSession(token=token) - try: - session.assert_authenticated() - token_manager.save_token(token) - rich.print(f"[green]✓ Stored token to:[/green] {token_manager.token_file}") - rich.print("[cyan]📊 Hey![/cyan] We collect anonymous usage data to improve your experience 🔒") - rich.print("To opt out, set [green]telemetry_enabled = false[/green] in [cyan]~/.config/codegen-sh/analytics.json[/cyan] ✨") - except ValueError as e: - msg = f"Error: {e!s}" - raise click.ClickException(msg) - else: - login_routine(token) + token_manager = TokenManager() + session = CodegenAuthenticatedSession(token=token) + try: + session.assert_authenticated() + token_manager.save_token(token) + rich.print(f"[green]✓ Stored token to:[/green] {token_manager.token_file}") + rich.print("[cyan]📊 Hey![/cyan] We collect anonymous usage data to improve your experience 🔒") + rich.print("To opt out, set [green]telemetry_enabled = false[/green] in [cyan]~/.config/codegen-sh/analytics.json[/cyan] ✨") + except ValueError as e: + msg = f"Error: {e!s}" + raise click.ClickException(msg) diff --git a/src/codegen/cli/commands/profile/main.py b/src/codegen/cli/commands/profile/main.py index 856b68b7a..55fcc5c82 100644 --- a/src/codegen/cli/commands/profile/main.py +++ b/src/codegen/cli/commands/profile/main.py @@ -13,9 +13,10 @@ @requires_init def profile_command(session: CodegenSession): """Display information about the currently authenticated user.""" + repo_config = session.config.repository rich.print( Panel( - f"[cyan]Name:[/cyan] {session.profile.name}\n[cyan]Email:[/cyan] {session.profile.email}\n[cyan]Repo:[/cyan] {session.repo_name}", + f"[cyan]Name:[/cyan] {repo_config.user_name}\n[cyan]Email:[/cyan] {repo_config.user_email}\n[cyan]Repo:[/cyan] {repo_config.repo_name}", title="🔑 [bold]Current Profile[/bold]", border_style="cyan", box=box.ROUNDED, diff --git a/src/codegen/cli/commands/run/run_local.py b/src/codegen/cli/commands/run/run_local.py index b87e15e74..71c3a07f7 100644 --- a/src/codegen/cli/commands/run/run_local.py +++ b/src/codegen/cli/commands/run/run_local.py @@ -35,7 +35,7 @@ def run_local( diff_preview: Number of lines of diff to preview (None for all) """ # Parse codebase and run - repo_root = Path(session.git_repo.workdir) + repo_root = session.repo_path with Status("[bold]Parsing codebase...", spinner="dots") as status: codebase = parse_codebase(repo_root) diff --git a/src/codegen/cli/commands/run_on_pr/main.py b/src/codegen/cli/commands/run_on_pr/main.py index d3b28c471..97bd003e9 100644 --- a/src/codegen/cli/commands/run_on_pr/main.py +++ b/src/codegen/cli/commands/run_on_pr/main.py @@ -20,7 +20,7 @@ def run_on_pr(session: CodegenSession, codemod_name: str, pr_number: int) -> Non try: response = RestAPI(session.token).run_on_pr( codemod_name=codemod_name, - repo_full_name=session.repo_name, + repo_full_name=session.config.repository.full_name, github_pr_number=pr_number, ) status.stop() diff --git a/src/codegen/cli/sdk/function.py b/src/codegen/cli/sdk/function.py index 417e72e39..345cf977d 100644 --- a/src/codegen/cli/sdk/function.py +++ b/src/codegen/cli/sdk/function.py @@ -3,7 +3,7 @@ from codegen.cli.api.client import RestAPI from codegen.cli.api.schemas import CodemodRunType, RunCodemodOutput -from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.token_manager import get_current_token from codegen.cli.utils.codemods import Codemod from codegen.cli.utils.schema import CodemodConfig @@ -29,8 +29,7 @@ class Function: @classmethod def lookup(cls, name: str) -> "Function": """Look up a deployed function by name.""" - session = CodegenSession() - api_client = RestAPI(session.token) + api_client = RestAPI(get_current_token()) response = api_client.lookup(name) return cls(name=name, codemod_id=response.codemod_id, version_id=response.version_id, _api_client=api_client) @@ -52,8 +51,7 @@ def run(self, pr: bool = False, **kwargs) -> RunCodemodOutput: """ if self._api_client is None: - session = CodegenSession() - self._api_client = RestAPI(session.token) + self._api_client = RestAPI(get_current_token()) # Create a temporary codemod object to use with the API config = CodemodConfig( diff --git a/src/codegen/cli/sdk/functions.py b/src/codegen/cli/sdk/functions.py index f95141843..5f38bd4c3 100644 --- a/src/codegen/cli/sdk/functions.py +++ b/src/codegen/cli/sdk/functions.py @@ -3,7 +3,7 @@ from codegen.cli.api.client import RestAPI from codegen.cli.api.schemas import CodemodRunType, RunCodemodOutput -from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.token_manager import get_current_token from codegen.cli.utils.codemods import Codemod from codegen.cli.utils.schema import CodemodConfig @@ -28,8 +28,7 @@ def lookup(cls, name: str) -> "Function": A Function instance that can be used to run the codemod """ - session = CodegenSession() - api_client = RestAPI(session.token) + api_client = RestAPI(get_current_token()) response = api_client.lookup(name) return cls(name=name, codemod_id=response.codemod_id, version_id=response.version_id, _api_client=api_client) @@ -51,8 +50,7 @@ def run(self, pr: bool = False, **kwargs) -> RunCodemodOutput: """ if self._api_client is None: - session = CodegenSession() - self._api_client = RestAPI(session.token) + self._api_client = RestAPI(get_current_token()) # Create a temporary codemod object to use with the API config = CodemodConfig( diff --git a/src/codegen/cli/sdk/pull_request.py b/src/codegen/cli/sdk/pull_request.py index 7caca0ada..4f54a9f0d 100644 --- a/src/codegen/cli/sdk/pull_request.py +++ b/src/codegen/cli/sdk/pull_request.py @@ -2,6 +2,7 @@ from codegen.cli.api.client import RestAPI from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.token_manager import get_current_token class CodegenPullRequest: @@ -33,9 +34,9 @@ def lookup(cls, number: int) -> "CodegenPullRequest": A CodegenPullRequest instance representing the PR """ - session = CodegenSession() - api_client = RestAPI(session.token) - response = api_client.lookup_pr(repo_full_name=session.repo_name, github_pr_number=number) + session = CodegenSession.from_active_session() + api_client = RestAPI(get_current_token()) + response = api_client.lookup_pr(repo_full_name=session.config.repository.full_name, github_pr_number=number) pr = response.pr return cls( diff --git a/src/codegen/cli/utils/config.py b/src/codegen/cli/utils/config.py deleted file mode 100644 index 372540c13..000000000 --- a/src/codegen/cli/utils/config.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -import toml -from pydantic import BaseModel - - -class Config(BaseModel): - repo_name: str = "" - organization_name: str = "" - programming_language: str | None = None - - @property - def repo_full_name(self) -> str: - return f"{self.organization_name}/{self.repo_name}" - - -CONFIG_PATH = "config.toml" - - -def read_model[T: BaseModel](model: type[T], path: Path) -> T: - if not path.exists(): - return model() - return model.model_validate(toml.load(path)) - - -def get_config(codegen_dir: Path) -> Config: - config_path = codegen_dir / CONFIG_PATH - return read_model(Config, config_path) - - -def write_model(model: BaseModel, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w") as f: - toml.dump(model.model_dump(), f) - - -def write_config(config: Config, codegen_dir: Path) -> None: - config_path = codegen_dir / CONFIG_PATH - write_model(config, config_path) diff --git a/src/codegen/cli/workspace/decorators.py b/src/codegen/cli/workspace/decorators.py index 33fa0d8a3..b707adae6 100644 --- a/src/codegen/cli/workspace/decorators.py +++ b/src/codegen/cli/workspace/decorators.py @@ -1,14 +1,11 @@ import functools +import sys from collections.abc import Callable import click -import rich -from rich.status import Status -from codegen.cli.auth.constants import CODEGEN_DIR from codegen.cli.auth.session import CodegenSession from codegen.cli.rich.pretty_print import pretty_print_error -from codegen.cli.workspace.initialize_workspace import initialize_codegen def requires_init(f: Callable) -> Callable: @@ -17,38 +14,17 @@ def requires_init(f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args, **kwargs): # Create a session if one wasn't provided - session = kwargs.get("session") - if not session: - session = CodegenSession() - kwargs["session"] = session - - if not session.codegen_dir.exists(): - rich.print("Codegen not initialized. Running init command first...") - with Status("[bold]Initializing Codegen...", spinner="dots", spinner_style="purple") as status: - initialize_codegen(status) - - # Check for config.toml existence and validity - config_path = session.codegen_dir / "config.toml" - if not config_path.exists(): - pretty_print_error(f"{CODEGEN_DIR}/config.toml is missing.\nPlease run 'codegen init' to initialize the project.") - raise click.Abort() - - try: - # This will attempt to parse the config - _ = session.config - except Exception as e: - pretty_print_error(f"{CODEGEN_DIR}/config.toml is corrupted or invalid.\nDetails: {e!s}\n\n\nPlease run 'codegen init' to reinitialize the project.") - raise click.Abort() - - try: - # Verify git repo exists before proceeding - _ = session.git_repo - except ValueError: - pretty_print_error( - "This command must be run from within a git repository.\n\nPlease either:\n1. Navigate to a git repository directory\n2. Initialize a new git repository with 'git init'" - ) + session = kwargs.get("session") or CodegenSession.from_active_session() + if session is None: + pretty_print_error("Codegen not initialized. Please run `codegen init` from a git repo workspace.") + sys.exit(1) + + # Check for valid session + if not session.is_valid(): + pretty_print_error(f"The session at path {session.repo_path} is missing or corrupt.\nPlease run 'codegen init' to re-initialize the project.") raise click.Abort() + kwargs["session"] = session return f(*args, **kwargs) return wrapper diff --git a/src/codegen/cli/workspace/examples_workspace.py b/src/codegen/cli/workspace/examples_workspace.py index 54c6bfbd3..9cd21599f 100644 --- a/src/codegen/cli/workspace/examples_workspace.py +++ b/src/codegen/cli/workspace/examples_workspace.py @@ -18,9 +18,8 @@ def populate_examples(session: CodegenSession, dest: Path, examples: list[Serial for example in examples: dest_file = dest / f"{example.name}.py" dest_file.parent.mkdir(parents=True, exist_ok=True) - session.config.programming_language = example.language - session.write_config() - formatted = format_example(example, session.config.programming_language) + session.config.set("repository.language", str(example.language)) + formatted = format_example(example, session.config.repository.language) dest_file.write_text(formatted) diff --git a/src/codegen/cli/workspace/initialize_workspace.py b/src/codegen/cli/workspace/initialize_workspace.py index 70657e6ef..3831ea75d 100644 --- a/src/codegen/cli/workspace/initialize_workspace.py +++ b/src/codegen/cli/workspace/initialize_workspace.py @@ -1,7 +1,6 @@ import shutil from contextlib import nullcontext from pathlib import Path -from typing import Optional import requests import rich @@ -11,6 +10,7 @@ from codegen.cli.api.client import RestAPI from codegen.cli.auth.constants import CODEGEN_DIR, DOCS_DIR, EXAMPLES_DIR, PROMPTS_DIR from codegen.cli.auth.session import CodegenSession +from codegen.cli.auth.token_manager import get_current_token from codegen.cli.git.repo import get_git_repo from codegen.cli.git.url import get_git_organization_and_repo from codegen.cli.rich.spinners import create_spinner @@ -18,29 +18,24 @@ from codegen.cli.workspace.docs_workspace import populate_api_docs from codegen.cli.workspace.examples_workspace import populate_examples from codegen.cli.workspace.venv_manager import VenvManager -from codegen.shared.enums.programming_language import ProgrammingLanguage -def initialize_codegen( - status: Status | str = "Initializing", session: CodegenSession | None = None, fetch_docs: bool = False, programming_language: Optional[ProgrammingLanguage] = None -) -> tuple[Path, Path, Path]: +def initialize_codegen(session: CodegenSession, status: Status | str = "Initializing", fetch_docs: bool = False) -> CodegenSession: """Initialize or update the codegen directory structure and content. Args: status: Either a Status object to update, or a string action being performed ("Initializing" or "Updating") session: Optional CodegenSession for fetching docs and examples fetch_docs: Whether to fetch docs and examples (requires auth) - programming_language: Optional override for the programming language Returns: Tuple of (codegen_folder, docs_folder, examples_folder) """ repo = get_git_repo() - REPO_PATH = Path(repo.workdir) - CODEGEN_FOLDER = REPO_PATH / CODEGEN_DIR - PROMPTS_FOLDER = REPO_PATH / PROMPTS_DIR - DOCS_FOLDER = REPO_PATH / DOCS_DIR - EXAMPLES_FOLDER = REPO_PATH / EXAMPLES_DIR + CODEGEN_FOLDER = session.repo_path / CODEGEN_DIR + PROMPTS_FOLDER = session.repo_path / PROMPTS_DIR + DOCS_FOLDER = session.repo_path / DOCS_DIR + EXAMPLES_FOLDER = session.repo_path / EXAMPLES_DIR CONFIG_PATH = CODEGEN_FOLDER / "config.toml" JUPYTER_DIR = CODEGEN_FOLDER / "jupyter" CODEMODS_DIR = CODEGEN_FOLDER / "codemods" @@ -107,18 +102,10 @@ def initialize_codegen( DOCS_FOLDER.mkdir(parents=True, exist_ok=True) EXAMPLES_FOLDER.mkdir(parents=True, exist_ok=True) - response = RestAPI(session.token).get_docs() + response = RestAPI(get_current_token()).get_docs() populate_api_docs(DOCS_FOLDER, response.docs, status_obj) populate_examples(session, EXAMPLES_FOLDER, response.examples, status_obj) - # Set programming language - if programming_language: - session.config.programming_language = programming_language - else: - session.config.programming_language = str(response.language) - - session.write_config() - return CODEGEN_FOLDER, DOCS_FOLDER, EXAMPLES_FOLDER diff --git a/src/codegen/git/repo_operator/local_git_repo.py b/src/codegen/git/repo_operator/local_git_repo.py index 8674e5b33..864f4ae17 100644 --- a/src/codegen/git/repo_operator/local_git_repo.py +++ b/src/codegen/git/repo_operator/local_git_repo.py @@ -1,5 +1,6 @@ import os from functools import cached_property +from pathlib import Path from git import Repo from git.remote import Remote @@ -11,9 +12,9 @@ # TODO: merge this with RepoOperator class LocalGitRepo: - repo_path: str + repo_path: Path - def __init__(self, repo_path: str): + def __init__(self, repo_path: Path): self.repo_path = repo_path @cached_property @@ -61,13 +62,13 @@ def user_email(self) -> str | None: def get_language(self, access_token: str | None = None) -> str: """Returns the majority language of the repository""" if access_token is not None: - repo_config = RepoConfig.from_repo_path(repo_path=self.repo_path) + repo_config = RepoConfig.from_repo_path(repo_path=str(self.repo_path)) repo_config.full_name = self.full_name remote_git = GitRepoClient(repo_config=repo_config, access_token=access_token) if (language := remote_git.repo.language) is not None: return language.upper() - return str(determine_project_language(self.repo_path)) + return str(determine_project_language(str(self.repo_path))) def has_remote(self) -> bool: return bool(self.git_cli.remotes) diff --git a/src/codegen/git/utils/path.py b/src/codegen/git/utils/path.py new file mode 100644 index 000000000..c4cd063f8 --- /dev/null +++ b/src/codegen/git/utils/path.py @@ -0,0 +1,13 @@ +import os +import subprocess +from pathlib import Path + + +def get_git_root_path(path: Path) -> Path | None: + """Get the closest root of the git repository containing the given path""" + try: + os.chdir(path) + output = subprocess.run(["git", "rev-parse", "--show-toplevel"], capture_output=True, check=True, text=True) + return Path(output.stdout.strip()) + except (subprocess.CalledProcessError, FileNotFoundError): + return None diff --git a/src/codegen/shared/configs/config.py b/src/codegen/shared/configs/config.py index 3bf6382d6..4caf3940c 100644 --- a/src/codegen/shared/configs/config.py +++ b/src/codegen/shared/configs/config.py @@ -1,3 +1,4 @@ +# TODO: rename this file to local.py from pathlib import Path import tomllib @@ -6,23 +7,27 @@ from codegen.shared.configs.models import Config -def load(config_path: Path | None = None) -> Config: +def load(config_path: Path) -> Config: """Loads configuration from various sources.""" # Load from .env file - env_config = _load_from_env() + env_config = _load_from_env(config_path) # Load from .codegen/config.toml file - toml_config = _load_from_toml(config_path or CONFIG_PATH) + toml_config = _load_from_toml(config_path) # Merge configurations recursively config_dict = _merge_configs(env_config.model_dump(), toml_config.model_dump()) + loaded_config = Config(**config_dict) - return Config(**config_dict) + # Save the configuration to file if it doesn't exist + if not config_path.exists(): + loaded_config.save() + return loaded_config -def _load_from_env() -> Config: +def _load_from_env(config_path: Path) -> Config: """Load configuration from the environment variables.""" - return Config() + return Config(file_path=str(config_path)) def _load_from_toml(config_path: Path) -> Config: @@ -30,9 +35,10 @@ def _load_from_toml(config_path: Path) -> Config: if config_path.exists(): with open(config_path, "rb") as f: toml_config = tomllib.load(f) + toml_config["file_path"] = str(config_path) return Config.model_validate(toml_config, strict=False) - return Config() + return Config(file_path=str(config_path)) def _merge_configs(base: dict, override: dict) -> dict: @@ -48,7 +54,7 @@ def _merge_configs(base: dict, override: dict) -> dict: return merged -config = load() +config = load(CONFIG_PATH) if __name__ == "__main__": print(config) diff --git a/src/codegen/shared/configs/constants.py b/src/codegen/shared/configs/constants.py index d9f5d6915..dc58327d0 100644 --- a/src/codegen/shared/configs/constants.py +++ b/src/codegen/shared/configs/constants.py @@ -1,11 +1,20 @@ from pathlib import Path -# Config file -CODEGEN_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent CODEGEN_DIR_NAME = ".codegen" CONFIG_FILENAME = "config.toml" -CONFIG_PATH = CODEGEN_REPO_ROOT / CODEGEN_DIR_NAME / CONFIG_FILENAME -# Environment variables -ENV_FILENAME = ".env" -ENV_PATH = CODEGEN_REPO_ROOT / "src" / "codegen" / ENV_FILENAME +# ====[ Codegen internal config ]==== +CODEGEN_REPO_ROOT = Path(__file__).parent.parent.parent.parent.parent +CODEGEN_DIR_PATH = CODEGEN_REPO_ROOT / CODEGEN_DIR_NAME +CONFIG_PATH = CODEGEN_DIR_PATH / CONFIG_FILENAME + +# ====[ User session config ]==== +PROMPTS_DIR = Path(CODEGEN_DIR_NAME) / "prompts" +DOCS_DIR = Path(CODEGEN_DIR_NAME) / "docs" +EXAMPLES_DIR = Path(CODEGEN_DIR_NAME) / "examples" + + +# ====[ User global config paths ]==== +GLOBAL_CONFIG_DIR = Path("~/.config/codegen-sh").expanduser() +AUTH_FILE = GLOBAL_CONFIG_DIR / "auth.json" +SESSION_FILE = GLOBAL_CONFIG_DIR / "session.json" diff --git a/src/codegen/shared/configs/global_config.py b/src/codegen/shared/configs/global_config.py new file mode 100644 index 000000000..a1b4ea493 --- /dev/null +++ b/src/codegen/shared/configs/global_config.py @@ -0,0 +1,56 @@ +"""Global config to manage different codegen sessions, as well as user auth.""" + +# TODO: rename this file to global.py +import json +from pathlib import Path + +from pydantic_settings import BaseSettings + +from codegen.shared.configs.constants import SESSION_FILE + + +class GlobalSessionConfig(BaseSettings): + active_session_path: str | None = None + sessions: list[str] + + def get_session(self, session_root_path: Path) -> str | None: + return next((s for s in self.sessions if s == str(session_root_path)), None) + + def get_active_session(self) -> Path | None: + if not self.active_session_path: + return None + + return Path(self.active_session_path) + + def set_active_session(self, session_root_path: Path) -> None: + if not session_root_path.exists(): + msg = f"Session path does not exist: {session_root_path}" + raise ValueError(msg) + + self.active_session_path = str(session_root_path) + if session_root_path.name not in self.sessions: + self.sessions.append(str(session_root_path)) + + self.save() + + def save(self) -> None: + if not SESSION_FILE.parent.exists(): + SESSION_FILE.parent.mkdir(parents=True, exist_ok=True) + + with open(SESSION_FILE, "w") as f: + json.dump(self.model_dump(), f) + + +def _load_global_config() -> GlobalSessionConfig: + """Load configuration from the JSON file.""" + if SESSION_FILE.exists(): + with open(SESSION_FILE) as f: + json_config = json.load(f) + return GlobalSessionConfig.model_validate(json_config, strict=False) + + new_config = GlobalSessionConfig(sessions=[]) + new_config.save() + return new_config + + +config = _load_global_config() diff --git a/src/codegen/shared/configs/models.py b/src/codegen/shared/configs/models.py index 6e8a5ec7c..674d8fdb9 100644 --- a/src/codegen/shared/configs/models.py +++ b/src/codegen/shared/configs/models.py @@ -5,19 +5,17 @@ from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict -from codegen.shared.configs.constants import CONFIG_PATH, ENV_PATH - def _get_setting_config(group_name: str) -> SettingsConfigDict: return SettingsConfigDict( env_prefix=f"CODEGEN_{group_name}__", - env_file=ENV_PATH, case_sensitive=False, extra="ignore", exclude_defaults=False, ) +# TODO: break up these models into separate files and nest in shared/configs/models class TypescriptConfig(BaseSettings): model_config = _get_setting_config("FEATURE_FLAGS_TYPESCRIPT") @@ -49,9 +47,9 @@ class RepositoryConfig(BaseSettings): model_config = _get_setting_config("REPOSITORY") - repo_path: str | None = None + repo_path: str | None = None # replace with base_dir repo_name: str | None = None - full_name: str | None = None + full_name: str | None = None # replace with org_name language: str | None = None user_name: str | None = None user_email: str | None = None @@ -68,22 +66,23 @@ class FeatureFlagsConfig(BaseModel): codebase: CodebaseFeatureFlags = Field(default_factory=CodebaseFeatureFlags) +# TODO: rename to SessionConfig class Config(BaseSettings): model_config = SettingsConfigDict( extra="ignore", exclude_defaults=False, ) + file_path: str secrets: SecretsConfig = Field(default_factory=SecretsConfig) repository: RepositoryConfig = Field(default_factory=RepositoryConfig) feature_flags: FeatureFlagsConfig = Field(default_factory=FeatureFlagsConfig) - def save(self, config_path: Path | None = None) -> None: + def save(self) -> None: """Save configuration to the config file.""" - path = config_path or CONFIG_PATH - - path.parent.mkdir(parents=True, exist_ok=True) + config_dir = Path(self.file_path).parent + config_dir.mkdir(parents=True, exist_ok=True) - with open(path, "w") as f: + with open(self.file_path, "w") as f: toml.dump(self.model_dump(exclude_none=True), f) def get(self, full_key: str) -> str | None: