From d36a975bab588d45443ce45cfd1a66cc080ce1c1 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sat, 28 Jun 2025 14:14:21 +0200 Subject: [PATCH 1/5] chore: migrate to Ruff, deprecate .gitingest, docs & repo-wide cleanup --- .pre-commit-config.yaml | 23 +- README.md | 31 ++- pyproject.toml | 37 ++- requirements-dev.txt | 4 +- requirements.txt | 1 - src/gitingest/__init__.py | 6 +- src/gitingest/cli.py | 172 +++++-------- src/gitingest/{cloning.py => clone.py} | 45 ++-- src/gitingest/entrypoint.py | 242 +++++++++++------- src/gitingest/ingestion.py | 127 ++------- ...tput_formatters.py => output_formatter.py} | 72 +++--- .../{query_parsing.py => query_parser.py} | 106 ++++---- src/gitingest/schemas/__init__.py | 8 +- .../{filesystem_schema.py => filesystem.py} | 80 +++--- src/gitingest/schemas/ingestion.py | 131 ++++++++++ src/gitingest/schemas/ingestion_schema.py | 89 ------- src/gitingest/utils/__init__.py | 1 + src/gitingest/utils/async_compat.py | 35 +++ src/gitingest/utils/auth.py | 22 ++ src/gitingest/utils/compat_typing.py | 13 + src/gitingest/utils/exceptions.py | 13 +- src/gitingest/utils/file_utils.py | 81 +++--- src/gitingest/utils/git_utils.py | 178 +++++++------ src/gitingest/utils/ignore_patterns.py | 98 ++++--- src/gitingest/utils/ingestion_utils.py | 80 +++--- .../utils/{notebook_utils.py => notebook.py} | 70 ++--- src/gitingest/utils/os_utils.py | 10 +- src/gitingest/utils/path_utils.py | 29 +-- src/gitingest/utils/query_parser_utils.py | 62 +++-- src/gitingest/utils/timeout_wrapper.py | 19 +- src/server/__init__.py | 1 + src/server/form_types.py | 16 ++ src/server/main.py | 31 +-- src/server/models.py | 71 +++++ src/server/query_processor.py | 94 ++++--- src/server/routers/__init__.py | 2 +- src/server/routers/download.py | 58 ++--- src/server/routers/dynamic.py | 56 ++-- src/server/routers/index.py | 51 ++-- src/server/server_config.py | 8 +- src/server/server_utils.py | 137 +++++----- tests/__init__.py | 1 + tests/conftest.py | 61 +++-- tests/query_parser/__init__.py | 1 + tests/query_parser/test_git_host_agnostic.py | 18 +- tests/query_parser/test_query_parser.py | 159 ++++++------ tests/test_cli.py | 43 ++-- ...test_repository_clone.py => test_clone.py} | 133 +++++----- tests/test_flow_integration.py | 38 +-- tests/test_git_utils.py | 104 +++++--- tests/test_gitignore_feature.py | 34 ++- tests/test_ingestion.py | 59 +++-- tests/test_notebook_utils.py | 140 +++++----- 53 files changed, 1742 insertions(+), 1459 deletions(-) rename src/gitingest/{cloning.py => clone.py} (67%) rename src/gitingest/{output_formatters.py => output_formatter.py} (79%) rename src/gitingest/{query_parsing.py => query_parser.py} (72%) rename src/gitingest/schemas/{filesystem_schema.py => filesystem.py} (62%) create mode 100644 src/gitingest/schemas/ingestion.py delete mode 100644 src/gitingest/schemas/ingestion_schema.py create mode 100644 src/gitingest/utils/async_compat.py create mode 100644 src/gitingest/utils/auth.py create mode 100644 src/gitingest/utils/compat_typing.py rename src/gitingest/utils/{notebook_utils.py => notebook.py} (64%) create mode 100644 src/server/form_types.py create mode 100644 src/server/models.py rename tests/{test_repository_clone.py => test_clone.py} (78%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aa40e20d..731ae623 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,11 +38,6 @@ repos: - id: absolufy-imports description: "Automatically convert relative imports to absolute. (Use `args: [--never]` to revert.)" - - repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black - - repo: https://github.com/asottile/pyupgrade rev: v3.20.0 hooks: @@ -79,12 +74,18 @@ repos: description: "Lint markdown files." args: ["--disable=line-length"] - - repo: https://github.com/terrencepreilly/darglint - rev: v1.8.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.0 + hooks: + - id: ruff-check + - id: ruff-format + + - repo: https://github.com/jsh9/pydoclint + rev: 0.6.7 hooks: - - id: darglint - name: darglint for source - args: [--docstring-style=numpy] + - id: pydoclint + name: pydoclint for source + args: [--style=numpy] files: ^src/ - repo: https://github.com/pycqa/pylint @@ -104,7 +105,6 @@ repos: slowapi, starlette>=0.40.0, tiktoken, - tomli, pathspec, uvicorn>=0.11.7, ] @@ -124,7 +124,6 @@ repos: slowapi, starlette>=0.40.0, tiktoken, - tomli, pathspec, uvicorn>=0.11.7, ] diff --git a/README.md b/README.md index 3e29673c..81b177f3 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,36 @@ # Gitingest -[![Image](./docs/frontpage.png "Gitingest main page")](https://gitingest.com) +[![Screenshot of Gitingest front page](./docs/frontpage.png)](https://gitingest.com) -[![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/cyclotruc/gitingest/blob/main/LICENSE) -[![PyPI version](https://badge.fury.io/py/gitingest.svg)](https://badge.fury.io/py/gitingest) -[![GitHub stars](https://img.shields.io/github/stars/cyclotruc/gitingest?style=social.svg)](https://github.com/cyclotruc/gitingest) -[![Downloads](https://pepy.tech/badge/gitingest)](https://pepy.tech/project/gitingest) - -[![Discord](https://dcbadge.limes.pink/api/server/https://discord.com/invite/zerRaGK9EC)](https://discord.com/invite/zerRaGK9EC) + + +

+ + PyPI + Python Versions +
+ + CI + Ruff + OpenSSF Scorecard +
+ License + Downloads + GitHub Stars + Discord +
+ Trendshift +

+ Turn any Git repository into a prompt-friendly text ingest for LLMs. You can also replace `hub` with `ingest` in any GitHub URL to access the corresponding digest. + [gitingest.com](https://gitingest.com) · [Chrome Extension](https://chromewebstore.google.com/detail/adfjahbijlkjfoicpjkhjicpjpjfaood) · [Firefox Add-on](https://addons.mozilla.org/firefox/addon/gitingest) - + [Deutsch](https://www.readme-i18n.com/cyclotruc/gitingest?lang=de) | [Español](https://www.readme-i18n.com/cyclotruc/gitingest?lang=es) | [Français](https://www.readme-i18n.com/cyclotruc/gitingest?lang=fr) | diff --git a/pyproject.toml b/pyproject.toml index 399adc9a..af518476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,8 @@ dependencies = [ "slowapi", "starlette>=0.40.0", # Vulnerable to https://osv.dev/vulnerability/GHSA-f96h-pmfr-66vw "tiktoken>=0.7.0", # Support for o200k_base encoding - "tomli", "pathspec>=0.12.1", - "typing_extensions; python_version < '3.10'", + "typing_extensions>= 4.0.0; python_version < '3.10'", "uvicorn>=0.11.7", # Vulnerable to https://osv.dev/vulnerability/PYSEC-2020-150 ] @@ -62,20 +61,46 @@ disable = [ "fixme", ] +[tool.ruff] +line-length = 119 +fix = true + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ # https://docs.astral.sh/ruff/rules/... + "D107", # undocumented-public-init + "FIX002", # line-contains-todo + "TD002", # missing-todo-author + "PLR0913", # too-many-arguments, + + # TODO: fix the following issues: + "TD003", # missing-todo-link, TODO: add issue links + "T201", # print, TODO: replace with logging + "S108", # hardcoded-temp-file, TODO: replace with tempfile + "BLE001", # blind-except, TODO: replace with specific exceptions + "FAST003", # fast-api-unused-path-parameter, TODO: fix +] +per-file-ignores = { "tests/**/*.py" = ["S101"] } # Skip the “assert used” warning + +[tool.ruff.lint.pylint] +max-returns = 10 + +[tool.ruff.lint.isort] +order-by-type = true +case-sensitive = true + [tool.pycln] all = true +# TODO: Remove this once we figure out how to use ruff-isort [tool.isort] profile = "black" line_length = 119 remove_redundant_aliases = true -float_to_top = true +float_to_top = true # https://github.com/astral-sh/ruff/issues/6514 order_by_type = true filter_files = true -[tool.black] -line-length = 119 - # Test configuration [tool.pytest.ini_options] pythonpath = ["src"] diff --git a/requirements-dev.txt b/requirements-dev.txt index b8fd868a..60dc6b5f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,6 @@ -r requirements.txt -black -djlint +eval-type-backport pre-commit -pylint pytest pytest-asyncio pytest-mock diff --git a/requirements.txt b/requirements.txt index bb2956a5..74042c48 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,4 @@ python-dotenv slowapi starlette>=0.40.0 # Vulnerable to https://osv.dev/vulnerability/GHSA-f96h-pmfr-66vw tiktoken>=0.7.0 # Support for o200k_base encoding -tomli uvicorn>=0.11.7 # Vulnerable to https://osv.dev/vulnerability/PYSEC-2020-150 diff --git a/src/gitingest/__init__.py b/src/gitingest/__init__.py index 46ea09ab..0248ad0e 100644 --- a/src/gitingest/__init__.py +++ b/src/gitingest/__init__.py @@ -1,8 +1,8 @@ """Gitingest: A package for ingesting data from Git repositories.""" -from gitingest.cloning import clone_repo +from gitingest.clone import clone_repo from gitingest.entrypoint import ingest, ingest_async from gitingest.ingestion import ingest_query -from gitingest.query_parsing import parse_query +from gitingest.query_parser import parse_query -__all__ = ["ingest_query", "clone_repo", "parse_query", "ingest", "ingest_async"] +__all__ = ["clone_repo", "ingest", "ingest_async", "ingest_query", "parse_query"] diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index 1fb8a785..7fffb6ae 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -1,54 +1,51 @@ -"""Command-line interface for the Gitingest package.""" +"""Command-line interface (CLI) for Gitingest.""" # pylint: disable=no-value-for-parameter +from __future__ import annotations import asyncio -from typing import Optional, Tuple +from typing import TypedDict import click +from typing_extensions import Unpack from gitingest.config import MAX_FILE_SIZE, OUTPUT_FILE_NAME from gitingest.entrypoint import ingest_async +class _CLIArgs(TypedDict): + source: str + max_size: int + exclude_pattern: tuple[str, ...] + include_pattern: tuple[str, ...] + branch: str | None + include_gitignored: bool + token: str | None + output: str | None + + @click.command() @click.argument("source", type=str, default=".") -@click.option( - "--output", - "-o", - default=None, - help="Output file path (default: .txt in current directory)", -) @click.option( "--max-size", "-s", default=MAX_FILE_SIZE, + show_default=True, help="Maximum file size to process in bytes", ) -@click.option( - "--exclude-pattern", - "-e", - multiple=True, - help=( - "Patterns to exclude. Handles Python's arbitrary subset of Unix shell-style " - "wildcards. See: https://docs.python.org/3/library/fnmatch.html" - ), -) +@click.option("--exclude-pattern", "-e", multiple=True, help="Shell-style patterns to exclude.") @click.option( "--include-pattern", "-i", multiple=True, - help=( - "Patterns to include. Handles Python's arbitrary subset of Unix shell-style " - "wildcards. See: https://docs.python.org/3/library/fnmatch.html" - ), + help="Shell-style patterns to include.", ) @click.option("--branch", "-b", default=None, help="Branch to clone and ingest") @click.option( "--include-gitignored", is_flag=True, default=False, - help="Include files matched by .gitignore", + help="Include files matched by .gitignore and .gitingestignore", ) @click.option( "--token", @@ -60,68 +57,29 @@ "If omitted, the CLI will look for the GITHUB_TOKEN environment variable." ), ) -def main( - source: str, - output: Optional[str], - max_size: int, - exclude_pattern: Tuple[str, ...], - include_pattern: Tuple[str, ...], - branch: Optional[str], - include_gitignored: bool, - token: Optional[str], -): - """ - Main entry point for the CLI. This function is called when the CLI is run as a script. - - It calls the async main function to run the command. - - Parameters - ---------- - source : str - A directory path or a Git repository URL. - output : str, optional - The path where the output file will be written. If not specified, the output will be written - to a file named `.txt` in the current directory. Use '-' to output to stdout. - max_size : int - Maximum file size (in bytes) to consider. - exclude_pattern : Tuple[str, ...] - Glob patterns for pruning the file set. - include_pattern : Tuple[str, ...] - Glob patterns for including files in the output. - branch : str, optional - Specific branch to ingest (defaults to the repository's default). - include_gitignored : bool - If provided, include files normally ignored by .gitignore. - token: str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. - """ - asyncio.run( - _async_main( - source=source, - output=output, - max_size=max_size, - exclude_pattern=exclude_pattern, - include_pattern=include_pattern, - branch=branch, - include_gitignored=include_gitignored, - token=token, - ) - ) +@click.option( + "--output", + "-o", + default=None, + help="Write to PATH (or '-' for stdout, default: .txt).", +) +def main(**cli_kwargs: Unpack[_CLIArgs]) -> None: + """Run the CLI entry point to analyze a repo / directory and dump its contents.""" + asyncio.run(_async_main(**cli_kwargs)) async def _async_main( source: str, - output: Optional[str], - max_size: int, - exclude_pattern: Tuple[str, ...], - include_pattern: Tuple[str, ...], - branch: Optional[str], - include_gitignored: bool, - token: Optional[str], + *, + max_size: int = MAX_FILE_SIZE, + exclude_pattern: tuple[str, ...] | None = None, + include_pattern: tuple[str, ...] | None = None, + branch: str | None = None, + include_gitignored: bool = False, + token: str | None = None, + output: str | None = None, ) -> None: - """ - Analyze a directory or repository and create a text dump of its contents. + """Analyze a directory or repository and create a text dump of its contents. This command analyzes the contents of a specified source directory or repository, applies custom include and exclude patterns, and generates a text summary of the analysis which is then written to an output file @@ -130,33 +88,34 @@ async def _async_main( Parameters ---------- source : str - A directory path or a Git repository URL. - output : str, optional - The path where the output file will be written. If not specified, the output will be written - to a file named `.txt` in the current directory. Use '-' to output to stdout. + Directory path or Git repository URL. max_size : int - Maximum file size (in bytes) to consider. - exclude_pattern : Tuple[str, ...] + Maximum file size in bytes to ingest (default: 10 MB). + exclude_pattern : tuple[str, ...] | None Glob patterns for pruning the file set. - include_pattern : Tuple[str, ...] + include_pattern : tuple[str, ...] | None Glob patterns for including files in the output. - branch : str, optional - Specific branch to ingest (defaults to the repository's default). + branch : str | None + Git branch to ingest. If *None*, the repository's default branch is used. include_gitignored : bool - If provided, include files normally ignored by .gitignore. - token: str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + If *True*, also ingest files matched by ``.gitignore`` or ``.gitingestignore`` (default: ``False``). + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the ``GITHUB_TOKEN`` env var. + output : str | None + Destination file path. If *None*, the output is written to ``.txt`` in the current directory. + Use ``"-"`` to write to *stdout*. Raises ------ - Abort + click.Abort If there is an error during the execution of the command, this exception is raised to abort the process. + """ try: # Normalise pattern containers (the ingest layer expects sets) - exclude_patterns = set(exclude_pattern) - include_patterns = set(include_pattern) + exclude_patterns = set(exclude_pattern) if exclude_pattern else set() + include_patterns = set(include_pattern) if include_pattern else set() output_target = output if output is not None else OUTPUT_FILE_NAME @@ -175,21 +134,20 @@ async def _async_main( include_gitignored=include_gitignored, token=token, ) - - if output_target == "-": # stdout - click.echo("\n--- Summary ---", err=True) - click.echo(summary, err=True) - click.echo("--- End Summary ---", err=True) - click.echo("Analysis complete! Output sent to stdout.", err=True) - else: # file - click.echo(f"Analysis complete! Output written to: {output_target}") - click.echo("\nSummary:") - click.echo(summary) - except Exception as exc: # Convert any exception into Click.Abort so that exit status is non-zero click.echo(f"Error: {exc}", err=True) - raise click.Abort() from exc + raise click.Abort from exc + + if output_target == "-": # stdout + click.echo("\n--- Summary ---", err=True) + click.echo(summary, err=True) + click.echo("--- End Summary ---", err=True) + click.echo("Analysis complete! Output sent to stdout.", err=True) + else: # file + click.echo(f"Analysis complete! Output written to: {output_target}") + click.echo("\nSummary:") + click.echo(summary) if __name__ == "__main__": diff --git a/src/gitingest/cloning.py b/src/gitingest/clone.py similarity index 67% rename from src/gitingest/cloning.py rename to src/gitingest/clone.py index 1d4487fb..ab14d261 100644 --- a/src/gitingest/cloning.py +++ b/src/gitingest/clone.py @@ -1,28 +1,30 @@ -"""This module contains functions for cloning a Git repository to a local path.""" +"""Module containing functions for cloning a Git repository to a local path.""" + +from __future__ import annotations from pathlib import Path -from typing import Optional -from urllib.parse import urlparse +from typing import TYPE_CHECKING from gitingest.config import DEFAULT_TIMEOUT -from gitingest.schemas import CloneConfig from gitingest.utils.git_utils import ( - _is_github_host, check_repo_exists, create_git_auth_header, create_git_command, ensure_git_installed, + is_github_host, run_command, validate_github_token, ) from gitingest.utils.os_utils import ensure_directory from gitingest.utils.timeout_wrapper import async_timeout +if TYPE_CHECKING: + from gitingest.schemas import CloneConfig + @async_timeout(DEFAULT_TIMEOUT) -async def clone_repo(config: CloneConfig, token: Optional[str] = None) -> None: - """ - Clone a repository to a local path based on the provided configuration. +async def clone_repo(config: CloneConfig, token: str | None = None) -> None: + """Clone a repository to a local path based on the provided configuration. This function handles the process of cloning a Git repository to the local file system. It can clone a specific branch or commit if provided, and it raises exceptions if @@ -32,25 +34,25 @@ async def clone_repo(config: CloneConfig, token: Optional[str] = None) -> None: ---------- config : CloneConfig The configuration for cloning the repository. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. - Must start with 'github_pat_' or 'gph_' for GitHub repositories. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the `GITHUB_TOKEN` env var. Raises ------ ValueError If the repository is not found, if the provided URL is invalid, or if the token format is invalid. + """ # Extract and validate query parameters url: str = config.url local_path: str = config.local_path - commit: Optional[str] = config.commit - branch: Optional[str] = config.branch + commit: str | None = config.commit + branch: str | None = config.branch partial_clone: bool = config.subpath != "/" # Validate token if provided - if token and _is_github_host(url): + if token and is_github_host(url): validate_github_token(token) # Create parent directory if it doesn't exist @@ -58,17 +60,12 @@ async def clone_repo(config: CloneConfig, token: Optional[str] = None) -> None: # Check if the repository exists if not await check_repo_exists(url, token=token): - raise ValueError("Repository not found. Make sure it is public or that you have provided a valid token.") + msg = "Repository not found. Make sure it is public or that you have provided a valid token." + raise ValueError(msg) clone_cmd = ["git"] - if token and _is_github_host(url): - # Only pass URL if it's not the default github.com to maintain backward compatibility - - parsed = urlparse(url) - if parsed.hostname == "github.com": - clone_cmd += ["-c", create_git_auth_header(token)] - else: - clone_cmd += ["-c", create_git_auth_header(token, url)] + if token and is_github_host(url): + clone_cmd += ["-c", create_git_auth_header(token, url)] clone_cmd += ["clone", "--single-branch"] # TODO: Re-enable --recurse-submodules when submodule support is needed diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index 5ccf6a5e..cbc041a2 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -1,31 +1,34 @@ """Main entry point for ingesting a source and processing its contents.""" +from __future__ import annotations + import asyncio import inspect -import os import shutil import sys -from typing import Optional, Set, Tuple, Union +from pathlib import Path -from gitingest.cloning import clone_repo +from gitingest.clone import clone_repo from gitingest.config import TMP_BASE_PATH from gitingest.ingestion import ingest_query -from gitingest.query_parsing import IngestionQuery, parse_query -from gitingest.utils.ignore_patterns import load_gitignore_patterns +from gitingest.query_parser import IngestionQuery, parse_query +from gitingest.utils.async_compat import to_thread +from gitingest.utils.auth import resolve_token +from gitingest.utils.ignore_patterns import load_ignore_patterns async def ingest_async( source: str, + *, max_file_size: int = 10 * 1024 * 1024, # 10 MB - include_patterns: Optional[Union[str, Set[str]]] = None, - exclude_patterns: Optional[Union[str, Set[str]]] = None, - branch: Optional[str] = None, + include_patterns: str | set[str] | None = None, + exclude_patterns: str | set[str] | None = None, + branch: str | None = None, include_gitignored: bool = False, - token: Optional[str] = None, - output: Optional[str] = None, -) -> Tuple[str, str, str]: - """ - Main entry point for ingesting a source and processing its contents. + token: str | None = None, + output: str | None = None, +) -> tuple[str, str, str]: + """Ingest a source and process its contents. This function analyzes a source (URL or local path), clones the corresponding repository (if applicable), and processes its files according to the specified query parameters. It returns a summary, a tree-like @@ -36,81 +39,55 @@ async def ingest_async( source : str The source to analyze, which can be a URL (for a Git repository) or a local directory path. max_file_size : int - Maximum allowed file size for file ingestion. Files larger than this size are ignored, by default - 10*1024*1024 (10 MB). - include_patterns : Union[str, Set[str]], optional - Pattern or set of patterns specifying which files to include. If `None`, all files are included. - exclude_patterns : Union[str, Set[str]], optional - Pattern or set of patterns specifying which files to exclude. If `None`, no files are excluded. - branch : str, optional - The branch to clone and ingest. If `None`, the default branch is used. + Maximum allowed file size for file ingestion. Files larger than this size are ignored (default: 10 MB). + include_patterns : str | set[str] | None + Pattern or set of patterns specifying which files to include. If *None*, all files are included. + exclude_patterns : str | set[str] | None + Pattern or set of patterns specifying which files to exclude. If *None*, no files are excluded. + branch : str | None + The branch to clone and ingest. If *None*, the default branch is used. include_gitignored : bool - If ``True``, include files ignored by ``.gitignore``. Defaults to ``False``. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. - output : str, optional - File path where the summary and content should be written. If `None`, the results are not written to a file. + If *True*, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the ``GITHUB_TOKEN`` env var. + output : str | None + File path where the summary and content should be written. If *"-"* (dash), the results are written to stdout. + If *None*, the results are not written to a file. Returns ------- - Tuple[str, str, str] + tuple[str, str, str] A tuple containing: - A summary string of the analyzed repository or directory. - A tree-like string representation of the file structure. - The content of the files in the repository or directory. - Raises - ------ - TypeError - If `clone_repo` does not return a coroutine, or if the `source` is of an unsupported type. """ - repo_cloned = False - - if not token: - token = os.getenv("GITHUB_TOKEN") - - try: - query: IngestionQuery = await parse_query( - source=source, - max_file_size=max_file_size, - from_web=False, - include_patterns=include_patterns, - ignore_patterns=exclude_patterns, - token=token, - ) - - if not include_gitignored: - gitignore_patterns = load_gitignore_patterns(query.local_path) - query.ignore_patterns.update(gitignore_patterns) - - if query.url: - selected_branch = branch if branch else query.branch # prioritize branch argument - query.branch = selected_branch + token = resolve_token(token) + + query: IngestionQuery = await parse_query( + source=source, + max_file_size=max_file_size, + from_web=False, + include_patterns=include_patterns, + ignore_patterns=exclude_patterns, + token=token, + ) - clone_config = query.extract_clone_config() - clone_coroutine = clone_repo(clone_config, token=token) + if not include_gitignored: + _apply_gitignores(query) - if inspect.iscoroutine(clone_coroutine): - if asyncio.get_event_loop().is_running(): - await clone_coroutine - else: - asyncio.run(clone_coroutine) - else: - raise TypeError("clone_repo did not return a coroutine as expected.") + if branch: + query.branch = branch - repo_cloned = True + repo_cloned = False + try: + await _clone_if_remote(query, token) + repo_cloned = bool(query.url) summary, tree, content = ingest_query(query) - - if output == "-": - loop = asyncio.get_running_loop() - output_data = tree + "\n" + content - await loop.run_in_executor(None, sys.stdout.write, output_data) - await loop.run_in_executor(None, sys.stdout.flush) - elif output is not None: - with open(output, "w", encoding="utf-8") as f: - f.write(tree + "\n" + content) + await _write_output(tree, content=content, target=output) return summary, tree, content finally: @@ -121,16 +98,16 @@ async def ingest_async( def ingest( source: str, + *, max_file_size: int = 10 * 1024 * 1024, # 10 MB - include_patterns: Optional[Union[str, Set[str]]] = None, - exclude_patterns: Optional[Union[str, Set[str]]] = None, - branch: Optional[str] = None, + include_patterns: str | set[str] | None = None, + exclude_patterns: str | set[str] | None = None, + branch: str | None = None, include_gitignored: bool = False, - token: Optional[str] = None, - output: Optional[str] = None, -) -> Tuple[str, str, str]: - """ - Synchronous version of ingest_async. + token: str | None = None, + output: str | None = None, +) -> tuple[str, str, str]: + """Provide a synchronous wrapper around ``ingest_async``. This function analyzes a source (URL or local path), clones the corresponding repository (if applicable), and processes its files according to the specified query parameters. It returns a summary, a tree-like @@ -141,25 +118,25 @@ def ingest( source : str The source to analyze, which can be a URL (for a Git repository) or a local directory path. max_file_size : int - Maximum allowed file size for file ingestion. Files larger than this size are ignored, by default - 10*1024*1024 (10 MB). - include_patterns : Union[str, Set[str]], optional - Pattern or set of patterns specifying which files to include. If `None`, all files are included. - exclude_patterns : Union[str, Set[str]], optional - Pattern or set of patterns specifying which files to exclude. If `None`, no files are excluded. - branch : str, optional - The branch to clone and ingest. If `None`, the default branch is used. + Maximum allowed file size for file ingestion. Files larger than this size are ignored (default: 10 MB). + include_patterns : str | set[str] | None + Pattern or set of patterns specifying which files to include. If *None*, all files are included. + exclude_patterns : str | set[str] | None + Pattern or set of patterns specifying which files to exclude. If *None*, no files are excluded. + branch : str | None + The branch to clone and ingest (default: the default branch). include_gitignored : bool - If ``True``, include files ignored by ``.gitignore``. Defaults to ``False``. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. - output : str, optional - File path where the summary and content should be written. If `None`, the results are not written to a file. + If *True*, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the ``GITHUB_TOKEN`` env var. + output : str | None + File path where the summary and content should be written. If *"-"* (dash), the results are written to stdout. + If *None*, the results are not written to a file. Returns ------- - Tuple[str, str, str] + tuple[str, str, str] A tuple containing: - A summary string of the analyzed repository or directory. - A tree-like string representation of the file structure. @@ -167,7 +144,8 @@ def ingest( See Also -------- - ingest_async : The asynchronous version of this function. + ``ingest_async`` : The asynchronous version of this function. + """ return asyncio.run( ingest_async( @@ -179,5 +157,73 @@ def ingest( include_gitignored=include_gitignored, token=token, output=output, - ) + ), ) + + +def _apply_gitignores(query: IngestionQuery) -> None: + """Update `query.ignore_patterns` in-place. + + Parameters + ---------- + query : IngestionQuery + The query to update. + + """ + for fname in (".gitignore", ".gitingestignore"): + query.ignore_patterns.update(load_ignore_patterns(query.local_path, filename=fname)) + + +async def _clone_if_remote(query: IngestionQuery, token: str | None) -> None: + """Clone the repo if *query* points to a remote URL. + + Parameters + ---------- + query : IngestionQuery + The query to clone. + token : str | None + The token to use for the query. + + Raises + ------ + TypeError + If `clone_repo` does not return a coroutine. + + """ + if not query.url: # local path ingestion + return + + # let CLI arg override, else keep the parsed branch + clone_cfg = query.extract_clone_config() + clone_coroutine = clone_repo(clone_cfg, token=token) + if not inspect.iscoroutine(clone_coroutine): + msg = "clone_repo did not return a coroutine as expected." + raise TypeError(msg) + + if asyncio.get_event_loop().is_running(): + await clone_coroutine + else: # running under sync context (unit-test, etc.) + asyncio.run(clone_coroutine) + + +async def _write_output(tree: str, content: str, target: str | None) -> None: + """Write combined output to *target* (`'-'` ⇒ stdout). + + Parameters + ---------- + tree : str + The tree-like string representation of the file structure. + content : str + The content of the files in the repository or directory. + target : str | None + The path to the output file. If *None*, the results are not written to a file. + + """ + if target == "-": + loop = asyncio.get_running_loop() + data = f"{tree}\n{content}" + await loop.run_in_executor(None, sys.stdout.write, data) + await loop.run_in_executor(None, sys.stdout.flush) + elif target is not None: + path = Path(target) + await to_thread(path.write_text, f"{tree}\n{content}", encoding="utf-8") diff --git a/src/gitingest/ingestion.py b/src/gitingest/ingestion.py index ec378978..9d47c603 100644 --- a/src/gitingest/ingestion.py +++ b/src/gitingest/ingestion.py @@ -1,24 +1,21 @@ """Functions to ingest and analyze a codebase directory or single file.""" -import warnings +from __future__ import annotations + from pathlib import Path -from typing import Tuple +from typing import TYPE_CHECKING from gitingest.config import MAX_DIRECTORY_DEPTH, MAX_FILES, MAX_TOTAL_SIZE_BYTES -from gitingest.output_formatters import format_node -from gitingest.query_parsing import IngestionQuery +from gitingest.output_formatter import format_node from gitingest.schemas import FileSystemNode, FileSystemNodeType, FileSystemStats from gitingest.utils.ingestion_utils import _should_exclude, _should_include -try: - import tomllib # type: ignore[import] -except ImportError: - import tomli as tomllib +if TYPE_CHECKING: + from gitingest.query_parser import IngestionQuery -def ingest_query(query: IngestionQuery) -> Tuple[str, str, str]: - """ - Run the ingestion process for a parsed query. +def ingest_query(query: IngestionQuery) -> tuple[str, str, str]: + """Run the ingestion process for a parsed query. This is the main entry point for analyzing a codebase directory or single file. It processes the query parameters, reads the file or directory content, and generates a summary, directory structure, and file content, @@ -31,26 +28,27 @@ def ingest_query(query: IngestionQuery) -> Tuple[str, str, str]: Returns ------- - Tuple[str, str, str] + tuple[str, str, str] A tuple containing the summary, directory structure, and file contents. Raises ------ ValueError If the path cannot be found, is not a file, or the file has no content. + """ subpath = Path(query.subpath.strip("/")).as_posix() path = query.local_path / subpath - apply_gitingest_file(path, query) - if not path.exists(): - raise ValueError(f"{query.slug} cannot be found") + msg = f"{query.slug} cannot be found" + raise ValueError(msg) if (query.type and query.type == "blob") or query.local_path.is_file(): # TODO: We do this wrong! We should still check the branch and commit! if not path.is_file(): - raise ValueError(f"Path {path} is not a file") + msg = f"Path {path} is not a file" + raise ValueError(msg) relative_path = path.relative_to(query.local_path) @@ -64,7 +62,8 @@ def ingest_query(query: IngestionQuery) -> Tuple[str, str, str]: ) if not file_node.content: - raise ValueError(f"File {file_node.name} has no content") + msg = f"File {file_node.name} has no content" + raise ValueError(msg) return format_node(file_node, query) @@ -86,78 +85,8 @@ def ingest_query(query: IngestionQuery) -> Tuple[str, str, str]: return format_node(root_node, query) -def apply_gitingest_file(path: Path, query: IngestionQuery) -> None: - """ - Apply the .gitingest file to the query object. - - This function reads the .gitingest file in the specified path and updates the query object with the ignore - patterns found in the file. - - Parameters - ---------- - path : Path - The path of the directory to ingest. - query : IngestionQuery - The parsed query object containing information about the repository and query parameters. - It should have an attribute `ignore_patterns` which is either None or a set of strings. - """ - path_gitingest = path / ".gitingest" - - if not path_gitingest.is_file(): - return - - try: - with path_gitingest.open("rb") as f: - data = tomllib.load(f) - except tomllib.TOMLDecodeError as exc: - warnings.warn(f"Invalid TOML in {path_gitingest}: {exc}", UserWarning) - return - - config_section = data.get("config", {}) - ignore_patterns = config_section.get("ignore_patterns") - - if not ignore_patterns: - return - - # If a single string is provided, make it a list of one element - if isinstance(ignore_patterns, str): - ignore_patterns = [ignore_patterns] - - if not isinstance(ignore_patterns, (list, set)): - warnings.warn( - f"Expected a list/set for 'ignore_patterns', got {type(ignore_patterns)} in {path_gitingest}. Skipping.", - UserWarning, - ) - return - - # Filter out duplicated patterns - ignore_patterns = set(ignore_patterns) - - # Filter out any non-string entries - valid_patterns = {pattern for pattern in ignore_patterns if isinstance(pattern, str)} - invalid_patterns = ignore_patterns - valid_patterns - - if invalid_patterns: - warnings.warn(f"Ignore patterns {invalid_patterns} are not strings. Skipping.", UserWarning) - - if not valid_patterns: - return - - if query.ignore_patterns is None: - query.ignore_patterns = valid_patterns - else: - query.ignore_patterns.update(valid_patterns) - - return - - -def _process_node( - node: FileSystemNode, - query: IngestionQuery, - stats: FileSystemStats, -) -> None: - """ - Process a file or directory item within a directory. +def _process_node(node: FileSystemNode, query: IngestionQuery, stats: FileSystemStats) -> None: + """Process a file or directory item within a directory. This function handles each file or directory item, checking if it should be included or excluded based on the provided patterns. It handles symlinks, directories, and files accordingly. @@ -170,13 +99,12 @@ def _process_node( The parsed query object containing information about the repository and query parameters. stats : FileSystemStats Statistics tracking object for the total file count and size. - """ + """ if limit_exceeded(stats, node.depth): return for sub_path in node.path.iterdir(): - if query.ignore_patterns and _should_exclude(sub_path, query.local_path, query.ignore_patterns): continue @@ -188,7 +116,6 @@ def _process_node( elif sub_path.is_file(): _process_file(path=sub_path, parent_node=node, stats=stats, local_path=query.local_path) elif sub_path.is_dir(): - child_directory_node = FileSystemNode( name=sub_path.name, type=FileSystemNodeType.DIRECTORY, @@ -217,8 +144,7 @@ def _process_node( def _process_symlink(path: Path, parent_node: FileSystemNode, stats: FileSystemStats, local_path: Path) -> None: - """ - Process a symlink in the file system. + """Process a symlink in the file system. This function checks the symlink's target. @@ -232,6 +158,7 @@ def _process_symlink(path: Path, parent_node: FileSystemNode, stats: FileSystemS Statistics tracking object for the total file count and size. local_path : Path The base path of the repository or directory being processed. + """ child = FileSystemNode( name=path.name, @@ -246,8 +173,7 @@ def _process_symlink(path: Path, parent_node: FileSystemNode, stats: FileSystemS def _process_file(path: Path, parent_node: FileSystemNode, stats: FileSystemStats, local_path: Path) -> None: - """ - Process a file in the file system. + """Process a file in the file system. This function checks the file's size, increments the statistics, and reads its content. If the file size exceeds the maximum allowed, it raises an error. @@ -262,6 +188,7 @@ def _process_file(path: Path, parent_node: FileSystemNode, stats: FileSystemStat Statistics tracking object for the total file count and size. local_path : Path The base path of the repository or directory being processed. + """ file_size = path.stat().st_size if stats.total_size + file_size > MAX_TOTAL_SIZE_BYTES: @@ -291,8 +218,7 @@ def _process_file(path: Path, parent_node: FileSystemNode, stats: FileSystemStat def limit_exceeded(stats: FileSystemStats, depth: int) -> bool: - """ - Check if any of the traversal limits have been exceeded. + """Check if any of the traversal limits have been exceeded. This function checks if the current traversal has exceeded any of the configured limits: maximum directory depth, maximum number of files, or maximum total size in bytes. @@ -307,7 +233,8 @@ def limit_exceeded(stats: FileSystemStats, depth: int) -> bool: Returns ------- bool - True if any limit has been exceeded, False otherwise. + ``True`` if any limit has been exceeded, ``False`` otherwise. + """ if depth > MAX_DIRECTORY_DEPTH: print(f"Maximum depth limit ({MAX_DIRECTORY_DEPTH}) reached") @@ -318,7 +245,7 @@ def limit_exceeded(stats: FileSystemStats, depth: int) -> bool: return True # TODO: end recursion if stats.total_size >= MAX_TOTAL_SIZE_BYTES: - print(f"Maxumum total size limit ({MAX_TOTAL_SIZE_BYTES/1024/1024:.1f}MB) reached") + print(f"Maxumum total size limit ({MAX_TOTAL_SIZE_BYTES / 1024 / 1024:.1f}MB) reached") return True # TODO: end recursion return False diff --git a/src/gitingest/output_formatters.py b/src/gitingest/output_formatter.py similarity index 79% rename from src/gitingest/output_formatters.py rename to src/gitingest/output_formatter.py index 9ca3d474..4714fccd 100644 --- a/src/gitingest/output_formatters.py +++ b/src/gitingest/output_formatter.py @@ -1,16 +1,24 @@ """Functions to ingest and analyze a codebase directory or single file.""" -from typing import Optional, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING import tiktoken -from gitingest.query_parsing import IngestionQuery from gitingest.schemas import FileSystemNode, FileSystemNodeType +if TYPE_CHECKING: + from gitingest.query_parser import IngestionQuery -def format_node(node: FileSystemNode, query: IngestionQuery) -> Tuple[str, str, str]: - """ - Generate a summary, directory structure, and file contents for a given file system node. +_TOKEN_THRESHOLDS: list[tuple[int, str]] = [ + (1_000_000, "M"), + (1_000, "k"), +] + + +def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str, str]: + """Generate a summary, directory structure, and file contents for a given file system node. If the node represents a directory, the function will recursively process its contents. @@ -23,8 +31,9 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> Tuple[str, str, Returns ------- - Tuple[str, str, str] + tuple[str, str, str] A tuple containing the summary, directory structure, and file contents. + """ is_single_file = node.type == FileSystemNodeType.FILE summary = _create_summary_prefix(query, single_file=is_single_file) @@ -35,8 +44,7 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> Tuple[str, str, summary += f"File: {node.name}\n" summary += f"Lines: {len(node.content.splitlines()):,}\n" - tree = "Directory structure:\n" + _create_tree_structure(query, node) - _create_tree_structure(query, node) + tree = "Directory structure:\n" + _create_tree_structure(query, node=node) content = _gather_file_contents(node) @@ -47,9 +55,8 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> Tuple[str, str, return summary, tree, content -def _create_summary_prefix(query: IngestionQuery, single_file: bool = False) -> str: - """ - Create a prefix string for summarizing a repository or local directory. +def _create_summary_prefix(query: IngestionQuery, *, single_file: bool = False) -> str: + """Create a prefix string for summarizing a repository or local directory. Includes repository name (if provided), commit/branch details, and subpath if relevant. @@ -58,12 +65,13 @@ def _create_summary_prefix(query: IngestionQuery, single_file: bool = False) -> query : IngestionQuery The parsed query object containing information about the repository and query parameters. single_file : bool - A flag indicating whether the summary is for a single file, by default False. + A flag indicating whether the summary is for a single file (default: ``False``). Returns ------- str A summary prefix string containing repository, commit, branch, and subpath details. + """ parts = [] @@ -85,8 +93,7 @@ def _create_summary_prefix(query: IngestionQuery, single_file: bool = False) -> def _gather_file_contents(node: FileSystemNode) -> str: - """ - Recursively gather contents of all files under the given node. + """Recursively gather contents of all files under the given node. This function recursively processes a directory node and gathers the contents of all files under that node. It returns the concatenated content of all files as a single string. @@ -100,6 +107,7 @@ def _gather_file_contents(node: FileSystemNode) -> str: ------- str The concatenated content of all files under the given node. + """ if node.type != FileSystemNodeType.DIRECTORY: return node.content_string @@ -108,9 +116,14 @@ def _gather_file_contents(node: FileSystemNode) -> str: return "\n".join(_gather_file_contents(child) for child in node.children) -def _create_tree_structure(query: IngestionQuery, node: FileSystemNode, prefix: str = "", is_last: bool = True) -> str: - """ - Generate a tree-like string representation of the file structure. +def _create_tree_structure( + query: IngestionQuery, + node: FileSystemNode, + *, + prefix: str = "", + is_last: bool = True, +) -> str: + """Generate a tree-like string representation of the file structure. This function generates a string representation of the directory structure, formatted as a tree with appropriate indentation for nested directories and files. @@ -122,14 +135,15 @@ def _create_tree_structure(query: IngestionQuery, node: FileSystemNode, prefix: node : FileSystemNode The current directory or file node being processed. prefix : str - A string used for indentation and formatting of the tree structure, by default "". + A string used for indentation and formatting of the tree structure (default: ``""``). is_last : bool - A flag indicating whether the current node is the last in its directory, by default True. + A flag indicating whether the current node is the last in its directory (default: ``True``). Returns ------- str A string representing the directory structure formatted as a tree. + """ if not node.name: # If no name is present, use the slug as the top-level directory name @@ -154,11 +168,8 @@ def _create_tree_structure(query: IngestionQuery, node: FileSystemNode, prefix: return tree_str -def _format_token_count(text: str) -> Optional[str]: - """ - Return a human-readable string representing the token count of the given text. - - E.g., '120' -> '120', '1200' -> '1.2k', '1200000' -> '1.2M'. +def _format_token_count(text: str) -> str | None: + """Return a human-readable token-count string (e.g. 1.2k, 1.2 M). Parameters ---------- @@ -167,8 +178,9 @@ def _format_token_count(text: str) -> Optional[str]: Returns ------- - str, optional - The formatted number of tokens as a string (e.g., '1.2k', '1.2M'), or `None` if an error occurs. + str | None + The formatted number of tokens as a string (e.g., ``1.2k``, ``1.2M``), or ``None`` if an error occurs. + """ try: encoding = tiktoken.get_encoding("o200k_base") # gpt-4o, gpt-4o-mini @@ -177,10 +189,8 @@ def _format_token_count(text: str) -> Optional[str]: print(exc) return None - if total_tokens >= 1_000_000: - return f"{total_tokens / 1_000_000:.1f}M" - - if total_tokens >= 1_000: - return f"{total_tokens / 1_000:.1f}k" + for threshold, suffix in _TOKEN_THRESHOLDS: + if total_tokens >= threshold: + return f"{total_tokens / threshold:.1f}{suffix}" return str(total_tokens) diff --git a/src/gitingest/query_parsing.py b/src/gitingest/query_parser.py similarity index 72% rename from src/gitingest/query_parsing.py rename to src/gitingest/query_parser.py index 089a6f96..0078043b 100644 --- a/src/gitingest/query_parsing.py +++ b/src/gitingest/query_parser.py @@ -1,10 +1,11 @@ -"""This module contains functions to parse and validate input sources and patterns.""" +"""Module containing functions to parse and validate input sources and patterns.""" + +from __future__ import annotations import re import uuid import warnings from pathlib import Path -from typing import List, Optional, Set, Union from urllib.parse import unquote, urlparse from gitingest.config import TMP_BASE_PATH @@ -25,18 +26,14 @@ async def parse_query( source: str, + *, max_file_size: int, from_web: bool, - include_patterns: Optional[Union[str, Set[str]]] = None, - ignore_patterns: Optional[Union[str, Set[str]]] = None, - token: Optional[str] = None, + include_patterns: set[str] | str | None = None, + ignore_patterns: set[str] | str | None = None, + token: str | None = None, ) -> IngestionQuery: - """ - Parse the input source (URL or path) to extract relevant details for the query. - - This function parses the input source to extract details such as the username, repository name, - commit hash, branch name, and other relevant information. It also processes the include and ignore - patterns to filter the files and directories to include or exclude from the query. + """Parse the input source to extract details for the query and process the include and ignore patterns. Parameters ---------- @@ -46,20 +43,20 @@ async def parse_query( The maximum file size in bytes to include. from_web : bool Flag indicating whether the source is a web URL. - include_patterns : Union[str, Set[str]], optional - Patterns to include, by default None. Can be a set of strings or a single string. - ignore_patterns : Union[str, Set[str]], optional - Patterns to ignore, by default None. Can be a set of strings or a single string. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. - Must start with 'github_pat_' or 'gph_' for GitHub repositories. + include_patterns : set[str] | str | None + Patterns to include. Can be a set of strings or a single string. + ignore_patterns : set[str] | str | None + Patterns to ignore. Can be a set of strings or a single string. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the `GITHUB_TOKEN` env var. + Returns ------- IngestionQuery A dataclass object containing the parsed details of the repository or file path. - """ + """ # Determine the parsing method based on the source type if from_web or urlparse(source).scheme in ("https", "http") or any(h in source for h in KNOWN_GIT_HOSTS): # We either have a full URL or a domain-less slug @@ -98,27 +95,27 @@ async def parse_query( ) -async def _parse_remote_repo(source: str, token: Optional[str] = None) -> IngestionQuery: - """ - Parse a repository URL into a structured query dictionary. +async def _parse_remote_repo(source: str, token: str | None = None) -> IngestionQuery: + """Parse a repository URL into a structured query dictionary. If source is: - - A fully qualified URL (https://gitlab.com/...), parse & verify that domain - - A URL missing 'https://' (gitlab.com/...), add 'https://' and parse - - A 'slug' (like 'pandas-dev/pandas'), attempt known domains until we find one that exists. + - A fully qualified URL (`https://gitlab.com/...`), parse & verify that domain + - A URL missing `https://` (`gitlab.com/...`), add `https://` and parse + - A `slug` (`pandas-dev/pandas`), attempt known domains until we find one that exists. Parameters ---------- source : str The URL or domain-less slug to parse. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the ``GITHUB_TOKEN`` env var. Returns ------- IngestionQuery A dictionary containing the parsed details of the repository. + """ source = unquote(source) @@ -190,26 +187,27 @@ async def _parse_remote_repo(source: str, token: Optional[str] = None) -> Ingest return parsed -async def _configure_branch_and_subpath(remaining_parts: List[str], url: str) -> Optional[str]: - """ - Configure the branch and subpath based on the remaining parts of the URL. +async def _configure_branch_and_subpath(remaining_parts: list[str], url: str) -> str | None: + """Configure the branch and subpath based on the remaining parts of the URL. + Parameters ---------- - remaining_parts : List[str] + remaining_parts : list[str] The remaining parts of the URL path. url : str The URL of the repository. + Returns ------- - str, optional - The branch name if found, otherwise None. + str | None + The branch name if found, otherwise ``None``. """ try: # Fetch the list of branches from the remote repository - branches: List[str] = await fetch_remote_branch_list(url) + branches: list[str] = await fetch_remote_branch_list(url) except RuntimeError as exc: - warnings.warn(f"Warning: Failed to fetch branch list: {exc}", RuntimeWarning) + warnings.warn(f"Warning: Failed to fetch branch list: {exc}", RuntimeWarning, stacklevel=2) return remaining_parts.pop(0) branch = [] @@ -222,21 +220,20 @@ async def _configure_branch_and_subpath(remaining_parts: List[str], url: str) -> return None -def _parse_patterns(pattern: Union[str, Set[str]]) -> Set[str]: - """ - Parse and validate file/directory patterns for inclusion or exclusion. +def _parse_patterns(pattern: set[str] | str) -> set[str]: + """Parse and validate file/directory patterns for inclusion or exclusion. Takes either a single pattern string or set of pattern strings and processes them into a normalized list. Patterns are split on commas and spaces, validated for allowed characters, and normalized. Parameters ---------- - pattern : Set[str] | str + pattern : set[str] | str Pattern(s) to parse - either a single string or set of strings Returns ------- - Set[str] + set[str] A set of normalized patterns. Raises @@ -245,10 +242,11 @@ def _parse_patterns(pattern: Union[str, Set[str]]) -> Set[str]: If any pattern contains invalid characters. Only alphanumeric characters, dash (-), underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed. + """ patterns = pattern if isinstance(pattern, set) else {pattern} - parsed_patterns: Set[str] = set() + parsed_patterns: set[str] = set() for p in patterns: parsed_patterns = parsed_patterns.union(set(re.split(",| ", p))) @@ -267,8 +265,7 @@ def _parse_patterns(pattern: Union[str, Set[str]]) -> Set[str]: def _parse_local_dir_path(path_str: str) -> IngestionQuery: - """ - Parse the given file path into a structured query dictionary. + """Parse the given file path into a structured query dictionary. Parameters ---------- @@ -279,6 +276,7 @@ def _parse_local_dir_path(path_str: str) -> IngestionQuery: ------- IngestionQuery A dictionary containing the parsed details of the file path. + """ path_obj = Path(path_str).resolve() slug = path_obj.name if path_str == "." else path_str.strip("/") @@ -292,9 +290,8 @@ def _parse_local_dir_path(path_str: str) -> IngestionQuery: ) -async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: Optional[str] = None) -> str: - """ - Attempt to find a valid repository host for the given user_name and repo_name. +async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: str | None = None) -> str: + """Attempt to find a valid repository host for the given ``user_name`` and ``repo_name``. Parameters ---------- @@ -302,9 +299,9 @@ async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: O The username or owner of the repository. repo_name : str The name of the repository. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the ``GITHUB_TOKEN`` env var. Returns ------- @@ -314,10 +311,13 @@ async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: O Raises ------ ValueError - If no valid repository host is found for the given user_name and repo_name. + If no valid repository host is found for the given ``user_name`` and ``repo_name``. + """ for domain in KNOWN_GIT_HOSTS: candidate = f"https://{domain}/{user_name}/{repo_name}" if await check_repo_exists(candidate, token=token if domain == "github.com" else None): return domain - raise ValueError(f"Could not find a valid repository host for '{user_name}/{repo_name}'.") + + msg = f"Could not find a valid repository host for '{user_name}/{repo_name}'." + raise ValueError(msg) diff --git a/src/gitingest/schemas/__init__.py b/src/gitingest/schemas/__init__.py index c3869864..efe2dd70 100644 --- a/src/gitingest/schemas/__init__.py +++ b/src/gitingest/schemas/__init__.py @@ -1,6 +1,6 @@ -"""This module contains the schemas for the Gitingest package.""" +"""Module containing the schemas for the Gitingest package.""" -from gitingest.schemas.filesystem_schema import FileSystemNode, FileSystemNodeType, FileSystemStats -from gitingest.schemas.ingestion_schema import CloneConfig, IngestionQuery +from gitingest.schemas.filesystem import FileSystemNode, FileSystemNodeType, FileSystemStats +from gitingest.schemas.ingestion import CloneConfig, IngestionQuery -__all__ = ["FileSystemNode", "FileSystemNodeType", "FileSystemStats", "CloneConfig", "IngestionQuery"] +__all__ = ["CloneConfig", "FileSystemNode", "FileSystemNodeType", "FileSystemStats", "IngestionQuery"] diff --git a/src/gitingest/schemas/filesystem_schema.py b/src/gitingest/schemas/filesystem.py similarity index 62% rename from src/gitingest/schemas/filesystem_schema.py rename to src/gitingest/schemas/filesystem.py index 22cff569..821c46e0 100644 --- a/src/gitingest/schemas/filesystem_schema.py +++ b/src/gitingest/schemas/filesystem.py @@ -5,10 +5,13 @@ import os from dataclasses import dataclass, field from enum import Enum, auto -from pathlib import Path +from typing import TYPE_CHECKING -from gitingest.utils.file_utils import get_preferred_encodings, is_text_file -from gitingest.utils.notebook_utils import process_notebook +from gitingest.utils.file_utils import _decodes, _get_preferred_encodings, _read_chunk +from gitingest.utils.notebook import process_notebook + +if TYPE_CHECKING: + from pathlib import Path SEPARATOR = "=" * 48 # Tiktoken, the tokenizer openai uses, counts 2 tokens if we have more than 48 @@ -32,8 +35,7 @@ class FileSystemStats: @dataclass class FileSystemNode: # pylint: disable=too-many-instance-attributes - """ - Class representing a node in the file system (either a file or directory). + """Class representing a node in the file system (either a file or directory). Tracks properties of files/directories for comprehensive analysis. """ @@ -49,8 +51,7 @@ class FileSystemNode: # pylint: disable=too-many-instance-attributes children: list[FileSystemNode] = field(default_factory=list) def sort_children(self) -> None: - """ - Sort the children nodes of a directory according to a specific order. + """Sort the children nodes of a directory according to a specific order. Order of sorting: 2. Regular files (not starting with dot) @@ -64,16 +65,18 @@ def sort_children(self) -> None: ------ ValueError If the node is not a directory. + """ if self.type != FileSystemNodeType.DIRECTORY: - raise ValueError("Cannot sort children of a non-directory node") + msg = "Cannot sort children of a non-directory node" + raise ValueError(msg) def _sort_key(child: FileSystemNode) -> tuple[int, str]: # returns the priority order for the sort function, 0 is first # Groups: 0=README, 1=regular file, 2=hidden file, 3=regular dir, 4=hidden dir name = child.name.lower() if child.type == FileSystemNodeType.FILE: - if name == "readme.md": + if name == "readme" or name.startswith("readme."): return (0, name) return (1 if not name.startswith(".") else 2, name) return (3 if not name.startswith(".") else 4, name) @@ -82,13 +85,13 @@ def _sort_key(child: FileSystemNode) -> tuple[int, str]: @property def content_string(self) -> str: - """ - Return the content of the node as a string, including path and content. + """Return the content of the node as a string, including path and content. Returns ------- str A string representation of the node's content. + """ parts = [ SEPARATOR, @@ -102,8 +105,10 @@ def content_string(self) -> str: @property def content(self) -> str: # pylint: disable=too-many-return-statements - """ - Read the content of a file if it's text (or a notebook). Return an error message otherwise. + """Return file content (if text / notebook) or an explanatory placeholder. + + Heuristically decides whether the file is text or binary by decoding a small chunk of the file + with multiple encodings and checking for common binary markers. Returns ------- @@ -114,32 +119,43 @@ def content(self) -> str: # pylint: disable=too-many-return-statements ------ ValueError If the node is a directory. + """ if self.type == FileSystemNodeType.DIRECTORY: - raise ValueError("Cannot read content of a directory node") + msg = "Cannot read content of a directory node" + raise ValueError(msg) if self.type == FileSystemNodeType.SYMLINK: - return "" + return "" # TODO: are we including the empty content of symlinks? - if not is_text_file(self.path): - return "[Non-text file]" - - if self.path.suffix == ".ipynb": + if self.path.suffix == ".ipynb": # Notebook try: return process_notebook(self.path) except Exception as exc: return f"Error processing notebook: {exc}" - # Try multiple encodings - for encoding in get_preferred_encodings(): - try: - with self.path.open(encoding=encoding) as f: - return f.read() - except UnicodeDecodeError: - continue - except UnicodeError: - continue - except OSError as exc: - return f"Error reading file: {exc}" - - return "Error: Unable to decode file with available encodings" + chunk = _read_chunk(self.path) + + if chunk is None: + return "Error reading file" + + if chunk == b"": + return "[Empty file]" + + if not _decodes(chunk, "utf-8"): + return "[Binary file]" + + # Find the first encoding that decodes the sample + good_enc: str | None = next( + (enc for enc in _get_preferred_encodings() if _decodes(chunk, encoding=enc)), + None, + ) + + if good_enc is None: + return "Error: Unable to decode file with available encodings" + + try: + with self.path.open(encoding=good_enc) as fp: + return fp.read() + except (OSError, UnicodeDecodeError) as exc: + return f"Error reading file with {good_enc!r}: {exc}" diff --git a/src/gitingest/schemas/ingestion.py b/src/gitingest/schemas/ingestion.py new file mode 100644 index 00000000..efa21319 --- /dev/null +++ b/src/gitingest/schemas/ingestion.py @@ -0,0 +1,131 @@ +"""Module containing the dataclasses for the ingestion process.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path # noqa: TC003 (typing-only-standard-library-import) needed for type checking (pydantic) + +from pydantic import BaseModel, Field + +from gitingest.config import MAX_FILE_SIZE + + +@dataclass +class CloneConfig: + """Configuration for cloning a Git repository. + + This class holds the necessary parameters for cloning a repository to a local path, including + the repository's URL, the target local path, and optional parameters for a specific commit or branch. + + Attributes + ---------- + url : str + The URL of the Git repository to clone. + local_path : str + The local directory where the repository will be cloned. + commit : str | None + The specific commit hash to check out after cloning. + branch : str | None + The branch to clone. + subpath : str + The subpath to clone from the repository (default: ``"/"``). + blob: bool + Whether the repository is a blob (default: ``False``). + + """ + + url: str + local_path: str + commit: str | None = None + branch: str | None = None + subpath: str = "/" + blob: bool = False + + +class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes + """Pydantic model to store the parsed details of the repository or file path. + + Attributes + ---------- + user_name : str | None + The username or owner of the repository. + repo_name : str | None + The name of the repository. + local_path : Path + The local path to the repository or file. + url : str | None + The URL of the repository. + slug : str + The slug of the repository. + id : str + The ID of the repository. + subpath : str + The subpath to the repository or file (default: ``"/"``). + type : str | None + The type of the repository or file. + branch : str | None + The branch of the repository. + commit : str | None + The commit of the repository. + max_file_size : int + The maximum file size to ingest (default: 10 MB). + ignore_patterns : set[str] + The patterns to ignore (default: ``set()``). + include_patterns : set[str] | None + The patterns to include. + + """ + + user_name: str | None = None + repo_name: str | None = None + local_path: Path + url: str | None = None + slug: str + id: str + subpath: str = "/" + type: str | None = None + branch: str | None = None + commit: str | None = None + max_file_size: int = Field(default=MAX_FILE_SIZE) + ignore_patterns: set[str] = set() # TODO: ignore_patterns and include_patterns have the same type + include_patterns: set[str] | None = None + + def extract_clone_config(self) -> CloneConfig: + """Extract the relevant fields for the CloneConfig object. + + Returns + ------- + CloneConfig + A CloneConfig object containing the relevant fields. + + Raises + ------ + ValueError + If the `url` parameter is not provided. + + """ + if not self.url: + msg = "The 'url' parameter is required." + raise ValueError(msg) + + return CloneConfig( + url=self.url, + local_path=str(self.local_path), + commit=self.commit, + branch=self.branch, + subpath=self.subpath, + blob=self.type == "blob", + ) + + def ensure_url(self) -> None: + """Raise if the parsed query has no URL (invalid user input). + + Raises + ------ + ValueError + If the parsed query has no URL (invalid user input). + + """ + if not self.url: + msg = "The 'url' parameter is required." + raise ValueError(msg) diff --git a/src/gitingest/schemas/ingestion_schema.py b/src/gitingest/schemas/ingestion_schema.py deleted file mode 100644 index 43ea6c42..00000000 --- a/src/gitingest/schemas/ingestion_schema.py +++ /dev/null @@ -1,89 +0,0 @@ -"""This module contains the dataclasses for the ingestion process.""" - -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Set - -from pydantic import BaseModel, ConfigDict, Field - -from gitingest.config import MAX_FILE_SIZE - - -@dataclass -class CloneConfig: - """ - Configuration for cloning a Git repository. - - This class holds the necessary parameters for cloning a repository to a local path, including - the repository's URL, the target local path, and optional parameters for a specific commit or branch. - - Attributes - ---------- - url : str - The URL of the Git repository to clone. - local_path : str - The local directory where the repository will be cloned. - commit : str, optional - The specific commit hash to check out after cloning (default is None). - branch : str, optional - The branch to clone (default is None). - subpath : str - The subpath to clone from the repository (default is "/"). - blob: bool - Whether the repository is a blob (default is False). - """ - - url: str - local_path: str - commit: Optional[str] = None - branch: Optional[str] = None - subpath: str = "/" - blob: bool = False - - -class IngestionQuery(BaseModel): # pylint: disable=too-many-instance-attributes - """ - Pydantic model to store the parsed details of the repository or file path. - """ - - user_name: Optional[str] = None - repo_name: Optional[str] = None - local_path: Path - url: Optional[str] = None - slug: str - id: str - subpath: str = "/" - type: Optional[str] = None - branch: Optional[str] = None - commit: Optional[str] = None - max_file_size: int = Field(default=MAX_FILE_SIZE) - ignore_patterns: Optional[Set[str]] = None - include_patterns: Optional[Set[str]] = None - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def extract_clone_config(self) -> CloneConfig: - """ - Extract the relevant fields for the CloneConfig object. - - Returns - ------- - CloneConfig - A CloneConfig object containing the relevant fields. - - Raises - ------ - ValueError - If the 'url' parameter is not provided. - """ - if not self.url: - raise ValueError("The 'url' parameter is required.") - - return CloneConfig( - url=self.url, - local_path=str(self.local_path), - commit=self.commit, - branch=self.branch, - subpath=self.subpath, - blob=self.type == "blob", - ) diff --git a/src/gitingest/utils/__init__.py b/src/gitingest/utils/__init__.py index e69de29b..fda3ee4b 100644 --- a/src/gitingest/utils/__init__.py +++ b/src/gitingest/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for the gitingest package.""" diff --git a/src/gitingest/utils/async_compat.py b/src/gitingest/utils/async_compat.py new file mode 100644 index 00000000..18123799 --- /dev/null +++ b/src/gitingest/utils/async_compat.py @@ -0,0 +1,35 @@ +"""Compatibility utilities wrapping asyncio features that are missing on older Python versions (e.g. 3.8).""" + +from __future__ import annotations + +import asyncio +import contextvars +import functools +import sys +from typing import Callable, TypeVar + +from gitingest.utils.compat_typing import ParamSpec + +P = ParamSpec("P") +R = TypeVar("R") + + +if sys.version_info >= (3, 9): + from asyncio import to_thread +else: + + async def to_thread(func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R: + """Back-port :func:`asyncio.to_thread` for Python < 3.9. + + Run `func` in the default thread-pool executor and return the result. + """ + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) + + # Patch stdlib so that *existing* imports of ``asyncio`` see the shim. + if not hasattr(asyncio, "to_thread"): + asyncio.to_thread = to_thread + +__all__ = ["to_thread"] diff --git a/src/gitingest/utils/auth.py b/src/gitingest/utils/auth.py new file mode 100644 index 00000000..6fe241d1 --- /dev/null +++ b/src/gitingest/utils/auth.py @@ -0,0 +1,22 @@ +"""Utilities for handling authentication.""" + +from __future__ import annotations + +import os + + +def resolve_token(token: str | None) -> str | None: + """Resolve the token to use for the query. + + Parameters + ---------- + token : str | None + The token to use for the query. + + Returns + ------- + str | None + The resolved token. + + """ + return token or os.getenv("GITHUB_TOKEN") diff --git a/src/gitingest/utils/compat_typing.py b/src/gitingest/utils/compat_typing.py new file mode 100644 index 00000000..0ffd9555 --- /dev/null +++ b/src/gitingest/utils/compat_typing.py @@ -0,0 +1,13 @@ +"""Compatibility layer for typing.""" + +try: + from typing import ParamSpec, TypeAlias # Py ≥ 3.10 +except ImportError: + from typing_extensions import ParamSpec, TypeAlias # Py 3.8 / 3.9 + +try: + from typing import Annotated # Py ≥ 3.9 +except ImportError: + from typing_extensions import Annotated # Py 3.8 + +__all__ = ["Annotated", "ParamSpec", "TypeAlias"] diff --git a/src/gitingest/utils/exceptions.py b/src/gitingest/utils/exceptions.py index 5b9f33b4..dac985db 100644 --- a/src/gitingest/utils/exceptions.py +++ b/src/gitingest/utils/exceptions.py @@ -2,28 +2,29 @@ class InvalidPatternError(ValueError): - """ - Exception raised when a pattern contains invalid characters. + """Exception raised when a pattern contains invalid characters. + This exception is used to signal that a pattern provided for some operation contains characters that are not allowed. The valid characters for the pattern include alphanumeric characters, dash (-), underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*). + Parameters ---------- pattern : str The invalid pattern that caused the error. + """ def __init__(self, pattern: str) -> None: super().__init__( f"Pattern '{pattern}' contains invalid characters. Only alphanumeric characters, dash (-), " - "underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed." + "underscore (_), dot (.), forward slash (/), plus (+), and asterisk (*) are allowed.", ) class AsyncTimeoutError(Exception): - """ - Exception raised when an async operation exceeds its timeout limit. + """Exception raised when an async operation exceeds its timeout limit. This exception is used by the `async_timeout` decorator to signal that the wrapped asynchronous function has exceeded the specified time limit for execution. @@ -43,5 +44,5 @@ class InvalidGitHubTokenError(ValueError): def __init__(self) -> None: super().__init__( "Invalid GitHub token format. Token should start with 'github_pat_' or 'ghp_' " - "followed by at least 36 characters of letters, numbers, and underscores." + "followed by at least 36 characters of letters, numbers, and underscores.", ) diff --git a/src/gitingest/utils/file_utils.py b/src/gitingest/utils/file_utils.py index 28c3d4eb..ac2b8940 100644 --- a/src/gitingest/utils/file_utils.py +++ b/src/gitingest/utils/file_utils.py @@ -1,74 +1,77 @@ """Utility functions for working with files and directories.""" +from __future__ import annotations + import locale import platform -from pathlib import Path -from typing import List +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path try: locale.setlocale(locale.LC_ALL, "") except locale.Error: locale.setlocale(locale.LC_ALL, "C") +_CHUNK_SIZE = 1024 # bytes -def get_preferred_encodings() -> List[str]: - """ - Get list of encodings to try, prioritized for the current platform. + +def _get_preferred_encodings() -> list[str]: + """Get list of encodings to try, prioritized for the current platform. Returns ------- - List[str] + list[str] List of encoding names to try in priority order, starting with the platform's default encoding followed by common fallback encodings. + """ encodings = [locale.getpreferredencoding(), "utf-8", "utf-16", "utf-16le", "utf-8-sig", "latin"] if platform.system() == "Windows": encodings += ["cp1252", "iso-8859-1"] - return encodings + return list(dict.fromkeys(encodings)) -def is_text_file(path: Path) -> bool: - """ - Determine if the file is likely a text file by trying to decode a small chunk - with multiple encodings, and checking for common binary markers. +def _read_chunk(path: Path) -> bytes | None: + """Attempt to read the first *size* bytes of *path* in binary mode. Parameters ---------- path : Path - The path to the file to check. + The path to the file to read. Returns ------- - bool - True if the file is likely textual; False if it appears to be binary. - """ + bytes | None + The first *_CHUNK_SIZE* bytes of *path*, or None on any `OSError`. - # Attempt to read a portion of the file in binary mode + """ try: - with path.open("rb") as f: - chunk = f.read(1024) + with path.open("rb") as fp: + return fp.read(_CHUNK_SIZE) except OSError: - return False + return None - # If file is empty, treat as text - if not chunk: - return True - # Check obvious binary bytes - if b"\x00" in chunk or b"\xff" in chunk: - return False +def _decodes(chunk: bytes, encoding: str) -> bool: + """Return *True* if *chunk* decodes cleanly with *encoding*. - # Attempt multiple encodings - for enc in get_preferred_encodings(): - try: - with path.open(encoding=enc) as f: - f.read() - return True - except UnicodeDecodeError: - continue - except UnicodeError: - continue - except OSError: - return False - - return False + Parameters + ---------- + chunk : bytes + The chunk of bytes to decode. + encoding : str + The encoding to use to decode the chunk. + + Returns + ------- + bool + True if the chunk decodes cleanly with the encoding, False otherwise. + + """ + try: + chunk.decode(encoding) + except UnicodeDecodeError: + return False + return True diff --git a/src/gitingest/utils/git_utils.py b/src/gitingest/utils/git_utils.py index 5735c27e..53f324e4 100644 --- a/src/gitingest/utils/git_utils.py +++ b/src/gitingest/utils/git_utils.py @@ -1,9 +1,10 @@ """Utility functions for interacting with Git repositories.""" +from __future__ import annotations + import asyncio import base64 import re -from typing import List, Optional, Tuple from urllib.parse import urlparse from gitingest.utils.exceptions import InvalidGitHubTokenError @@ -11,9 +12,8 @@ GITHUB_PAT_PATTERN = r"^(?:github_pat_|ghp_)[A-Za-z0-9_]{36,}$" -def _is_github_host(url: str) -> bool: - """ - Check if a URL is from a GitHub host (github.com or GitHub Enterprise). +def is_github_host(url: str) -> bool: + """Check if a URL is from a GitHub host (github.com or GitHub Enterprise). Parameters ---------- @@ -24,15 +24,14 @@ def _is_github_host(url: str) -> bool: ------- bool True if the URL is from a GitHub host, False otherwise + """ - parsed = urlparse(url) - hostname = parsed.hostname or "" + hostname = urlparse(url).hostname or "" return hostname == "github.com" or hostname.startswith("github.") -async def run_command(*args: str) -> Tuple[bytes, bytes]: - """ - Execute a shell command asynchronously and return (stdout, stderr) bytes. +async def run_command(*args: str) -> tuple[bytes, bytes]: + """Execute a shell command asynchronously and return (stdout, stderr) bytes. Parameters ---------- @@ -41,13 +40,14 @@ async def run_command(*args: str) -> Tuple[bytes, bytes]: Returns ------- - Tuple[bytes, bytes] + tuple[bytes, bytes] A tuple containing the stdout and stderr of the command. Raises ------ RuntimeError If command exits with a non-zero status. + """ # Execute the requested command proc = await asyncio.create_subprocess_exec( @@ -58,37 +58,38 @@ async def run_command(*args: str) -> Tuple[bytes, bytes]: stdout, stderr = await proc.communicate() if proc.returncode != 0: error_message = stderr.decode().strip() - raise RuntimeError(f"Command failed: {' '.join(args)}\nError: {error_message}") + msg = f"Command failed: {' '.join(args)}\nError: {error_message}" + raise RuntimeError(msg) return stdout, stderr async def ensure_git_installed() -> None: - """ - Ensure Git is installed and accessible on the system. + """Ensure Git is installed and accessible on the system. Raises ------ RuntimeError If Git is not installed or not accessible. + """ try: await run_command("git", "--version") except RuntimeError as exc: - raise RuntimeError("Git is not installed or not accessible. Please install Git first.") from exc + msg = "Git is not installed or not accessible. Please install Git first." + raise RuntimeError(msg) from exc -async def check_repo_exists(url: str, token: Optional[str] = None) -> bool: - """ - Check if a Git repository exists at the provided URL. +async def check_repo_exists(url: str, token: str | None = None) -> bool: + """Check if a Git repository exists at the provided URL. Parameters ---------- url : str The URL of the Git repository to check. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the `GITHUB_TOKEN` env var. Returns ------- @@ -99,9 +100,11 @@ async def check_repo_exists(url: str, token: Optional[str] = None) -> bool: ------ RuntimeError If the curl command returns an unexpected status code. + """ - if token and _is_github_host(url): - return await _check_github_repo_exists(url, token) + expected_path_length = 2 + if token and is_github_host(url): + return await _check_github_repo_exists(url, token=token) proc = await asyncio.create_subprocess_exec( "curl", @@ -118,26 +121,26 @@ async def check_repo_exists(url: str, token: Optional[str] = None) -> bool: response = stdout.decode() status_line = response.splitlines()[0].strip() parts = status_line.split(" ") - if len(parts) >= 2: + if len(parts) >= expected_path_length: status_code_str = parts[1] if status_code_str in ("200", "301"): return True if status_code_str in ("302", "404"): return False - raise RuntimeError(f"Unexpected status line: {status_line}") + msg = f"Unexpected status line: {status_line}" + raise RuntimeError(msg) -async def _check_github_repo_exists(url: str, token: Optional[str] = None) -> bool: - """ - Return True iff the authenticated user can see `url`. +async def _check_github_repo_exists(url: str, token: str | None = None) -> bool: + """Return True iff the authenticated user can see ``url``. Parameters ---------- url : str The URL of the GitHub repository to check. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the `GITHUB_TOKEN` env var. Returns ------- @@ -146,23 +149,17 @@ async def _check_github_repo_exists(url: str, token: Optional[str] = None) -> bo Raises ------ - ValueError - If the URL is not a valid GitHub repository URL. RuntimeError If the repository is not found, if the provided URL is invalid, or if the token format is invalid. + """ - m = re.match(r"https?://github\.([^/]*)/([^/]+)/([^/]+?)(?:\.git)?/?$", url) - if not m: - m = re.match(r"https?://github\.com/([^/]+)/([^/]+?)(?:\.git)?/?$", url) - if not m: - raise ValueError(f"Un-recognised GitHub URL: {url!r}") - owner, repo = m.groups() + host, owner, repo = _parse_github_url(url) + + if host == "github.com": api = f"https://api.github.com/repos/{owner}/{repo}" - else: - _, owner, repo = m.groups() + else: # GitHub Enterprise + api = f"https://{host}/api/v3/repos/{owner}/{repo}" - parsed = urlparse(url) - api = f"https://{parsed.hostname}/api/v3/repos/{owner}/{repo}" cmd = [ "curl", "--silent", @@ -191,38 +188,72 @@ async def _check_github_repo_exists(url: str, token: Optional[str] = None) -> bo if status == "404": return False if status in ("401", "403"): - raise RuntimeError("Token invalid or lacks permissions") - raise RuntimeError(f"GitHub API returned unexpected HTTP {status}") + msg = "Token invalid or lacks permissions" + raise RuntimeError(msg) + msg = f"GitHub API returned unexpected HTTP {status}" + raise RuntimeError(msg) + + +def _parse_github_url(url: str) -> tuple[str, str, str]: + """Parse a GitHub URL and return (hostname, owner, repo). + + Parameters + ---------- + url : str + The URL of the GitHub repository to parse. + Returns + ------- + tuple[str, str, str] + A tuple containing the hostname, owner, and repository name. + + Raises + ------ + ValueError + If the URL is not a valid GitHub repository URL. -async def fetch_remote_branch_list(url: str, token: Optional[str] = None) -> List[str]: """ - Fetch the list of branches from a remote Git repository. + expected_path_length = 2 + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + msg = f"URL must start with http:// or https://: {url!r}" + raise ValueError(msg) + + if not parsed.hostname or not parsed.hostname.startswith("github."): + msg = f"Un-recognised GitHub hostname: {parsed.hostname!r}" + raise ValueError(msg) + + parts = parsed.path.strip("/").removesuffix(".git").split("/") + if len(parts) != expected_path_length: + msg = f"Path must look like //: {parsed.path!r}" + raise ValueError(msg) + + owner, repo = parts + return parsed.hostname, owner, repo + + +async def fetch_remote_branch_list(url: str, token: str | None = None) -> list[str]: + """Fetch the list of branches from a remote Git repository. Parameters ---------- url : str The URL of the Git repository to fetch branches from. - token : str, optional - GitHub personal-access token (PAT). Needed when *source* refers to a - **private** repository. Can also be set via the ``GITHUB_TOKEN`` env var. + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. + Can also be set via the `GITHUB_TOKEN` env var. Returns ------- - List[str] + list[str] A list of branch names available in the remote repository. + """ fetch_branches_command = ["git"] # Add authentication if needed - if token and _is_github_host(url): - # Only pass URL if it's not the default github.com to maintain backward compatibility - - parsed = urlparse(url) - if parsed.hostname == "github.com": - fetch_branches_command += ["-c", create_git_auth_header(token)] - else: - fetch_branches_command += ["-c", create_git_auth_header(token, url)] + if token and is_github_host(url): + fetch_branches_command += ["-c", create_git_auth_header(token, url)] fetch_branches_command += ["ls-remote", "--heads", url] @@ -237,35 +268,30 @@ async def fetch_remote_branch_list(url: str, token: Optional[str] = None) -> Lis ] -def create_git_command(base_cmd: List[str], local_path: str, url: str, token: Optional[str] = None) -> List[str]: +def create_git_command(base_cmd: list[str], local_path: str, url: str, token: str | None = None) -> list[str]: """Create a git command with authentication if needed. Parameters ---------- - base_cmd : List[str] + base_cmd : list[str] The base git command to start with local_path : str The local path where the git command should be executed url : str The repository URL to check if it's a GitHub repository - token : Optional[str] + token : str | None GitHub personal access token for authentication Returns ------- - List[str] + list[str] The git command with authentication if needed + """ - cmd = base_cmd + ["-C", local_path] - if token and _is_github_host(url): + cmd = [*base_cmd, "-C", local_path] + if token and is_github_host(url): validate_github_token(token) - # Only pass URL if it's not the default github.com to maintain backward compatibility - - parsed = urlparse(url) - if parsed.hostname == "github.com": - cmd += ["-c", create_git_auth_header(token)] - else: - cmd += ["-c", create_git_auth_header(token, url)] + cmd += ["-c", create_git_auth_header(token, url=url)] return cmd @@ -278,16 +304,15 @@ def create_git_auth_header(token: str, url: str = "https://github.com") -> str: GitHub personal access token url : str The GitHub URL to create the authentication header for. - Defaults to "https://github.com". + Defaults to "https://github.com" if not provided. Returns ------- str The git config command for setting the authentication header - """ - parsed = urlparse(url) - hostname = parsed.hostname or "github.com" + """ + hostname = urlparse(url).hostname basic = base64.b64encode(f"x-oauth-basic:{token}".encode()).decode() return f"http.https://{hostname}/.extraheader=Authorization: Basic {basic}" @@ -304,6 +329,7 @@ def validate_github_token(token: str) -> None: ------ InvalidGitHubTokenError If the token format is invalid + """ if not re.match(GITHUB_PAT_PATTERN, token): - raise InvalidGitHubTokenError() + raise InvalidGitHubTokenError diff --git a/src/gitingest/utils/ignore_patterns.py b/src/gitingest/utils/ignore_patterns.py index 7f4d6454..079ce758 100644 --- a/src/gitingest/utils/ignore_patterns.py +++ b/src/gitingest/utils/ignore_patterns.py @@ -1,10 +1,10 @@ """Default ignore patterns for Gitingest.""" -import os +from __future__ import annotations + from pathlib import Path -from typing import Set -DEFAULT_IGNORE_PATTERNS: Set[str] = { +DEFAULT_IGNORE_PATTERNS: set[str] = { # Python "*.pyc", "*.pyo", @@ -94,6 +94,7 @@ ".svn", ".hg", ".gitignore", + ".gitingestignore", # Ignore rules specific to Gitingest ".gitattributes", ".gitmodules", # Images and media @@ -164,45 +165,74 @@ } -def load_gitignore_patterns(root: Path) -> Set[str]: +def load_ignore_patterns(root: Path, filename: str) -> set[str]: + """Load ignore patterns from ``filename`` found under ``root``. + + The loader walks the directory tree, looks for the supplied ``filename``, + and returns a unified ``set`` of patterns. It implements the same parsing rules + we use for ``.gitignore`` and ``.gitingestignore`` (git-wildmatch syntax with + support for negation and root-relative paths). + + Parameters + ---------- + root : Path + Directory to walk. + filename : str + The filename to look for in each directory. + + Returns + ------- + set[str] + A set of ignore patterns extracted from the ``filename`` file found under the ``root`` directory. + """ - Recursively load ignore patterns from all .gitignore files under the given root directory. + patterns: set[str] = set() + + for ignore_file in root.rglob(filename): + if ignore_file.is_file(): + patterns.update(_parse_ignore_file(ignore_file, root)) + + return patterns + + +def _parse_ignore_file(ignore_file: Path, root: Path) -> set[str]: + """Parse an ignore file and return a set of ignore patterns. Parameters ---------- + ignore_file : Path + The path to the ignore file. root : Path - The root directory to search for .gitignore files. + The root directory of the repository. Returns ------- - Set[str] - A set of ignore patterns extracted from all .gitignore files found under the root directory. + set[str] + A set of ignore patterns. + """ - patterns: Set[str] = set() - for dirpath, _, filenames in os.walk(root): - if ".gitignore" not in filenames: - continue - - gitignore_path = Path(dirpath) / ".gitignore" - with gitignore_path.open("r", encoding="utf-8") as f: - for line in f: - stripped = line.strip() - - if not stripped or stripped.startswith("#"): - continue - - negated = stripped.startswith("!") - if negated: - stripped = stripped[1:] - - rel_dir = os.path.relpath(dirpath, root) - if stripped.startswith("/"): - pattern_body = os.path.join(rel_dir, stripped.lstrip("/")) - else: - pattern_body = os.path.join(rel_dir, stripped) if rel_dir != "." else stripped - - pattern_body = pattern_body.replace("\\", "/") - pattern = f"!{pattern_body}" if negated else pattern_body - patterns.add(pattern) + patterns: set[str] = set() + + # Path of the ignore file relative to the repository root + rel_dir = ignore_file.parent.relative_to(root) + base_dir = Path() if rel_dir == Path() else rel_dir + + with ignore_file.open(encoding="utf-8") as fh: + for raw in fh: + line = raw.strip() + if not line or line.startswith("#"): # comments / blank lines + continue + + # Handle negation (`!foobar`) + negated = line.startswith("!") + if negated: + line = line[1:] + + # Handle leading slash (`/foobar`) + if line.startswith("/"): + line = line.lstrip("/") + + pattern_body = (base_dir / line).as_posix() + patterns.add(f"!{pattern_body}" if negated else pattern_body) return patterns diff --git a/src/gitingest/utils/ingestion_utils.py b/src/gitingest/utils/ingestion_utils.py index 924490a9..c4ba24ae 100644 --- a/src/gitingest/utils/ingestion_utils.py +++ b/src/gitingest/utils/ingestion_utils.py @@ -1,55 +1,47 @@ """Utility functions for the ingestion process.""" -from pathlib import Path -from typing import Set +from __future__ import annotations + +from typing import TYPE_CHECKING from pathspec import PathSpec +if TYPE_CHECKING: + from pathlib import Path -def _should_include(path: Path, base_path: Path, include_patterns: Set[str]) -> bool: - """ - Determine if the given file or directory path matches any of the include patterns. - This function checks whether the relative path of a file or directory matches any of the specified patterns. If a - match is found, it returns `True`, indicating that the file or directory should be included in further processing. +def _should_include(path: Path, base_path: Path, include_patterns: set[str]) -> bool: + """Return ``True`` if ``path`` matches any of ``include_patterns``. Parameters ---------- path : Path The absolute path of the file or directory to check. + base_path : Path The base directory from which the relative path is calculated. - include_patterns : Set[str] + + include_patterns : set[str] A set of patterns to check against the relative path. Returns ------- bool - `True` if the path matches any of the include patterns, `False` otherwise. + ``True`` if the path matches any of the include patterns, ``False`` otherwise. + """ - try: - rel_path = path.relative_to(base_path) - except ValueError: - # If path is not under base_path at all + rel_path = _relative_or_none(path, base_path) + if rel_path is None: # outside repo → do *not* include return False - - rel_str = str(rel_path) - - # if path is a directory, include it by default - if path.is_dir(): + if path.is_dir(): # keep directories so children are visited return True spec = PathSpec.from_lines("gitwildmatch", include_patterns) - return spec.match_file(rel_str) - + return spec.match_file(str(rel_path)) -def _should_exclude(path: Path, base_path: Path, ignore_patterns: Set[str]) -> bool: - """ - Determine if the given file or directory path matches any of the ignore patterns. - This function checks whether the relative path of a file or directory matches - any of the specified ignore patterns. If a match is found, it returns `True`, indicating - that the file or directory should be excluded from further processing. +def _should_exclude(path: Path, base_path: Path, ignore_patterns: set[str]) -> bool: + """Return ``True`` if ``path`` matches any of ``ignore_patterns``. Parameters ---------- @@ -57,20 +49,40 @@ def _should_exclude(path: Path, base_path: Path, ignore_patterns: Set[str]) -> b The absolute path of the file or directory to check. base_path : Path The base directory from which the relative path is calculated. - ignore_patterns : Set[str] + ignore_patterns : set[str] A set of patterns to check against the relative path. Returns ------- bool - `True` if the path matches any of the ignore patterns, `False` otherwise. + ``True`` if the path matches any of the ignore patterns, ``False`` otherwise. + """ - try: - rel_path = path.relative_to(base_path) - except ValueError: - # If path is not under base_path at all + rel_path = _relative_or_none(path, base_path) + if rel_path is None: # outside repo → already “excluded” return True - rel_str = str(rel_path) spec = PathSpec.from_lines("gitwildmatch", ignore_patterns) - return spec.match_file(rel_str) + return spec.match_file(str(rel_path)) + + +def _relative_or_none(path: Path, base: Path) -> Path | None: + """Return *path* relative to *base* or ``None`` if *path* is outside *base*. + + Parameters + ---------- + path : Path + The absolute path of the file or directory to check. + base : Path + The base directory from which the relative path is calculated. + + Returns + ------- + Path | None + The relative path of *path* to *base*, or ``None`` if *path* is outside *base*. + + """ + try: + return path.relative_to(base) + except ValueError: # path is not a sub-path of base + return None diff --git a/src/gitingest/utils/notebook_utils.py b/src/gitingest/utils/notebook.py similarity index 64% rename from src/gitingest/utils/notebook_utils.py rename to src/gitingest/utils/notebook.py index bae62064..cfa09238 100644 --- a/src/gitingest/utils/notebook_utils.py +++ b/src/gitingest/utils/notebook.py @@ -1,24 +1,27 @@ """Utilities for processing Jupyter notebooks.""" +from __future__ import annotations + import json import warnings from itertools import chain -from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from gitingest.utils.exceptions import InvalidNotebookError +if TYPE_CHECKING: + from pathlib import Path -def process_notebook(file: Path, include_output: bool = True) -> str: - """ - Process a Jupyter notebook file and return an executable Python script as a string. + +def process_notebook(file: Path, *, include_output: bool = True) -> str: + """Process a Jupyter notebook file and return an executable Python script as a string. Parameters ---------- file : Path The path to the Jupyter notebook file. include_output : bool - Whether to include cell outputs in the generated script, by default True. + Whether to include cell outputs in the generated script (default: ``True``). Returns ------- @@ -29,12 +32,14 @@ def process_notebook(file: Path, include_output: bool = True) -> str: ------ InvalidNotebookError If the notebook file is invalid or cannot be processed. + """ try: with file.open(encoding="utf-8") as f: - notebook: Dict[str, Any] = json.load(f) + notebook: dict[str, Any] = json.load(f) except json.JSONDecodeError as exc: - raise InvalidNotebookError(f"Invalid JSON in notebook: {file}") from exc + msg = f"Invalid JSON in notebook: {file}" + raise InvalidNotebookError(msg) from exc # Check if the notebook contains worksheets worksheets = notebook.get("worksheets") @@ -45,10 +50,15 @@ def process_notebook(file: Path, include_output: bool = True) -> str: "https://github.com/ipython/ipython/wiki/IPEP-17:-Notebook-Format-4#remove-multiple-worksheets " "for more information.)", DeprecationWarning, + stacklevel=2, ) if len(worksheets) > 1: - warnings.warn("Multiple worksheets detected. Combining all worksheets into a single script.", UserWarning) + warnings.warn( + "Multiple worksheets detected. Combining all worksheets into a single script.", + UserWarning, + stacklevel=2, + ) cells = list(chain.from_iterable(ws["cells"] for ws in worksheets)) @@ -65,32 +75,33 @@ def process_notebook(file: Path, include_output: bool = True) -> str: return "\n\n".join(result) + "\n" -def _process_cell(cell: Dict[str, Any], include_output: bool) -> Optional[str]: - """ - Process a Jupyter notebook cell and return the cell content as a string. +def _process_cell(cell: dict[str, Any], *, include_output: bool) -> str | None: + """Process a Jupyter notebook cell and return the cell content as a string. Parameters ---------- - cell : Dict[str, Any] + cell : dict[str, Any] The cell dictionary from a Jupyter notebook. include_output : bool - Whether to include cell outputs in the generated script + Whether to include cell outputs in the generated script. Returns ------- - str, optional - The cell content as a string, or None if the cell is empty. + str | None + The cell content as a string, or ``None`` if the cell is empty. Raises ------ ValueError If an unexpected cell type is encountered. + """ cell_type = cell["cell_type"] # Validate cell type and handle unexpected types if cell_type not in ("markdown", "code", "raw"): - raise ValueError(f"Unknown cell type: {cell_type}") + msg = f"Unknown cell type: {cell_type}" + raise ValueError(msg) cell_str = "".join(cell["source"]) @@ -105,40 +116,34 @@ def _process_cell(cell: Dict[str, Any], include_output: bool) -> Optional[str]: # Add cell output as comments outputs = cell.get("outputs") if include_output and outputs: - # Include cell outputs as comments - output_lines = [] - + raw_lines: list[str] = [] for output in outputs: - output_lines += _extract_output(output) - - for output_line in output_lines: - if not output_line.endswith("\n"): - output_line += "\n" + raw_lines += _extract_output(output) - cell_str += "\n# Output:\n# " + "\n# ".join(output_lines) + cell_str += "\n# Output:\n# " + "\n# ".join(raw_lines) return cell_str -def _extract_output(output: Dict[str, Any]) -> List[str]: - """ - Extract the output from a Jupyter notebook cell. +def _extract_output(output: dict[str, Any]) -> list[str]: + """Extract the output from a Jupyter notebook cell. Parameters ---------- - output : Dict[str, Any] + output : dict[str, Any] The output dictionary from a Jupyter notebook cell. Returns ------- - List[str] + list[str] The output as a list of strings. Raises ------ ValueError If an unknown output type is encountered. + """ output_type = output["output_type"] @@ -151,4 +156,5 @@ def _extract_output(output: Dict[str, Any]) -> List[str]: if output_type == "error": return [f"Error: {output['ename']}: {output['evalue']}"] - raise ValueError(f"Unknown output type: {output_type}") + msg = f"Unknown output type: {output_type}" + raise ValueError(msg) diff --git a/src/gitingest/utils/os_utils.py b/src/gitingest/utils/os_utils.py index a2d49916..438c108c 100644 --- a/src/gitingest/utils/os_utils.py +++ b/src/gitingest/utils/os_utils.py @@ -1,12 +1,10 @@ """Utility functions for working with the operating system.""" -import os from pathlib import Path async def ensure_directory(path: Path) -> None: - """ - Ensure the directory exists, creating it if necessary. + """Ensure the directory exists, creating it if necessary. Parameters ---------- @@ -17,8 +15,10 @@ async def ensure_directory(path: Path) -> None: ------ OSError If the directory cannot be created + """ try: - os.makedirs(path, exist_ok=True) + path.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise OSError(f"Failed to create directory {path}: {exc}") from exc + msg = f"Failed to create directory {path}: {exc}" + raise OSError(msg) from exc diff --git a/src/gitingest/utils/path_utils.py b/src/gitingest/utils/path_utils.py index c6edd501..fe387da4 100644 --- a/src/gitingest/utils/path_utils.py +++ b/src/gitingest/utils/path_utils.py @@ -1,39 +1,34 @@ """Utility functions for working with file paths.""" -import os import platform from pathlib import Path def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool: - """ - Check if a symlink points to a location within the base directory. - - This function resolves the target of a symlink and ensures it is within the specified - base directory, returning `True` if it is safe, or `False` if the symlink points outside - the base directory. + """Return ``True`` if *symlink_path* resolves inside *base_path*. Parameters ---------- symlink_path : Path - The path of the symlink to check. + Symlink whose target should be validated. base_path : Path - The base directory to ensure the symlink points within. + Directory that the symlink target must remain within. Returns ------- bool - `True` if the symlink points within the base directory, `False` otherwise. + Whether the symlink is “safe” (i.e., does not escape *base_path*). + """ - try: - if platform.system() == "Windows": - if not os.path.islink(str(symlink_path)): - return False + # On Windows a non-symlink is immediately unsafe + if platform.system() == "Windows" and not symlink_path.is_symlink(): + return False + try: target_path = symlink_path.resolve() base_resolved = base_path.resolve() - - return base_resolved in target_path.parents or target_path == base_resolved except (OSError, ValueError): - # If there's any error resolving the paths, consider it unsafe + # Any resolution error → treat as unsafe return False + + return base_resolved in target_path.parents or target_path == base_resolved diff --git a/src/gitingest/utils/query_parser_utils.py b/src/gitingest/utils/query_parser_utils.py index 0181d371..7b57ba40 100644 --- a/src/gitingest/utils/query_parser_utils.py +++ b/src/gitingest/utils/query_parser_utils.py @@ -1,13 +1,14 @@ """Utility functions for parsing and validating query parameters.""" +from __future__ import annotations + import os import string -from typing import List, Set, Tuple -HEX_DIGITS: Set[str] = set(string.hexdigits) +HEX_DIGITS: set[str] = set(string.hexdigits) -KNOWN_GIT_HOSTS: List[str] = [ +KNOWN_GIT_HOSTS: list[str] = [ "github.com", "gitlab.com", "bitbucket.org", @@ -18,8 +19,7 @@ def _is_valid_git_commit_hash(commit: str) -> bool: - """ - Validate if the provided string is a valid Git commit hash. + """Validate if the provided string is a valid Git commit hash. This function checks if the commit hash is a 40-character string consisting only of hexadecimal digits, which is the standard format for Git commit hashes. @@ -32,14 +32,15 @@ def _is_valid_git_commit_hash(commit: str) -> bool: Returns ------- bool - True if the string is a valid 40-character Git commit hash, otherwise False. + `True` if the string is a valid 40-character Git commit hash, otherwise `False`. + """ - return len(commit) == 40 and all(c in HEX_DIGITS for c in commit) + sha_hex_length = 40 + return len(commit) == sha_hex_length and all(c in HEX_DIGITS for c in commit) def _is_valid_pattern(pattern: str) -> bool: - """ - Validate if the given pattern contains only valid characters. + """Validate if the given pattern contains only valid characters. This function checks if the pattern contains only alphanumeric characters or one of the following allowed characters: dash (`-`), underscore (`_`), dot (`.`), @@ -53,14 +54,14 @@ def _is_valid_pattern(pattern: str) -> bool: Returns ------- bool - True if the pattern is valid, otherwise False. + `True` if the pattern is valid, otherwise `False`. + """ return all(c.isalnum() or c in "-_./+*@" for c in pattern) def _validate_host(host: str) -> None: - """ - Validate a hostname. + """Validate a hostname. The host is accepted if it is either present in the hard-coded `KNOWN_GIT_HOSTS` list or if it satisfies the simple heuristics in `_looks_like_git_host`, which try to recognise common self-hosted Git services (e.g. GitLab @@ -75,15 +76,16 @@ def _validate_host(host: str) -> None: ------ ValueError If the host cannot be recognised as a probable Git hosting domain. + """ host = host.lower() if host not in KNOWN_GIT_HOSTS and not _looks_like_git_host(host): - raise ValueError(f"Unknown domain '{host}' in URL") + msg = f"Unknown domain '{host}' in URL" + raise ValueError(msg) def _looks_like_git_host(host: str) -> bool: - """ - Check if the given host looks like a Git host. + """Check if the given host looks like a Git host. The current heuristic returns `True` when the host starts with `git.` (e.g. `git.example.com`), starts with `gitlab.` (e.g. `gitlab.company.com`), or starts with `github.` (e.g. `github.company.com` for GitHub Enterprise). @@ -96,15 +98,15 @@ def _looks_like_git_host(host: str) -> bool: Returns ------- bool - True if the host looks like a Git host, otherwise False. + `True` if the host looks like a Git host, otherwise `False`. + """ host = host.lower() return host.startswith(("git.", "gitlab.", "github.")) def _validate_url_scheme(scheme: str) -> None: - """ - Validate the given scheme against the known schemes. + """Validate the given scheme against the known schemes. Parameters ---------- @@ -114,16 +116,17 @@ def _validate_url_scheme(scheme: str) -> None: Raises ------ ValueError - If the scheme is not 'http' or 'https'. + If the scheme is not `http` or `https`. + """ scheme = scheme.lower() if scheme not in ("https", "http"): - raise ValueError(f"Invalid URL scheme '{scheme}' in URL") + msg = f"Invalid URL scheme '{scheme}' in URL" + raise ValueError(msg) -def _get_user_and_repo_from_path(path: str) -> Tuple[str, str]: - """ - Extract the user and repository names from a given path. +def _get_user_and_repo_from_path(path: str) -> tuple[str, str]: + """Extract the user and repository names from a given path. Parameters ---------- @@ -132,23 +135,25 @@ def _get_user_and_repo_from_path(path: str) -> Tuple[str, str]: Returns ------- - Tuple[str, str] + tuple[str, str] A tuple containing the user and repository names. Raises ------ ValueError If the path does not contain at least two parts. + """ + min_path_parts = 2 path_parts = path.lower().strip("/").split("/") - if len(path_parts) < 2: - raise ValueError(f"Invalid repository URL '{path}'") + if len(path_parts) < min_path_parts: + msg = f"Invalid repository URL '{path}'" + raise ValueError(msg) return path_parts[0], path_parts[1] def _normalize_pattern(pattern: str) -> str: - """ - Normalize the given pattern by removing leading separators and appending a wildcard. + """Normalize the given pattern by removing leading separators and appending a wildcard. This function processes the pattern string by stripping leading directory separators and appending a wildcard (`*`) if the pattern ends with a separator. @@ -162,6 +167,7 @@ def _normalize_pattern(pattern: str) -> str: ------- str The normalized pattern. + """ pattern = pattern.lstrip(os.sep) if pattern.endswith(os.sep): diff --git a/src/gitingest/utils/timeout_wrapper.py b/src/gitingest/utils/timeout_wrapper.py index 7d1d5f91..f811e29c 100644 --- a/src/gitingest/utils/timeout_wrapper.py +++ b/src/gitingest/utils/timeout_wrapper.py @@ -2,16 +2,17 @@ import asyncio import functools -from typing import Any, Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, TypeVar +from gitingest.utils.compat_typing import ParamSpec from gitingest.utils.exceptions import AsyncTimeoutError T = TypeVar("T") +P = ParamSpec("P") -def async_timeout(seconds) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: - """ - Async Timeout decorator. +def async_timeout(seconds: int) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: + """Async Timeout decorator. This decorator wraps an asynchronous function and ensures it does not run for longer than the specified number of seconds. If the function execution exceeds @@ -24,19 +25,21 @@ def async_timeout(seconds) -> Callable[[Callable[..., Awaitable[T]]], Callable[. Returns ------- - Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]] + Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]] A decorator that, when applied to an async function, ensures the function completes within the specified time limit. If the function takes too long, an `AsyncTimeoutError` is raised. + """ - def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: @functools.wraps(func) - async def wrapper(*args: Any, **kwargs: Any) -> T: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) except asyncio.TimeoutError as exc: - raise AsyncTimeoutError(f"Operation timed out after {seconds} seconds") from exc + msg = f"Operation timed out after {seconds} seconds" + raise AsyncTimeoutError(msg) from exc return wrapper diff --git a/src/server/__init__.py b/src/server/__init__.py index e69de29b..49d346db 100644 --- a/src/server/__init__.py +++ b/src/server/__init__.py @@ -0,0 +1 @@ +"""Server module.""" diff --git a/src/server/form_types.py b/src/server/form_types.py new file mode 100644 index 00000000..127d2adc --- /dev/null +++ b/src/server/form_types.py @@ -0,0 +1,16 @@ +"""Reusable form type aliases for FastAPI form parameters.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from fastapi import Form + +from gitingest.utils.compat_typing import Annotated + +if TYPE_CHECKING: + from gitingest.utils.compat_typing import TypeAlias + +StrForm: TypeAlias = Annotated[str, Form(...)] +IntForm: TypeAlias = Annotated[int, Form(...)] +OptStrForm: TypeAlias = Annotated[Optional[str], Form()] diff --git a/src/server/main.py b/src/server/main.py index f314a3ad..b3563550 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -1,8 +1,9 @@ """Main module for the FastAPI application.""" +from __future__ import annotations + import os from pathlib import Path -from typing import Dict from dotenv import load_dotenv from fastapi import FastAPI, Request @@ -45,22 +46,21 @@ @app.get("/health") -async def health_check() -> Dict[str, str]: - """ - Health check endpoint to verify that the server is running. +async def health_check() -> dict[str, str]: + """Health check endpoint to verify that the server is running. Returns ------- - Dict[str, str] + dict[str, str] A JSON object with a "status" key indicating the server's health status. + """ return {"status": "healthy"} @app.head("/") async def head_root() -> HTMLResponse: - """ - Respond to HTTP HEAD requests for the root URL. + """Respond to HTTP HEAD requests for the root URL. Mirrors the headers and status code of the index page. @@ -68,6 +68,7 @@ async def head_root() -> HTMLResponse: ------- HTMLResponse An empty HTML response with appropriate headers. + """ return HTMLResponse(content=None, headers={"content-type": "text/html; charset=utf-8"}) @@ -75,8 +76,7 @@ async def head_root() -> HTMLResponse: @app.get("/api/", response_class=HTMLResponse) @app.get("/api", response_class=HTMLResponse) async def api_docs(request: Request) -> HTMLResponse: - """ - Render the API documentation page. + """Render the API documentation page. Parameters ---------- @@ -87,32 +87,33 @@ async def api_docs(request: Request) -> HTMLResponse: ------- HTMLResponse A rendered HTML page displaying API documentation. + """ return templates.TemplateResponse("api.jinja", {"request": request}) @app.get("/robots.txt") async def robots() -> FileResponse: - """ - Serve the `robots.txt` file to guide search engine crawlers. + """Serve the ``robots.txt`` file to guide search engine crawlers. Returns ------- FileResponse - The `robots.txt` file located in the static directory. + The ``robots.txt`` file located in the static directory. + """ return FileResponse("static/robots.txt") @app.get("/llm.txt") async def llm_txt() -> FileResponse: - """ - Serve the `llm.txt` file to provide information about the site to LLMs. + """Serve the ``llm.txt`` file to provide information about the site to LLMs. Returns ------- FileResponse - The `llm.txt` file located in the static directory. + The ``llm.txt`` file located in the static directory. + """ return FileResponse("static/llm.txt") diff --git a/src/server/models.py b/src/server/models.py new file mode 100644 index 00000000..614c79f7 --- /dev/null +++ b/src/server/models.py @@ -0,0 +1,71 @@ +"""Pydantic models for the query form.""" + +from __future__ import annotations + +from pydantic import BaseModel + +# needed for type checking (pydantic) +from server.form_types import IntForm, OptStrForm, StrForm # noqa: TC001 (typing-only-first-party-import) + + +class QueryForm(BaseModel): + """Form data for the query. + + Attributes + ---------- + input_text : str + Text or URL supplied in the form. + max_file_size : int + The maximum allowed file size for the input, specified by the user. + pattern_type : str + The type of pattern used for the query (``include`` or ``exclude``). + pattern : str + Glob/regex pattern string. + token : str | None + GitHub personal-access token (``None`` if omitted). + + """ + + input_text: str + max_file_size: int + pattern_type: str + pattern: str + token: str | None = None + + @classmethod + def as_form( + cls, + input_text: StrForm, + max_file_size: IntForm, + pattern_type: StrForm, + pattern: StrForm, + token: OptStrForm, + ) -> QueryForm: + """Create a QueryForm from FastAPI form parameters. + + Parameters + ---------- + input_text : StrForm + The input text provided by the user. + max_file_size : IntForm + The maximum allowed file size for the input. + pattern_type : StrForm + The type of pattern used for the query (``include`` or ``exclude``). + pattern : StrForm + Glob/regex pattern string. + token : OptStrForm + GitHub personal-access token (``None`` if omitted). + + Returns + ------- + QueryForm + The QueryForm instance. + + """ + return cls( + input_text=input_text, + max_file_size=max_file_size, + pattern_type=pattern_type, + pattern=pattern, + token=token, + ) diff --git a/src/server/query_processor.py b/src/server/query_processor.py index 35a0fc53..806becbd 100644 --- a/src/server/query_processor.py +++ b/src/server/query_processor.py @@ -1,29 +1,38 @@ """Process a query by parsing input, cloning a repository, and generating a summary.""" -from functools import partial -from typing import Optional +from __future__ import annotations -from fastapi import Request -from starlette.templating import _TemplateResponse +from functools import partial +from pathlib import Path +from typing import TYPE_CHECKING, cast -from gitingest.cloning import clone_repo +from gitingest.clone import clone_repo from gitingest.ingestion import ingest_query -from gitingest.query_parsing import IngestionQuery, parse_query -from server.server_config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE, templates +from gitingest.query_parser import IngestionQuery, parse_query +from server.server_config import ( + DEFAULT_FILE_SIZE_KB, + EXAMPLE_REPOS, + MAX_DISPLAY_SIZE, + templates, +) from server.server_utils import Colors, log_slider_to_size +if TYPE_CHECKING: + from fastapi import Request + from starlette.templating import _TemplateResponse + async def process_query( request: Request, + *, input_text: str, slider_position: int, pattern_type: str = "exclude", pattern: str = "", is_index: bool = False, - token: Optional[str] = None, + token: str | None = None, ) -> _TemplateResponse: - """ - Process a query by parsing input, cloning a repository, and generating a summary. + """Process a query by parsing input, cloning a repository, and generating a summary. Handle user input, process Git repository data, and prepare a response for rendering a template with the processed results or an error message. @@ -37,14 +46,13 @@ async def process_query( slider_position : int Position of the slider, representing the maximum file size in the query. pattern_type : str - Type of pattern to use, either "include" or "exclude" (default is "exclude"). + Type of pattern to use (either "include" or "exclude") (default: ``"exclude"``). pattern : str Pattern to include or exclude in the query, depending on the pattern type. is_index : bool - Flag indicating whether the request is for the index page (default is False). - token : str, optional - GitHub personal-access token (PAT). Needed when *input_text* refers to a - **private** repository. + Flag indicating whether the request is for the index page (default: ``False``). + token : str | None + GitHub personal-access token (PAT). Needed when the repository is private. Returns ------- @@ -55,6 +63,7 @@ async def process_query( ------ ValueError If an invalid pattern type is provided. + """ if pattern_type == "include": include_patterns = pattern @@ -63,7 +72,8 @@ async def process_query( exclude_patterns = pattern include_patterns = None else: - raise ValueError(f"Invalid pattern type: {pattern_type}") + msg = f"Invalid pattern type: {pattern_type}" + raise ValueError(msg) template = "index.jinja" if is_index else "git.jinja" template_response = partial(templates.TemplateResponse, name=template) @@ -79,8 +89,10 @@ async def process_query( "token": token, } + query: IngestionQuery | None = None + try: - query: IngestionQuery = await parse_query( + query = await parse_query( source=input_text, max_file_size=max_file_size, from_web=True, @@ -88,21 +100,24 @@ async def process_query( ignore_patterns=exclude_patterns, token=token, ) - if not query.url: - raise ValueError("The 'url' parameter is required.") + query.ensure_url() # Sets the "/" for the page title context["short_repo_url"] = f"{query.user_name}/{query.repo_name}" clone_config = query.extract_clone_config() await clone_repo(clone_config, token=token) + summary, tree, content = ingest_query(query) - with open(f"{clone_config.local_path}.txt", "w", encoding="utf-8") as f: + + local_txt_file = Path(clone_config.local_path).with_suffix(".txt") + + with local_txt_file.open("w", encoding="utf-8") as f: f.write(tree + "\n" + content) + except Exception as exc: - # hack to print error message when query is not defined - if "query" in locals() and query is not None and isinstance(query, dict): - _print_error(query["url"], exc, max_file_size, pattern_type, pattern) + if query and query.url: + _print_error(query.url, exc, max_file_size, pattern_type, pattern) else: print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") print(f"{Colors.RED}{exc}{Colors.END}") @@ -120,6 +135,9 @@ async def process_query( "download full ingest to see more)\n" + content[:MAX_DISPLAY_SIZE] ) + query.ensure_url() + query.url = cast("str", query.url) + _print_success( url=query.url, max_file_size=max_file_size, @@ -135,16 +153,14 @@ async def process_query( "tree": tree, "content": content, "ingest_id": query.id, - } + }, ) return template_response(context=context) def _print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) -> None: - """ - Print a formatted summary of the query details, including the URL, file size, - and pattern information, for easier debugging or logging. + """Print a formatted summary of the query details for debugging. Parameters ---------- @@ -156,26 +172,28 @@ def _print_query(url: str, max_file_size: int, pattern_type: str, pattern: str) Specifies the type of pattern to use, either "include" or "exclude". pattern : str The actual pattern string to include or exclude in the query. + """ print(f"{Colors.WHITE}{url:<20}{Colors.END}", end="") - if int(max_file_size / 1024) != 50: - print(f" | {Colors.YELLOW}Size: {int(max_file_size/1024)}kb{Colors.END}", end="") + if int(max_file_size / 1024) != DEFAULT_FILE_SIZE_KB: + print( + f" | {Colors.YELLOW}Size: {int(max_file_size / 1024)}kb{Colors.END}", + end="", + ) if pattern_type == "include" and pattern != "": print(f" | {Colors.YELLOW}Include {pattern}{Colors.END}", end="") elif pattern_type == "exclude" and pattern != "": print(f" | {Colors.YELLOW}Exclude {pattern}{Colors.END}", end="") -def _print_error(url: str, e: Exception, max_file_size: int, pattern_type: str, pattern: str) -> None: - """ - Print a formatted error message including the URL, file size, pattern details, and the exception encountered, - for debugging or logging purposes. +def _print_error(url: str, exc: Exception, max_file_size: int, pattern_type: str, pattern: str) -> None: + """Print a formatted error message for debugging. Parameters ---------- url : str The URL associated with the query that caused the error. - e : Exception + exc : Exception The exception raised during the query or process. max_file_size : int The maximum file size allowed for the query, in bytes. @@ -183,16 +201,15 @@ def _print_error(url: str, e: Exception, max_file_size: int, pattern_type: str, Specifies the type of pattern to use, either "include" or "exclude". pattern : str The actual pattern string to include or exclude in the query. + """ print(f"{Colors.BROWN}WARN{Colors.END}: {Colors.RED}<- {Colors.END}", end="") _print_query(url, max_file_size, pattern_type, pattern) - print(f" | {Colors.RED}{e}{Colors.END}") + print(f" | {Colors.RED}{exc}{Colors.END}") def _print_success(url: str, max_file_size: int, pattern_type: str, pattern: str, summary: str) -> None: - """ - Print a formatted success message, including the URL, file size, pattern details, and a summary with estimated - tokens, for debugging or logging purposes. + """Print a formatted success message for debugging. Parameters ---------- @@ -206,6 +223,7 @@ def _print_success(url: str, max_file_size: int, pattern_type: str, pattern: str The actual pattern string to include or exclude in the query. summary : str A summary of the query result, including details like estimated tokens. + """ estimated_tokens = summary[summary.index("Estimated tokens:") + len("Estimated ") :] print(f"{Colors.GREEN}INFO{Colors.END}: {Colors.GREEN}<- {Colors.END}", end="") diff --git a/src/server/routers/__init__.py b/src/server/routers/__init__.py index a1159830..bfddefaa 100644 --- a/src/server/routers/__init__.py +++ b/src/server/routers/__init__.py @@ -1,4 +1,4 @@ -"""This module contains the routers for the FastAPI application.""" +"""Module containing the routers for the FastAPI application.""" from server.routers.download import router as download from server.routers.dynamic import router as dynamic diff --git a/src/server/routers/download.py b/src/server/routers/download.py index e2b405ea..66f39229 100644 --- a/src/server/routers/download.py +++ b/src/server/routers/download.py @@ -1,61 +1,45 @@ -"""This module contains the FastAPI router for downloading a digest file.""" +"""Module containing the FastAPI router for downloading a digest file.""" from fastapi import APIRouter, HTTPException -from fastapi.responses import Response +from fastapi.responses import FileResponse from gitingest.config import TMP_BASE_PATH router = APIRouter() -@router.get("/download/{digest_id}") -async def download_ingest(digest_id: str) -> Response: - """ - Download a .txt file associated with a given digest ID. - - This function searches for a `.txt` file in a directory corresponding to the provided - digest ID. If a file is found, it is read and returned as a downloadable attachment. - If no `.txt` file is found, an error is raised. +@router.get("/download/{digest_id}", response_class=FileResponse) +async def download_ingest(digest_id: str) -> FileResponse: + """Return the first ``*.txt`` file produced for ``digest_id`` as a download. Parameters ---------- digest_id : str - The unique identifier for the digest. It is used to find the corresponding directory - and locate the .txt file within that directory. + Identifier that the ingest step emitted (also the directory name that stores the artefacts). Returns ------- - Response - A FastAPI Response object containing the content of the found `.txt` file. The file is - sent with the appropriate media type (`text/plain`) and the correct `Content-Disposition` - header to prompt a file download. + FileResponse + Streamed response with media type ``text/plain`` that prompts the browser to download the file. Raises ------ HTTPException - If the digest directory is not found or if no `.txt` file exists in the directory. + **404** - digest directory is missing or contains no ``*.txt`` file. + **403** - the process lacks permission to read the directory or file. + """ directory = TMP_BASE_PATH / digest_id - try: - if not directory.exists(): - raise FileNotFoundError("Directory not found") - - txt_files = [f for f in directory.iterdir() if f.suffix == ".txt"] - if not txt_files: - raise FileNotFoundError("No .txt file found") + if not directory.is_dir(): + raise HTTPException(status_code=404, detail=f"Digest {digest_id!r} not found") - except FileNotFoundError as exc: - raise HTTPException(status_code=404, detail="Digest not found") from exc - - # Find the first .txt file in the directory - first_file = txt_files[0] - - with first_file.open(encoding="utf-8") as f: - content = f.read() + try: + first_txt_file = next(directory.glob("*.txt")) + except StopIteration as exc: + raise HTTPException(status_code=404, detail=f"No .txt file found for digest {digest_id!r}") from exc - return Response( - content=content, - media_type="text/plain", - headers={"Content-Disposition": f"attachment; filename={first_file.name}"}, - ) + try: + return FileResponse(path=first_txt_file, media_type="text/plain", filename=first_txt_file.name) + except PermissionError as exc: + raise HTTPException(status_code=403, detail=f"Permission denied for {first_txt_file}") from exc diff --git a/src/server/routers/dynamic.py b/src/server/routers/dynamic.py index 57a54a56..5ed36fe4 100644 --- a/src/server/routers/dynamic.py +++ b/src/server/routers/dynamic.py @@ -1,8 +1,10 @@ -"""This module defines the dynamic router for handling dynamic path requests.""" +"""The dynamic router module defines handlers for dynamic path requests.""" -from fastapi import APIRouter, Form, Request +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse +from gitingest.utils.compat_typing import Annotated +from server.models import QueryForm from server.query_processor import process_query from server.server_config import templates from server.server_utils import limiter @@ -12,11 +14,10 @@ @router.get("/{full_path:path}") async def catch_all(request: Request, full_path: str) -> HTMLResponse: - """ - Render a page with a Git URL based on the provided path. + """Render a page with a Git URL based on the provided path. This endpoint catches all GET requests with a dynamic path, constructs a Git URL - using the `full_path` parameter, and renders the `git.jinja` template with that URL. + using the ``full_path`` parameter, and renders the ``git.jinja`` template with that URL. Parameters ---------- @@ -30,6 +31,7 @@ async def catch_all(request: Request, full_path: str) -> HTMLResponse: HTMLResponse An HTML response containing the rendered template, with the Git URL and other default parameters such as loading state and file size. + """ return templates.TemplateResponse( "git.jinja", @@ -44,48 +46,32 @@ async def catch_all(request: Request, full_path: str) -> HTMLResponse: @router.post("/{full_path:path}", response_class=HTMLResponse) @limiter.limit("10/minute") -async def process_catch_all( - request: Request, - input_text: str = Form(...), - max_file_size: int = Form(...), - pattern_type: str = Form(...), - pattern: str = Form(...), - token: str = Form(...), -) -> HTMLResponse: - """ - Process the form submission with user input for query parameters. +async def process_catch_all(request: Request, form: Annotated[QueryForm, Depends(QueryForm.as_form)]) -> HTMLResponse: + """Process the form submission with user input for query parameters. This endpoint handles POST requests, processes the input parameters (e.g., text, file size, pattern), - and calls the `process_query` function to handle the query logic, returning the result as an HTML response. + and calls the ``process_query`` function to handle the query logic, returning the result as an HTML response. Parameters ---------- request : Request - The incoming request object, which provides context for rendering the response. - input_text : str - The input text provided by the user for processing, by default taken from the form. - max_file_size : int - The maximum allowed file size for the input, specified by the user. - pattern_type : str - The type of pattern used for the query, specified by the user. - pattern : str - The pattern string used in the query, specified by the user. - token : str - GitHub personal-access token (PAT). Needed when *input_text* refers to a - **private** repository. + FastAPI request context. + form : Annotated[QueryForm, Depends(QueryForm.as_form)] + The form data submitted by the user. + Returns ------- HTMLResponse - An HTML response generated after processing the form input and query logic, - which will be rendered and returned to the user. + Rendered HTML with the query results. + """ - resolved_token = None if token == "" else token + resolved_token = form.token if form.token else None return await process_query( request, - input_text, - max_file_size, - pattern_type, - pattern, + input_text=form.input_text, + slider_position=form.max_file_size, + pattern_type=form.pattern_type, + pattern=form.pattern, is_index=False, token=resolved_token, ) diff --git a/src/server/routers/index.py b/src/server/routers/index.py index 8c11aaa8..9385d6ff 100644 --- a/src/server/routers/index.py +++ b/src/server/routers/index.py @@ -1,8 +1,10 @@ -"""This module defines the FastAPI router for the home page of the application.""" +"""Module defining the FastAPI router for the home page of the application.""" -from fastapi import APIRouter, Form, Request +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse +from gitingest.utils.compat_typing import Annotated +from server.models import QueryForm from server.query_processor import process_query from server.server_config import EXAMPLE_REPOS, templates from server.server_utils import limiter @@ -12,10 +14,9 @@ @router.get("/", response_class=HTMLResponse) async def home(request: Request) -> HTMLResponse: - """ - Render the home page with example repositories and default parameters. + """Render the home page with example repositories and default parameters. - This endpoint serves the home page of the application, rendering the `index.jinja` template + This endpoint serves the home page of the application, rendering the ``index.jinja`` template and providing it with a list of example repositories and default file size values. Parameters @@ -28,6 +29,7 @@ async def home(request: Request) -> HTMLResponse: HTMLResponse An HTML response containing the rendered home page template, with example repositories and other default parameters such as file size. + """ return templates.TemplateResponse( "index.jinja", @@ -41,49 +43,34 @@ async def home(request: Request) -> HTMLResponse: @router.post("/", response_class=HTMLResponse) @limiter.limit("10/minute") -async def index_post( - request: Request, - input_text: str = Form(...), - max_file_size: int = Form(...), - pattern_type: str = Form(...), - pattern: str = Form(...), - token: str = Form(...), -) -> HTMLResponse: - """ - Process the form submission with user input for query parameters. +async def index_post(request: Request, form: Annotated[QueryForm, Depends(QueryForm.as_form)]) -> HTMLResponse: + """Process the form submission with user input for query parameters. This endpoint handles POST requests from the home page form. It processes the user-submitted - input (e.g., text, file size, pattern type) and invokes the `process_query` function to handle + input (e.g., text, file size, pattern type) and invokes the ``process_query`` function to handle the query logic, returning the result as an HTML response. Parameters ---------- request : Request The incoming request object, which provides context for rendering the response. - input_text : str - The input text provided by the user for processing, by default taken from the form. - max_file_size : int - The maximum allowed file size for the input, specified by the user. - pattern_type : str - The type of pattern used for the query, specified by the user. - pattern : str - The pattern string used in the query, specified by the user. - token : str - GitHub personal-access token (PAT). Needed when *input_text* refers to a - **private** repository. + form : Annotated[QueryForm, Depends(QueryForm.as_form)] + The form data submitted by the user. + Returns ------- HTMLResponse An HTML response containing the results of processing the form input and query logic, which will be rendered and returned to the user. + """ - resolved_token = None if token == "" else token + resolved_token = form.token if form.token else None return await process_query( request, - input_text, - max_file_size, - pattern_type, - pattern, + input_text=form.input_text, + slider_position=form.max_file_size, + pattern_type=form.pattern_type, + pattern=form.pattern, is_index=True, token=resolved_token, ) diff --git a/src/server/server_config.py b/src/server/server_config.py index 0f910623..29b02a1a 100644 --- a/src/server/server_config.py +++ b/src/server/server_config.py @@ -1,14 +1,16 @@ """Configuration for the server.""" -from typing import Dict, List +from __future__ import annotations from fastapi.templating import Jinja2Templates MAX_DISPLAY_SIZE: int = 300_000 -DELETE_REPO_AFTER: int = 60 * 60 # In seconds +DELETE_REPO_AFTER: int = 60 * 60 # In seconds (1 hour) +# Default maximum file size to include in the digest +DEFAULT_FILE_SIZE_KB: int = 50 -EXAMPLE_REPOS: List[Dict[str, str]] = [ +EXAMPLE_REPOS: list[dict[str, str]] = [ {"name": "Gitingest", "url": "https://github.com/cyclotruc/gitingest"}, {"name": "FastAPI", "url": "https://github.com/tiangolo/fastapi"}, {"name": "Flask", "url": "https://github.com/pallets/flask"}, diff --git a/src/server/server_utils.py b/src/server/server_utils.py index 9972c9ba..3f1010d1 100644 --- a/src/server/server_utils.py +++ b/src/server/server_utils.py @@ -4,8 +4,9 @@ import math import shutil import time -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from pathlib import Path +from typing import AsyncGenerator from fastapi import FastAPI, Request from fastapi.responses import Response @@ -14,6 +15,7 @@ from slowapi.util import get_remote_address from gitingest.config import TMP_BASE_PATH +from gitingest.utils.async_compat import to_thread from server.server_config import DELETE_REPO_AFTER # Initialize a rate limiter @@ -21,8 +23,7 @@ async def rate_limit_exception_handler(request: Request, exc: Exception) -> Response: - """ - Custom exception handler for rate-limiting errors. + """Handle rate-limiting errors with a custom exception handler. Parameters ---------- @@ -40,6 +41,7 @@ async def rate_limit_exception_handler(request: Request, exc: Exception) -> Resp ------ exc If the exception is not a RateLimitExceeded error, it is re-raised. + """ if isinstance(exc, RateLimitExceeded): # Delegate to the default rate limit handler @@ -49,102 +51,120 @@ async def rate_limit_exception_handler(request: Request, exc: Exception) -> Resp @asynccontextmanager -async def lifespan(_: FastAPI): - """ - Lifecycle manager for handling startup and shutdown events for the FastAPI application. - - Parameters - ---------- - _ : FastAPI - The FastAPI application instance (unused). +async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: + """Manage startup & graceful-shutdown tasks for the FastAPI app. - Yields + Returns ------- - None + AsyncGenerator[None, None] Yields control back to the FastAPI application while the background task runs. + """ task = asyncio.create_task(_remove_old_repositories()) - yield - # Cancel the background task on shutdown - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + yield # app runs while the background task is alive + task.cancel() # ask the worker to stop + with suppress(asyncio.CancelledError): + await task # swallow the cancellation signal -async def _remove_old_repositories(): - """ - Periodically remove old repository folders. - Background task that runs periodically to clean up old repository directories. +async def _remove_old_repositories( + base_path: Path = TMP_BASE_PATH, + scan_interval: int = 60, + delete_after: int = DELETE_REPO_AFTER, +) -> None: + """Periodically delete old repositories/directories. - This task: - - Scans the TMP_BASE_PATH directory every 60 seconds - - Removes directories older than DELETE_REPO_AFTER seconds - - Before deletion, logs repository URLs to history.txt if a matching .txt file exists - - Handles errors gracefully if deletion fails + Every ``scan_interval`` seconds the coroutine scans ``base_path`` and deletes directories older than + ``delete_after`` seconds. The repository URL is extracted from the first ``.txt`` file in each directory + and appended to ``history.txt``, assuming the filename format: "owner-repository.txt". Filesystem errors are + logged and the loop continues. + + Parameters + ---------- + base_path : Path + The path to the base directory where repositories are stored (default: ``TMP_BASE_PATH``) + scan_interval : int + The number of seconds between scans (default: 60). + delete_after : int + The number of seconds after which a repository is considered old and will be deleted + (default: ``DELETE_REPO_AFTER``). - The repository URL is extracted from the first .txt file in each directory, - assuming the filename format: "owner-repository.txt" """ while True: - try: - if not TMP_BASE_PATH.exists(): - await asyncio.sleep(60) - continue + if not base_path.exists(): + await asyncio.sleep(scan_interval) + continue - current_time = time.time() - - for folder in TMP_BASE_PATH.iterdir(): - # Skip if folder is not old enough - if current_time - folder.stat().st_ctime <= DELETE_REPO_AFTER: + now = time.time() + try: + for folder in base_path.iterdir(): + if now - folder.stat().st_ctime <= delete_after: # Not old enough continue await _process_folder(folder) - except Exception as exc: + except (OSError, PermissionError) as exc: print(f"Error in _remove_old_repositories: {exc}") - await asyncio.sleep(60) + await asyncio.sleep(scan_interval) async def _process_folder(folder: Path) -> None: - """ - Process a single folder for deletion and logging. + """Append the repo URL (if discoverable) to ``history.txt`` and delete *folder*. Parameters ---------- folder : Path The path to the folder to be processed. + """ - # Try to log repository URL before deletion + history_file = Path("history.txt") + try: - txt_files = [f for f in folder.iterdir() if f.suffix == ".txt"] + first_txt_file: Path = next(folder.glob("*.txt")) + except StopIteration: # No .txt file found + return + # Try to log repository URL before deletion + try: # Extract owner and repository name from the filename - filename = txt_files[0].stem - if txt_files and "-" in filename: + filename = first_txt_file.stem # e.g. "owner-repo" + if "-" in filename: owner, repo = filename.split("-", 1) repo_url = f"{owner}/{repo}" + await to_thread(_append_line, history_file, repo_url) - with open("history.txt", mode="a", encoding="utf-8") as history: - history.write(f"{repo_url}\n") - - except Exception as exc: + except (OSError, PermissionError) as exc: print(f"Error logging repository URL for {folder}: {exc}") # Delete the folder try: - shutil.rmtree(folder) - except Exception as exc: - print(f"Error deleting {folder}: {exc}") + await to_thread(shutil.rmtree, folder) + except PermissionError as exc: + print(f"No permission to delete {folder}: {exc}") + except OSError as exc: + print(f"Could not delete {folder}: {exc}") -def log_slider_to_size(position: int) -> int: +def _append_line(path: Path, line: str) -> None: + """Append a line to a file. + + Parameters + ---------- + path : Path + The path to the file to append the line to. + line : str + The line to append to the file. + """ - Convert a slider position to a file size in bytes using a logarithmic scale. + with path.open("a", encoding="utf-8") as fp: + fp.write(f"{line}\n") + + +def log_slider_to_size(position: int) -> int: + """Convert a slider position to a file size in bytes using a logarithmic scale. Parameters ---------- @@ -155,6 +175,7 @@ def log_slider_to_size(position: int) -> int: ------- int File size in bytes corresponding to the slider position. + """ maxp = 500 minv = math.log(1) @@ -164,7 +185,7 @@ def log_slider_to_size(position: int) -> int: ## Color printing utility class Colors: - """ANSI color codes""" + """ANSI color codes.""" BLACK = "\033[0;30m" RED = "\033[0;31m" diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..b6e6b827 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the gitingest package.""" diff --git a/tests/conftest.py b/tests/conftest.py index 50a5a90d..3710ab4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,22 @@ -""" -Fixtures for tests. +"""Fixtures for tests. This file provides shared fixtures for creating sample queries, a temporary directory structure, and a helper function -to write `.ipynb` notebooks for testing notebook utilities. +to write ``.ipynb`` notebooks for testing notebook utilities. """ +from __future__ import annotations + import json from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import TYPE_CHECKING, Any, Callable, Dict from unittest.mock import AsyncMock import pytest -from pytest_mock import MockerFixture -from gitingest.query_parsing import IngestionQuery +from gitingest.query_parser import IngestionQuery + +if TYPE_CHECKING: + from pytest_mock import MockerFixture WriteNotebookFunc = Callable[[str, Dict[str, Any]], Path] @@ -23,15 +26,15 @@ @pytest.fixture def sample_query() -> IngestionQuery: - """ - Provide a default `IngestionQuery` object for use in tests. + """Provide a default ``IngestionQuery`` object for use in tests. - This fixture returns a `IngestionQuery` pre-populated with typical fields and some default ignore patterns. + This fixture returns a ``IngestionQuery`` pre-populated with typical fields and some default ignore patterns. Returns ------- IngestionQuery - The sample `IngestionQuery` object. + The sample ``IngestionQuery`` object. + """ return IngestionQuery( user_name="test_user", @@ -50,8 +53,7 @@ def sample_query() -> IngestionQuery: @pytest.fixture def temp_directory(tmp_path: Path) -> Path: - """ - Create a temporary directory structure for testing repository scanning. + """Create a temporary directory structure for testing repository scanning. The structure includes: test_repo/ @@ -71,12 +73,13 @@ def temp_directory(tmp_path: Path) -> Path: Parameters ---------- tmp_path : Path - The temporary directory path provided by the `tmp_path` fixture. + The temporary directory path provided by the ``tmp_path`` fixture. Returns ------- Path - The path to the created `test_repo` directory. + The path to the created ``test_repo`` directory. + """ test_dir = tmp_path / "test_repo" test_dir.mkdir() @@ -112,22 +115,22 @@ def temp_directory(tmp_path: Path) -> Path: @pytest.fixture def write_notebook(tmp_path: Path) -> WriteNotebookFunc: - """ - Provide a helper function to write a `.ipynb` notebook file with the given content. + """Provide a helper function to write a ``.ipynb`` notebook file with the given content. Parameters ---------- tmp_path : Path - The temporary directory path provided by the `tmp_path` fixture. + The temporary directory path provided by the ``tmp_path`` fixture. Returns ------- WriteNotebookFunc - A callable that accepts a filename and a dictionary (representing JSON notebook data), writes it to a `.ipynb` - file, and returns the path to the file. + A callable that accepts a filename and a dictionary (representing JSON notebook data), writes it to a + ``.ipynb`` file, and returns the path to the file. + """ - def _write_notebook(name: str, content: Dict[str, Any]) -> Path: + def _write_notebook(name: str, content: dict[str, Any]) -> Path: notebook_path = tmp_path / name with notebook_path.open(mode="w", encoding="utf-8") as f: json.dump(content, f) @@ -137,10 +140,10 @@ def _write_notebook(name: str, content: Dict[str, Any]) -> Path: @pytest.fixture -def stub_branches(mocker: MockerFixture) -> Callable[[List[str]], None]: +def stub_branches(mocker: MockerFixture) -> Callable[[list[str]], None]: """Return a function that stubs git branch discovery to *branches*.""" - def _factory(branches: List[str]) -> None: + def _factory(branches: list[str]) -> None: mocker.patch( "gitingest.utils.git_utils.run_command", new_callable=AsyncMock, @@ -157,24 +160,18 @@ def _factory(branches: List[str]) -> None: @pytest.fixture def repo_exists_true(mocker: MockerFixture) -> AsyncMock: - """Patch `gitingest.cloning.check_repo_exists` to always return ``True``. - - Many cloning-related tests assume that the remote repository exists. This fixture centralises - that behaviour so individual tests no longer need to repeat the same ``mocker.patch`` call. - The mock object is returned so that tests can make assertions on how it was used or override - its behaviour when needed. - """ - return mocker.patch("gitingest.cloning.check_repo_exists", return_value=True) + """Patch ``gitingest.clone.check_repo_exists`` to always return *True*.""" + return mocker.patch("gitingest.clone.check_repo_exists", return_value=True) @pytest.fixture def run_command_mock(mocker: MockerFixture) -> AsyncMock: - """Patch `gitingest.cloning.run_command` with an ``AsyncMock``. + """Patch ``gitingest.clone.run_command`` with an ``AsyncMock``. The mocked function returns a dummy process whose ``communicate`` method yields generic *stdout* / *stderr* bytes. Tests can still access / tweak the mock via the fixture argument. """ - mock_exec = mocker.patch("gitingest.cloning.run_command", new_callable=AsyncMock) + mock_exec = mocker.patch("gitingest.clone.run_command", new_callable=AsyncMock) # Provide a default dummy process so most tests don't have to create one. dummy_process = AsyncMock() diff --git a/tests/query_parser/__init__.py b/tests/query_parser/__init__.py index e69de29b..4634bc87 100644 --- a/tests/query_parser/__init__.py +++ b/tests/query_parser/__init__.py @@ -0,0 +1 @@ +"""Tests for the query parser.""" diff --git a/tests/query_parser/test_git_host_agnostic.py b/tests/query_parser/test_git_host_agnostic.py index eef06f97..f4014d92 100644 --- a/tests/query_parser/test_git_host_agnostic.py +++ b/tests/query_parser/test_git_host_agnostic.py @@ -1,19 +1,18 @@ -""" -Tests to verify that the query parser is Git host agnostic. +"""Tests to verify that the query parser is Git host agnostic. -These tests confirm that `parse_query` correctly identifies user/repo pairs and canonical URLs for GitHub, GitLab, +These tests confirm that ``parse_query`` correctly identifies user/repo pairs and canonical URLs for GitHub, GitLab, Bitbucket, Gitea, and Codeberg, even if the host is omitted. """ -from typing import List, Tuple +from __future__ import annotations import pytest -from gitingest.query_parsing import parse_query +from gitingest.query_parser import parse_query from gitingest.utils.query_parser_utils import KNOWN_GIT_HOSTS # Repository matrix: (host, user, repo) -_REPOS: List[Tuple[str, str, str]] = [ +_REPOS: list[tuple[str, str, str]] = [ ("github.com", "tiangolo", "fastapi"), ("gitlab.com", "gitlab-org", "gitlab-runner"), ("bitbucket.org", "na-dna", "llm-knowledge-share"), @@ -25,7 +24,7 @@ # Generate cartesian product of repository tuples with URL variants. -@pytest.mark.parametrize("host, user, repo", _REPOS, ids=[f"{h}:{u}/{r}" for h, u, r in _REPOS]) +@pytest.mark.parametrize(("host", "user", "repo"), _REPOS, ids=[f"{h}:{u}/{r}" for h, u, r in _REPOS]) @pytest.mark.parametrize("variant", ["full", "noscheme", "slug"]) @pytest.mark.asyncio async def test_parse_query_without_host( @@ -34,8 +33,7 @@ async def test_parse_query_without_host( repo: str, variant: str, ) -> None: - """Verify that `parse_query` handles URLs, host-omitted URLs and raw slugs.""" - + """Verify that ``parse_query`` handles URLs, host-omitted URLs and raw slugs.""" # Build the input URL based on the selected variant if variant == "full": url = f"https://{host}/{user}/{repo}" @@ -49,7 +47,7 @@ async def test_parse_query_without_host( # For slug form with a custom host (not in KNOWN_GIT_HOSTS) we expect a failure, # because the parser cannot guess which domain to use. if variant == "slug" and host not in KNOWN_GIT_HOSTS: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Could not find a valid repository host"): await parse_query(url, max_file_size=50, from_web=True) return diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index dd0fa32e..5c9c4e44 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -1,23 +1,28 @@ -""" -Tests for the `query_parsing` module. +"""Tests for the `query_parser` module. These tests cover URL parsing, pattern parsing, and handling of branches/subpaths for HTTP(S) repositories and local paths. """ +from __future__ import annotations + from pathlib import Path -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable from unittest.mock import AsyncMock import pytest -from pytest_mock import MockerFixture -from gitingest.query_parsing import _parse_patterns, _parse_remote_repo, parse_query -from gitingest.schemas.ingestion_schema import IngestionQuery +from gitingest.query_parser import _parse_patterns, _parse_remote_repo, parse_query from gitingest.utils.ignore_patterns import DEFAULT_IGNORE_PATTERNS from tests.conftest import DEMO_URL -URLS_HTTPS: List[str] = [ +if TYPE_CHECKING: + from pytest_mock import MockerFixture + + from gitingest.schemas.ingestion import IngestionQuery + + +URLS_HTTPS: list[str] = [ DEMO_URL, "https://gitlab.com/user/repo", "https://bitbucket.org/user/repo", @@ -29,13 +34,13 @@ "https://gitlab.example.se/user/repo", ] -URLS_HTTP: List[str] = [url.replace("https://", "http://") for url in URLS_HTTPS] +URLS_HTTP: list[str] = [url.replace("https://", "http://") for url in URLS_HTTPS] @pytest.mark.parametrize("url", URLS_HTTPS, ids=lambda u: u) @pytest.mark.asyncio async def test_parse_url_valid_https(url: str) -> None: - """Valid HTTPS URLs parse correctly and `query.url` equals the input.""" + """Valid HTTPS URLs parse correctly and ``query.url`` equals the input.""" query = await _assert_basic_repo_fields(url) assert query.url == url # HTTPS: canonical URL should equal input @@ -50,11 +55,10 @@ async def test_parse_url_valid_http(url: str) -> None: @pytest.mark.asyncio async def test_parse_url_invalid() -> None: - """ - Test `_parse_remote_repo` with an invalid URL. + """Test ``_parse_remote_repo`` with an invalid URL. Given an HTTPS URL lacking a repository structure (e.g., "https://github.com"), - When `_parse_remote_repo` is called, + When ``_parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com" @@ -66,11 +70,10 @@ async def test_parse_url_invalid() -> None: @pytest.mark.asyncio @pytest.mark.parametrize("url", [DEMO_URL, "https://gitlab.com/user/repo"]) async def test_parse_query_basic(url: str) -> None: - """ - Test `parse_query` with a basic valid repository URL. + """Test ``parse_query`` with a basic valid repository URL. Given an HTTPS URL and ignore_patterns="*.txt": - When `parse_query` is called, + When ``parse_query`` is called, Then user/repo, URL, and ignore patterns should be parsed correctly. """ query = await parse_query(source=url, max_file_size=50, from_web=True, ignore_patterns="*.txt") @@ -84,11 +87,10 @@ async def test_parse_query_basic(url: str) -> None: @pytest.mark.asyncio async def test_parse_query_mixed_case() -> None: - """ - Test `parse_query` with mixed-case URLs. + """Test ``parse_query`` with mixed-case URLs. Given a URL with mixed-case parts (e.g. "Https://GitHub.COM/UsEr/rEpO"): - When `parse_query` is called, + When ``parse_query`` is called, Then the user and repo names should be normalized to lowercase. """ url = "Https://GitHub.COM/UsEr/rEpO" @@ -100,11 +102,10 @@ async def test_parse_query_mixed_case() -> None: @pytest.mark.asyncio async def test_parse_query_include_pattern() -> None: - """ - Test `parse_query` with a specified include pattern. + """Test ``parse_query`` with a specified include pattern. Given a URL and include_patterns="*.py": - When `parse_query` is called, + When ``parse_query`` is called, Then the include pattern should be set, and default ignore patterns remain applied. """ query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="*.py") @@ -115,11 +116,10 @@ async def test_parse_query_include_pattern() -> None: @pytest.mark.asyncio async def test_parse_query_invalid_pattern() -> None: - """ - Test `parse_query` with an invalid pattern. + """Test ``parse_query`` with an invalid pattern. Given an include pattern containing special characters (e.g., "*.py;rm -rf"): - When `parse_query` is called, + When ``parse_query`` is called, Then a ValueError should be raised indicating invalid characters. """ with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): @@ -127,12 +127,11 @@ async def test_parse_query_invalid_pattern() -> None: @pytest.mark.asyncio -async def test_parse_url_with_subpaths(stub_branches: Callable[[List[str]], None]) -> None: - """ - Test `_parse_remote_repo` with a URL containing branch and subpath. +async def test_parse_url_with_subpaths(stub_branches: Callable[[list[str]], None]) -> None: + """Test ``_parse_remote_repo`` with a URL containing branch and subpath. Given a URL referencing a branch ("main") and a subdir ("subdir/file"): - When `_parse_remote_repo` is called with remote branch fetching, + When ``_parse_remote_repo`` is called with remote branch fetching, Then user, repo, branch, and subpath should be identified correctly. """ url = DEMO_URL + "/tree/main/subdir/file" @@ -149,11 +148,10 @@ async def test_parse_url_with_subpaths(stub_branches: Callable[[List[str]], None @pytest.mark.asyncio async def test_parse_url_invalid_repo_structure() -> None: - """ - Test `_parse_remote_repo` with a URL missing a repository name. + """Test ``_parse_remote_repo`` with a URL missing a repository name. Given a URL like "https://github.com/user": - When `_parse_remote_repo` is called, + When ``_parse_remote_repo`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com/user" @@ -163,11 +161,10 @@ async def test_parse_url_invalid_repo_structure() -> None: def test_parse_patterns_valid() -> None: - """ - Test `_parse_patterns` with valid comma-separated patterns. + """Test ``_parse_patterns`` with valid comma-separated patterns. Given patterns like "*.py, *.md, docs/*": - When `_parse_patterns` is called, + When ``_parse_patterns`` is called, Then it should return a set of parsed strings. """ patterns = "*.py, *.md, docs/*" @@ -177,11 +174,10 @@ def test_parse_patterns_valid() -> None: def test_parse_patterns_invalid_characters() -> None: - """ - Test `_parse_patterns` with invalid characters. + """Test ``_parse_patterns`` with invalid characters. Given a pattern string containing special characters (e.g. "*.py;rm -rf"): - When `_parse_patterns` is called, + When ``_parse_patterns`` is called, Then a ValueError should be raised indicating invalid pattern syntax. """ patterns = "*.py;rm -rf" @@ -192,12 +188,11 @@ def test_parse_patterns_invalid_characters() -> None: @pytest.mark.asyncio async def test_parse_query_with_large_file_size() -> None: - """ - Test `parse_query` with a very large file size limit. + """Test ``parse_query`` with a very large file size limit. Given a URL and max_file_size=10**9: - When `parse_query` is called, - Then `max_file_size` should be set correctly and default ignore patterns remain unchanged. + When ``parse_query`` is called, + Then ``max_file_size`` should be set correctly and default ignore patterns remain unchanged. """ query = await parse_query(DEMO_URL, max_file_size=10**9, from_web=True) @@ -207,12 +202,11 @@ async def test_parse_query_with_large_file_size() -> None: @pytest.mark.asyncio async def test_parse_query_empty_patterns() -> None: - """ - Test `parse_query` with empty patterns. + """Test ``parse_query`` with empty patterns. Given empty include_patterns and ignore_patterns: - When `parse_query` is called, - Then include_patterns becomes None and default ignore patterns apply. + When ``parse_query`` is called, + Then ``include_patterns`` becomes *None* and default ignore patterns apply. """ query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") @@ -222,11 +216,10 @@ async def test_parse_query_empty_patterns() -> None: @pytest.mark.asyncio async def test_parse_query_include_and_ignore_overlap() -> None: - """ - Test `parse_query` with overlapping patterns. + """Test ``parse_query`` with overlapping patterns. Given include="*.py" and ignore={"*.py", "*.txt"}: - When `parse_query` is called, + When ``parse_query`` is called, Then "*.py" should be removed from ignore patterns. """ query = await parse_query( @@ -245,11 +238,10 @@ async def test_parse_query_include_and_ignore_overlap() -> None: @pytest.mark.asyncio async def test_parse_query_local_path() -> None: - """ - Test `parse_query` with a local file path. + """Test ``parse_query`` with a local file path. Given "/home/user/project" and from_web=False: - When `parse_query` is called, + When ``parse_query`` is called, Then the local path should be set, id generated, and slug formed accordingly. """ path = "/home/user/project" @@ -263,12 +255,11 @@ async def test_parse_query_local_path() -> None: @pytest.mark.asyncio async def test_parse_query_relative_path() -> None: - """ - Test `parse_query` with a relative path. + """Test ``parse_query`` with a relative path. Given "./project" and from_web=False: - When `parse_query` is called, - Then local_path resolves relatively, and slug ends with "project". + When ``parse_query`` is called, + Then ``local_path`` resolves relatively, and ``slug`` ends with "project". """ path = "./project" query = await parse_query(path, max_file_size=100, from_web=False) @@ -280,11 +271,10 @@ async def test_parse_query_relative_path() -> None: @pytest.mark.asyncio async def test_parse_query_empty_source() -> None: - """ - Test `parse_query` with an empty string. + """Test ``parse_query`` with an empty string. Given an empty source string: - When `parse_query` is called, + When ``parse_query`` is called, Then a ValueError should be raised indicating an invalid repository URL. """ url = "" @@ -295,7 +285,7 @@ async def test_parse_query_empty_source() -> None: @pytest.mark.asyncio @pytest.mark.parametrize( - "path, expected_branch, expected_commit", + ("path", "expected_branch", "expected_commit"), [ ("/tree/main", "main", None), ("/tree/abcd1234abcd1234abcd1234abcd1234abcd1234", None, "abcd1234abcd1234abcd1234abcd1234abcd1234"), @@ -305,13 +295,12 @@ async def test_parse_url_branch_and_commit_distinction( path: str, expected_branch: str, expected_commit: str, - stub_branches: Callable[[List[str]], None], + stub_branches: Callable[[list[str]], None], ) -> None: - """ - Test `_parse_remote_repo` distinguishing branch vs. commit hash. + """Test ``_parse_remote_repo`` distinguishing branch vs. commit hash. Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL: - When `_parse_remote_repo` is called with branch fetching, + When ``_parse_remote_repo`` is called with branch fetching, Then the function should correctly set `branch` or `commit` based on the URL content. """ stub_branches(["main", "dev", "feature-branch"]) @@ -325,11 +314,10 @@ async def test_parse_url_branch_and_commit_distinction( @pytest.mark.asyncio async def test_parse_query_uuid_uniqueness() -> None: - """ - Test `parse_query` for unique UUID generation. + """Test ``parse_query`` for unique UUID generation. Given the same path twice: - When `parse_query` is called repeatedly, + When ``parse_query`` is called repeatedly, Then each call should produce a different query id. """ path = "/home/user/project" @@ -341,11 +329,10 @@ async def test_parse_query_uuid_uniqueness() -> None: @pytest.mark.asyncio async def test_parse_url_with_query_and_fragment() -> None: - """ - Test `_parse_remote_repo` with query parameters and a fragment. + """Test ``_parse_remote_repo`` with query parameters and a fragment. Given a URL like "https://github.com/user/repo?arg=value#fragment": - When `_parse_remote_repo` is called, + When ``_parse_remote_repo`` is called, Then those parts should be stripped, leaving a clean user/repo URL. """ url = DEMO_URL + "?arg=value#fragment" @@ -358,11 +345,10 @@ async def test_parse_url_with_query_and_fragment() -> None: @pytest.mark.asyncio async def test_parse_url_unsupported_host() -> None: - """ - Test `_parse_remote_repo` with an unsupported host. + """Test ``_parse_remote_repo`` with an unsupported host. Given "https://only-domain.com": - When `_parse_remote_repo` is called, + When ``_parse_remote_repo`` is called, Then a ValueError should be raised for the unknown domain. """ url = "https://only-domain.com" @@ -373,11 +359,10 @@ async def test_parse_url_unsupported_host() -> None: @pytest.mark.asyncio async def test_parse_query_with_branch() -> None: - """ - Test `parse_query` when a branch is specified in a blob path. + """Test ``parse_query`` when a branch is specified in a blob path. Given "https://github.com/pandas-dev/pandas/blob/2.2.x/...": - When `parse_query` is called, + When ``parse_query`` is called, Then the branch should be identified, subpath set, and commit remain None. """ url = "https://github.com/pandas-dev/pandas/blob/2.2.x/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" @@ -396,7 +381,7 @@ async def test_parse_query_with_branch() -> None: @pytest.mark.asyncio @pytest.mark.parametrize( - "path, expected_branch, expected_subpath", + ("path", "expected_branch", "expected_subpath"), [ ("/tree/main/src", "main", "/src"), ("/tree/fix1", "fix1", "/"), @@ -409,11 +394,10 @@ async def test_parse_repo_source_with_failed_git_command( expected_subpath: str, mocker: MockerFixture, ) -> None: - """ - Test `_parse_remote_repo` when git fetch fails. + """Test ``_parse_remote_repo`` when git fetch fails. Given a URL referencing a branch, but Git fetching fails: - When `_parse_remote_repo` is called, + When ``_parse_remote_repo`` is called, Then it should fall back to path components for branch identification. """ url = DEMO_URL + path @@ -446,15 +430,17 @@ async def test_parse_repo_source_with_failed_git_command( ) async def test_parse_repo_source_with_various_url_patterns( path: str, - expected_branch: Optional[str], + expected_branch: str | None, expected_subpath: str, - stub_branches: Callable[[List[str]], None], + stub_branches: Callable[[list[str]], None], ) -> None: - """ - `_parse_remote_repo` should detect (or reject) a branch and resolve the - sub-path for various GitHub-style URL permutations. + """Test ``_parse_remote_repo`` with various GitHub-style URL permutations. + + Given various GitHub-style URL permutations: + When ``_parse_remote_repo`` is called, + Then it should detect (or reject) a branch and resolve the sub-path. - Branch discovery is stubbed so that only names passed to `stub_branches` are considered "remote". + Branch discovery is stubbed so that only names passed to ``stub_branches`` are considered "remote". """ stub_branches(["feature/fix1", "main", "feature-branch", "fix"]) @@ -466,8 +452,7 @@ async def test_parse_repo_source_with_various_url_patterns( async def _assert_basic_repo_fields(url: str) -> IngestionQuery: - """Run _parse_remote_repo and assert user, repo and slug are parsed.""" - + """Run ``_parse_remote_repo`` and assert user, repo and slug are parsed.""" query = await _parse_remote_repo(url) assert query.user_name == "user" diff --git a/tests/test_cli.py b/tests/test_cli.py index e033164a..ab018e32 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,12 +1,11 @@ """Tests for the Gitingest CLI.""" -import os +from __future__ import annotations + from inspect import signature from pathlib import Path -from typing import List import pytest -from _pytest.monkeypatch import MonkeyPatch from click.testing import CliRunner, Result from gitingest.cli import main @@ -14,7 +13,7 @@ @pytest.mark.parametrize( - "cli_args, expect_file", + ("cli_args", "expect_file"), [ pytest.param(["./"], True, id="default-options"), pytest.param( @@ -34,14 +33,21 @@ ), ], ) -def test_cli_writes_file(tmp_path: Path, monkeypatch: MonkeyPatch, cli_args: List[str], expect_file: bool) -> None: +def test_cli_writes_file( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + cli_args: list[str], + *, + expect_file: bool, +) -> None: """Run the CLI and verify that the SARIF file is created (or not).""" + expectes_exit_code = 0 # Work inside an isolated temp directory monkeypatch.chdir(tmp_path) result = _invoke_isolated_cli_runner(cli_args) - assert result.exit_code == 0, result.stderr + assert result.exit_code == expectes_exit_code, result.stderr # Summary line should be on STDOUT stdout_lines = result.stdout.splitlines() @@ -54,9 +60,10 @@ def test_cli_writes_file(tmp_path: Path, monkeypatch: MonkeyPatch, cli_args: Lis def test_cli_with_stdout_output() -> None: """Test CLI invocation with output directed to STDOUT.""" + output_file = Path(OUTPUT_FILE_NAME) # Clean up any existing digest.txt file before test - if os.path.exists(OUTPUT_FILE_NAME): - os.remove(OUTPUT_FILE_NAME) + if output_file.exists(): + output_file.unlink() try: result = _invoke_isolated_cli_runner(["./", "--output", "-", "--exclude-pattern", "tests/"]) @@ -64,10 +71,10 @@ def test_cli_with_stdout_output() -> None: # ─── core expectations (stdout) ────────────────────────────────────- assert result.exit_code == 0, f"CLI exited with code {result.exit_code}, stderr: {result.stderr}" assert "---" in result.stdout, "Expected file separator '---' not found in STDOUT" - assert ( - "src/gitingest/cli.py" in result.stdout - ), "Expected content (e.g., src/gitingest/cli.py) not found in STDOUT" - assert not os.path.exists(OUTPUT_FILE_NAME), f"Output file {OUTPUT_FILE_NAME} was unexpectedly created." + assert "src/gitingest/cli.py" in result.stdout, ( + "Expected content (e.g., src/gitingest/cli.py) not found in STDOUT" + ) + assert not output_file.exists(), f"Output file {output_file} was unexpectedly created." # ─── the summary must *not* pollute STDOUT, must appear on STDERR ─── summary = "Analysis complete! Output sent to stdout." @@ -75,17 +82,17 @@ def test_cli_with_stdout_output() -> None: stderr_lines = result.stderr.splitlines() assert summary not in stdout_lines, "Unexpected summary message found in STDOUT" assert summary in stderr_lines, "Expected summary message not found in STDERR" - assert f"Output written to: {OUTPUT_FILE_NAME}" not in stderr_lines + assert f"Output written to: {output_file.name}" not in stderr_lines finally: # Clean up any digest.txt file that might have been created during test - if os.path.exists(OUTPUT_FILE_NAME): - os.remove(OUTPUT_FILE_NAME) + if output_file.exists(): + output_file.unlink() -def _invoke_isolated_cli_runner(args: List[str]) -> Result: - """Return a CliRunner that keeps stderr apart on Click 8.0-8.1.""" +def _invoke_isolated_cli_runner(args: list[str]) -> Result: + """Return a ``CliRunner`` that keeps ``stderr`` separate on Click 8.0-8.1.""" kwargs = {} if "mix_stderr" in signature(CliRunner.__init__).parameters: - kwargs["mix_stderr"] = False # Click 8.0–8.1 + kwargs["mix_stderr"] = False # Click 8.0-8.1 runner = CliRunner(**kwargs) return runner.invoke(main, args) diff --git a/tests/test_repository_clone.py b/tests/test_clone.py similarity index 78% rename from tests/test_repository_clone.py rename to tests/test_clone.py index d5d395c8..d560d684 100644 --- a/tests/test_repository_clone.py +++ b/tests/test_clone.py @@ -1,20 +1,20 @@ -""" -Tests for the `cloning` module. +"""Tests for the `clone` module. These tests cover various scenarios for cloning repositories, verifying that the appropriate Git commands are invoked and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches. """ import asyncio -import os +import subprocess from pathlib import Path from unittest.mock import AsyncMock import pytest from pytest_mock import MockerFixture -from gitingest.cloning import clone_repo +from gitingest.clone import clone_repo from gitingest.schemas import CloneConfig +from gitingest.utils.async_compat import to_thread from gitingest.utils.exceptions import AsyncTimeoutError from gitingest.utils.git_utils import check_repo_exists from tests.conftest import DEMO_URL, LOCAL_REPO_PATH @@ -26,13 +26,13 @@ @pytest.mark.asyncio async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository with a specific commit hash. + """Test cloning a repository with a specific commit hash. Given a valid URL and a commit hash: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ + expected_call_count = 2 clone_config = CloneConfig( url=DEMO_URL, local_path=LOCAL_REPO_PATH, @@ -43,33 +43,32 @@ async def test_clone_with_commit(repo_exists_true: AsyncMock, run_command_mock: await clone_repo(clone_config) repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == 2 # Clone and checkout calls + assert run_command_mock.call_count == expected_call_count # Clone and checkout calls @pytest.mark.asyncio async def test_clone_without_commit(repo_exists_true: AsyncMock, run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository when no commit hash is provided. + """Test cloning a repository when no commit hash is provided. Given a valid URL and no commit hash: - When `clone_repo` is called, + When ``clone_repo`` is called, Then only the clone_repo operation should be performed (no checkout). """ + expected_call_count = 1 clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit=None, branch="main") await clone_repo(clone_config) repo_exists_true.assert_called_once_with(clone_config.url, token=None) - assert run_command_mock.call_count == 1 # Only clone call + assert run_command_mock.call_count == expected_call_count # Only clone call @pytest.mark.asyncio async def test_clone_nonexistent_repository(repo_exists_true: AsyncMock) -> None: - """ - Test cloning a nonexistent repository URL. + """Test cloning a nonexistent repository URL. Given an invalid or nonexistent URL: - When `clone_repo` is called, + When ``clone_repo`` is called, Then a ValueError should be raised with an appropriate error message. """ clone_config = CloneConfig( @@ -89,19 +88,24 @@ async def test_clone_nonexistent_repository(repo_exists_true: AsyncMock) -> None @pytest.mark.asyncio @pytest.mark.parametrize( - "mock_stdout, return_code, expected", + ("mock_stdout", "return_code", "expected"), [ (b"HTTP/1.1 200 OK\n", 0, True), # Existing repo (b"HTTP/1.1 404 Not Found\n", 0, False), # Non-existing repo (b"HTTP/1.1 200 OK\n", 1, False), # Failed request ], ) -async def test_check_repo_exists(mock_stdout: bytes, return_code: int, expected: bool, mocker: MockerFixture) -> None: - """ - Test the `check_repo_exists` function with different Git HTTP responses. +async def test_check_repo_exists( + mock_stdout: bytes, + return_code: int, + *, + expected: bool, + mocker: MockerFixture, +) -> None: + """Test the ``check_repo_exists`` function with different Git HTTP responses. Given various stdout lines and return codes: - When `check_repo_exists` is called, + When ``check_repo_exists`` is called, Then it should correctly indicate whether the repository exists. """ mock_exec = mocker.patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) @@ -117,11 +121,10 @@ async def test_check_repo_exists(mock_stdout: bytes, return_code: int, expected: @pytest.mark.asyncio async def test_clone_with_custom_branch(run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository with a specified custom branch. + """Test cloning a repository with a specified custom branch. Given a valid URL and a branch: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned shallowly to that branch. """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, branch="feature-branch") @@ -142,11 +145,10 @@ async def test_clone_with_custom_branch(run_command_mock: AsyncMock) -> None: @pytest.mark.asyncio async def test_git_command_failure(run_command_mock: AsyncMock) -> None: - """ - Test cloning when the Git command fails during execution. + """Test cloning when the Git command fails during execution. - Given a valid URL, but `run_command` raises a RuntimeError: - When `clone_repo` is called, + Given a valid URL, but ``run_command`` raises a RuntimeError: + When ``clone_repo`` is called, Then a RuntimeError should be raised with the correct message. """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) @@ -159,12 +161,11 @@ async def test_git_command_failure(run_command_mock: AsyncMock) -> None: @pytest.mark.asyncio async def test_clone_default_shallow_clone(run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository with the default shallow clone options. + """Test cloning a repository with the default shallow clone options. Given a valid URL and no branch or commit: - When `clone_repo` is called, - Then the repository should be cloned with `--depth=1` and `--single-branch`. + When ``clone_repo`` is called, + Then the repository should be cloned with ``--depth=1`` and ``--single-branch``. """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) @@ -182,30 +183,29 @@ async def test_clone_default_shallow_clone(run_command_mock: AsyncMock) -> None: @pytest.mark.asyncio async def test_clone_commit_without_branch(run_command_mock: AsyncMock) -> None: - """ - Test cloning when a commit hash is provided but no branch is specified. + """Test cloning when a commit hash is provided but no branch is specified. Given a valid URL and a commit hash (but no branch): - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned and checked out at that commit. """ + expected_call_count = 2 # Simulating a valid commit hash clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40) await clone_repo(clone_config) - assert run_command_mock.call_count == 2 # Clone and checkout calls + assert run_command_mock.call_count == expected_call_count # Clone and checkout calls run_command_mock.assert_any_call("git", "clone", "--single-branch", clone_config.url, clone_config.local_path) run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "checkout", clone_config.commit) @pytest.mark.asyncio async def test_check_repo_exists_with_redirect(mocker: MockerFixture) -> None: - """ - Test `check_repo_exists` when a redirect (302) is returned. + """Test ``check_repo_exists`` when a redirect (302) is returned. Given a URL that responds with "302 Found": - When `check_repo_exists` is called, + When ``check_repo_exists`` is called, Then it should return `False`, indicating the repo is inaccessible. """ mock_exec = mocker.patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) @@ -221,12 +221,11 @@ async def test_check_repo_exists_with_redirect(mocker: MockerFixture) -> None: @pytest.mark.asyncio async def test_check_repo_exists_with_permanent_redirect(mocker: MockerFixture) -> None: - """ - Test `check_repo_exists` when a permanent redirect (301) is returned. + """Test ``check_repo_exists`` when a permanent redirect (301) is returned. Given a URL that responds with "301 Found": - When `check_repo_exists` is called, - Then it should return `True`, indicating the repo may exist at the new location. + When ``check_repo_exists`` is called, + Then it should return *True*, indicating the repo may exist at the new location. """ mock_exec = mocker.patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) mock_process = AsyncMock() @@ -241,12 +240,11 @@ async def test_check_repo_exists_with_permanent_redirect(mocker: MockerFixture) @pytest.mark.asyncio async def test_clone_with_timeout(run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository when a timeout occurs. + """Test cloning a repository when a timeout occurs. - Given a valid URL, but `run_command` times out: - When `clone_repo` is called, - Then an `AsyncTimeoutError` should be raised to indicate the operation exceeded time limits. + Given a valid URL, but ``run_command`` times out: + When ``clone_repo`` is called, + Then an ``AsyncTimeoutError`` should be raised to indicate the operation exceeded time limits. """ clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH) @@ -258,11 +256,10 @@ async def test_clone_with_timeout(run_command_mock: AsyncMock) -> None: @pytest.mark.asyncio async def test_clone_specific_branch(tmp_path: Path) -> None: - """ - Test cloning a specific branch of a repository. + """Test cloning a specific branch of a repository. Given a valid repository URL and a branch name: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned and checked out at that branch. """ repo_url = "https://github.com/cyclotruc/gitingest.git" @@ -274,17 +271,24 @@ async def test_clone_specific_branch(tmp_path: Path) -> None: assert local_path.exists(), "The repository was not cloned successfully." assert local_path.is_dir(), "The cloned repository path is not a directory." - current_branch = os.popen(f"git -C {local_path} branch --show-current").read().strip() + + current_branch = ( + await to_thread( + subprocess.check_output, + ["git", "-C", str(local_path), "branch", "--show-current"], + text=True, + ) + ).strip() + assert current_branch == branch_name, f"Expected branch '{branch_name}', got '{current_branch}'." @pytest.mark.asyncio async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: AsyncMock) -> None: - """ - Test cloning a branch with slashes in the name. + """Test cloning a branch with slashes in the name. Given a valid repository URL and a branch name with slashes: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned and checked out at that branch. """ branch_name = "fix/in-operator" @@ -307,11 +311,10 @@ async def test_clone_branch_with_slashes(tmp_path: Path, run_command_mock: Async @pytest.mark.asyncio async def test_clone_creates_parent_directory(tmp_path: Path, run_command_mock: AsyncMock) -> None: - """ - Test that clone_repo creates parent directories if they don't exist. + """Test that ``clone_repo`` creates parent directories if they don't exist. Given a local path with non-existent parent directories: - When `clone_repo` is called, + When ``clone_repo`` is called, Then it should create the parent directories before attempting to clone. """ nested_path = tmp_path / "deep" / "nested" / "path" / "repo" @@ -332,13 +335,13 @@ async def test_clone_creates_parent_directory(tmp_path: Path, run_command_mock: @pytest.mark.asyncio async def test_clone_with_specific_subpath(run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository with a specific subpath. + """Test cloning a repository with a specific subpath. Given a valid repository URL and a specific subpath: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned with sparse checkout enabled and the specified subpath. """ + expected_call_count = 2 clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, subpath="src/docs") await clone_repo(clone_config) @@ -358,19 +361,19 @@ async def test_clone_with_specific_subpath(run_command_mock: AsyncMock) -> None: # Verify the sparse-checkout command sets the correct path run_command_mock.assert_any_call("git", "-C", clone_config.local_path, "sparse-checkout", "set", "src/docs") - assert run_command_mock.call_count == 2 + assert run_command_mock.call_count == expected_call_count @pytest.mark.asyncio async def test_clone_with_commit_and_subpath(run_command_mock: AsyncMock) -> None: - """ - Test cloning a repository with both a specific commit and subpath. + """Test cloning a repository with both a specific commit and subpath. Given a valid repository URL, commit hash, and subpath: - When `clone_repo` is called, + When ``clone_repo`` is called, Then the repository should be cloned with sparse checkout enabled, checked out at the specific commit, and only include the specified subpath. """ + expected_call_count = 3 # Simulating a valid commit hash clone_config = CloneConfig(url=DEMO_URL, local_path=LOCAL_REPO_PATH, commit="a" * 40, subpath="src/docs") @@ -406,4 +409,4 @@ async def test_clone_with_commit_and_subpath(run_command_mock: AsyncMock) -> Non clone_config.commit, ) - assert run_command_mock.call_count == 3 + assert run_command_mock.call_count == expected_call_count diff --git a/tests/test_flow_integration.py b/tests/test_flow_integration.py index 7821b60a..163fd882 100644 --- a/tests/test_flow_integration.py +++ b/tests/test_flow_integration.py @@ -6,8 +6,8 @@ from typing import Generator import pytest +from fastapi import status from fastapi.testclient import TestClient -from pytest import FixtureRequest from pytest_mock import MockerFixture from src.server.main import app @@ -25,24 +25,24 @@ def test_client() -> Generator[TestClient, None, None]: @pytest.fixture(autouse=True) -def mock_static_files(mocker: MockerFixture) -> Generator[None, None, None]: +def mock_static_files(mocker: MockerFixture) -> None: """Mock the static file mount to avoid directory errors.""" mock_static = mocker.patch("src.server.main.StaticFiles", autospec=True) mock_static.return_value = None - yield mock_static + return mock_static @pytest.fixture(autouse=True) -def mock_templates(mocker: MockerFixture) -> Generator[None, None, None]: +def mock_templates(mocker: MockerFixture) -> None: """Mock Jinja2 template rendering to bypass actual file loading.""" mock_template = mocker.patch("starlette.templating.Jinja2Templates.TemplateResponse", autospec=True) mock_template.return_value = "Mocked Template Response" - yield mock_template + return mock_template @pytest.fixture(scope="module", autouse=True) def cleanup_tmp_dir() -> Generator[None, None, None]: - """Remove /tmp/gitingest after this test-module is done.""" + """Remove ``/tmp/gitingest`` after this test-module is done.""" yield # run tests temp_dir = Path("/tmp/gitingest") if temp_dir.exists(): @@ -53,7 +53,7 @@ def cleanup_tmp_dir() -> Generator[None, None, None]: @pytest.mark.asyncio -async def test_remote_repository_analysis(request: FixtureRequest) -> None: +async def test_remote_repository_analysis(request: pytest.FixtureRequest) -> None: """Test the complete flow of analyzing a remote repository.""" client = request.getfixturevalue("test_client") form_data = { @@ -65,12 +65,12 @@ async def test_remote_repository_analysis(request: FixtureRequest) -> None: } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Form submission failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Form submission failed: {response.text}" assert "Mocked Template Response" in response.text @pytest.mark.asyncio -async def test_invalid_repository_url(request: FixtureRequest) -> None: +async def test_invalid_repository_url(request: pytest.FixtureRequest) -> None: """Test handling of an invalid repository URL.""" client = request.getfixturevalue("test_client") form_data = { @@ -82,12 +82,12 @@ async def test_invalid_repository_url(request: FixtureRequest) -> None: } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Request failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Request failed: {response.text}" assert "Mocked Template Response" in response.text @pytest.mark.asyncio -async def test_large_repository(request: FixtureRequest) -> None: +async def test_large_repository(request: pytest.FixtureRequest) -> None: """Simulate analysis of a large repository with nested folders.""" client = request.getfixturevalue("test_client") form_data = { @@ -99,16 +99,16 @@ async def test_large_repository(request: FixtureRequest) -> None: } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Request failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Request failed: {response.text}" assert "Mocked Template Response" in response.text @pytest.mark.asyncio -async def test_concurrent_requests(request: FixtureRequest) -> None: +async def test_concurrent_requests(request: pytest.FixtureRequest) -> None: """Test handling of multiple concurrent requests.""" client = request.getfixturevalue("test_client") - def make_request(): + def make_request() -> None: form_data = { "input_text": "https://github.com/octocat/Hello-World", "max_file_size": "243", @@ -117,7 +117,7 @@ def make_request(): "token": "", } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Request failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Request failed: {response.text}" assert "Mocked Template Response" in response.text with ThreadPoolExecutor(max_workers=5) as executor: @@ -127,7 +127,7 @@ def make_request(): @pytest.mark.asyncio -async def test_large_file_handling(request: FixtureRequest) -> None: +async def test_large_file_handling(request: pytest.FixtureRequest) -> None: """Test handling of repositories with large files.""" client = request.getfixturevalue("test_client") form_data = { @@ -139,12 +139,12 @@ async def test_large_file_handling(request: FixtureRequest) -> None: } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Request failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Request failed: {response.text}" assert "Mocked Template Response" in response.text @pytest.mark.asyncio -async def test_repository_with_patterns(request: FixtureRequest) -> None: +async def test_repository_with_patterns(request: pytest.FixtureRequest) -> None: """Test repository analysis with include/exclude patterns.""" client = request.getfixturevalue("test_client") form_data = { @@ -156,5 +156,5 @@ async def test_repository_with_patterns(request: FixtureRequest) -> None: } response = client.post("/", data=form_data) - assert response.status_code == 200, f"Request failed: {response.text}" + assert response.status_code == status.HTTP_200_OK, f"Request failed: {response.text}" assert "Mocked Template Response" in response.text diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index df6e4e72..8fbb961e 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -1,22 +1,29 @@ -""" -Tests for the `git_utils` module. +"""Tests for the `git_utils` module. These tests validate the `validate_github_token` function, which ensures that GitHub personal access tokens (PATs) are properly formatted. """ +from __future__ import annotations + import base64 +from typing import TYPE_CHECKING import pytest from gitingest.utils.exceptions import InvalidGitHubTokenError from gitingest.utils.git_utils import ( - _is_github_host, create_git_auth_header, create_git_command, + is_github_host, validate_github_token, ) +if TYPE_CHECKING: + from pathlib import Path + + from pytest_mock import MockerFixture + @pytest.mark.parametrize( "token", @@ -27,7 +34,7 @@ "github_pat_1234567890abcdef1234567890abcdef1234", ], ) -def test_validate_github_token_valid(token): +def test_validate_github_token_valid(token: str) -> None: """validate_github_token should accept properly-formatted tokens.""" # Should not raise any exception validate_github_token(token) @@ -43,14 +50,14 @@ def test_validate_github_token_valid(token): "", # Empty string ], ) -def test_validate_github_token_invalid(token): - """validate_github_token should raise ValueError on malformed tokens.""" +def test_validate_github_token_invalid(token: str) -> None: + """Test that ``validate_github_token`` raises ``InvalidGitHubTokenError`` on malformed tokens.""" with pytest.raises(InvalidGitHubTokenError): validate_github_token(token) @pytest.mark.parametrize( - "base_cmd, local_path, url, token, expected_suffix", + ("base_cmd", "local_path", "url", "token", "expected_suffix"), [ ( ["git", "clone"], @@ -78,20 +85,26 @@ def test_validate_github_token_invalid(token): ), ], ) -def test_create_git_command(base_cmd, local_path, url, token, expected_suffix): - """create_git_command should build the correct command list based on inputs.""" +def test_create_git_command( + base_cmd: list[str], + local_path: str, + url: str, + token: str | None, + expected_suffix: list[str], +) -> None: + """Test that ``create_git_command`` builds the correct command list based on inputs.""" cmd = create_git_command(base_cmd, local_path, url, token) # The command should start with base_cmd and the -C option - expected_prefix = base_cmd + ["-C", local_path] + expected_prefix = [*base_cmd, "-C", local_path] assert cmd[: len(expected_prefix)] == expected_prefix # The suffix (anything after prefix) should match expected assert cmd[len(expected_prefix) :] == expected_suffix -def test_create_git_command_invalid_token(): - """Supplying an invalid token for a GitHub URL should raise ValueError.""" +def test_create_git_command_invalid_token() -> None: + """Test that supplying an invalid token for a GitHub URL raises ``InvalidGitHubTokenError``.""" with pytest.raises(InvalidGitHubTokenError): create_git_command( ["git", "clone"], @@ -108,8 +121,8 @@ def test_create_git_command_invalid_token(): "github_pat_1234567890abcdef1234567890abcdef1234", ], ) -def test_create_git_auth_header(token): - """create_git_auth_header should produce correct base64-encoded header.""" +def test_create_git_auth_header(token: str) -> None: + """Test that ``create_git_auth_header`` produces correct base64-encoded header.""" header = create_git_auth_header(token) expected_basic = base64.b64encode(f"x-oauth-basic:{token}".encode()).decode() expected = f"http.https://github.com/.extraheader=Authorization: Basic {expected_basic}" @@ -117,34 +130,40 @@ def test_create_git_auth_header(token): @pytest.mark.parametrize( - "url, token, should_call", + ("url", "token", "should_call"), [ ("https://github.com/foo/bar.git", "ghp_" + "f" * 36, True), ("https://github.com/foo/bar.git", None, False), ("https://gitlab.com/foo/bar.git", "ghp_" + "g" * 36, False), ], ) -def test_create_git_command_helper_calls(mocker, url, token, should_call): - """Verify validate_github_token & create_git_auth_header are invoked only when appropriate.""" - +def test_create_git_command_helper_calls( + mocker: MockerFixture, + tmp_path: Path, + *, + url: str, + token: str | None, + should_call: bool, +) -> None: + """Test that ``validate_github_token`` and ``create_git_auth_header`` are invoked only when appropriate.""" + work_dir = tmp_path / "repo" validate_mock = mocker.patch("gitingest.utils.git_utils.validate_github_token") header_mock = mocker.patch("gitingest.utils.git_utils.create_git_auth_header", return_value="HEADER") - cmd = create_git_command(["git", "clone"], "/tmp", url, token) + cmd = create_git_command(["git", "clone"], str(work_dir), url, token) if should_call: validate_mock.assert_called_once_with(token) - header_mock.assert_called_once_with(token) + header_mock.assert_called_once_with(token, url=url) assert "HEADER" in cmd else: validate_mock.assert_not_called() header_mock.assert_not_called() - # HEADER should not be included in command list assert "HEADER" not in cmd @pytest.mark.parametrize( - "url, expected", + ("url", "expected"), [ # GitHub.com URLs ("https://github.com/owner/repo.git", True), @@ -168,13 +187,13 @@ def test_create_git_command_helper_calls(mocker, url, token, should_call): ("ftp://github.com/owner/repo.git", True), # Different protocol but still github.com ], ) -def test_is_github_host(url, expected): - """_is_github_host should correctly identify GitHub and GitHub Enterprise URLs.""" - assert _is_github_host(url) == expected +def test_is_github_host(url: str, *, expected: bool) -> None: + """Test that ``is_github_host`` correctly identifies GitHub and GitHub Enterprise URLs.""" + assert is_github_host(url) == expected @pytest.mark.parametrize( - "token, url, expected_hostname", + ("token", "url", "expected_hostname"), [ # GitHub.com URLs (default) ("ghp_" + "a" * 36, "https://github.com", "github.com"), @@ -185,16 +204,16 @@ def test_is_github_host(url, expected): ("ghp_" + "d" * 36, "http://github.internal", "github.internal"), ], ) -def test_create_git_auth_header_with_ghe_url(token, url, expected_hostname): - """create_git_auth_header should handle GitHub Enterprise URLs correctly.""" - header = create_git_auth_header(token, url) +def test_create_git_auth_header_with_ghe_url(token: str, url: str, expected_hostname: str) -> None: + """Test that ``create_git_auth_header`` handles GitHub Enterprise URLs correctly.""" + header = create_git_auth_header(token, url=url) expected_basic = base64.b64encode(f"x-oauth-basic:{token}".encode()).decode() expected = f"http.https://{expected_hostname}/.extraheader=Authorization: Basic {expected_basic}" assert header == expected @pytest.mark.parametrize( - "base_cmd, local_path, url, token, expected_auth_hostname", + ("base_cmd", "local_path", "url", "token", "expected_auth_hostname"), [ # GitHub.com URLs - should use default hostname ( @@ -228,12 +247,18 @@ def test_create_git_auth_header_with_ghe_url(token, url, expected_hostname): ), ], ) -def test_create_git_command_with_ghe_urls(base_cmd, local_path, url, token, expected_auth_hostname): - """create_git_command should handle GitHub Enterprise URLs correctly.""" +def test_create_git_command_with_ghe_urls( + base_cmd: list[str], + local_path: str, + url: str, + token: str, + expected_auth_hostname: str, +) -> None: + """Test that ``create_git_command`` handles GitHub Enterprise URLs correctly.""" cmd = create_git_command(base_cmd, local_path, url, token) # Should have base command and -C option - expected_prefix = base_cmd + ["-C", local_path] + expected_prefix = [*base_cmd, "-C", local_path] assert cmd[: len(expected_prefix)] == expected_prefix # Should have -c and auth header @@ -247,7 +272,7 @@ def test_create_git_command_with_ghe_urls(base_cmd, local_path, url, token, expe @pytest.mark.parametrize( - "base_cmd, local_path, url, token", + ("base_cmd", "local_path", "url", "token"), [ # Should NOT add auth headers for non-GitHub URLs (["git", "clone"], "/some/path", "https://gitlab.com/owner/repo.git", "ghp_" + "a" * 36), @@ -255,10 +280,15 @@ def test_create_git_command_with_ghe_urls(base_cmd, local_path, url, token, expe (["git", "clone"], "/some/path", "https://git.example.com/owner/repo.git", "ghp_" + "c" * 36), ], ) -def test_create_git_command_ignores_non_github_urls(base_cmd, local_path, url, token): - """create_git_command should not add auth headers for non-GitHub URLs.""" +def test_create_git_command_ignores_non_github_urls( + base_cmd: list[str], + local_path: str, + url: str, + token: str, +) -> None: + """Test that ``create_git_command`` does not add auth headers for non-GitHub URLs.""" cmd = create_git_command(base_cmd, local_path, url, token) # Should only have base command and -C option, no auth headers - expected = base_cmd + ["-C", local_path] + expected = [*base_cmd, "-C", local_path] assert cmd == expected diff --git a/tests/test_gitignore_feature.py b/tests/test_gitignore_feature.py index fcbc74f7..e588abc9 100644 --- a/tests/test_gitignore_feature.py +++ b/tests/test_gitignore_feature.py @@ -1,22 +1,21 @@ -""" -Tests for the gitignore functionality in Gitingest. -""" +"""Tests for the gitignore functionality in Gitingest.""" from pathlib import Path import pytest from gitingest.entrypoint import ingest_async -from gitingest.utils.ignore_patterns import load_gitignore_patterns +from gitingest.utils.ignore_patterns import load_ignore_patterns @pytest.fixture(name="repo_path") def repo_fixture(tmp_path: Path) -> Path: - """ - Create a temporary repository structure with: - - A .gitignore that excludes 'exclude.txt' - - 'include.txt' (should be processed) - - 'exclude.txt' (should be skipped when gitignore rules are respected) + """Create a temporary repository structure. + + The repository structure includes: + - A ``.gitignore`` that excludes ``exclude.txt`` + - ``include.txt`` (should be processed) + - ``exclude.txt`` (should be skipped when gitignore rules are respected) """ # Create a .gitignore file that excludes 'exclude.txt' gitignore_file = tmp_path / ".gitignore" @@ -33,15 +32,13 @@ def repo_fixture(tmp_path: Path) -> Path: return tmp_path -def test_load_gitignore_patterns(tmp_path: Path): - """ - Test that load_gitignore_patterns() correctly loads patterns from a .gitignore file. - """ +def test_load_gitignore_patterns(tmp_path: Path) -> None: + """Test that ``load_ignore_patterns()`` correctly loads patterns from a ``.gitignore`` file.""" gitignore = tmp_path / ".gitignore" # Write some sample patterns with a comment line included gitignore.write_text("exclude.txt\n*.log\n# a comment\n") - patterns = load_gitignore_patterns(tmp_path) + patterns = load_ignore_patterns(tmp_path, filename=".gitignore") # Check that the expected patterns are loaded assert "exclude.txt" in patterns @@ -52,12 +49,11 @@ def test_load_gitignore_patterns(tmp_path: Path): @pytest.mark.asyncio -async def test_ingest_with_gitignore(repo_path: Path): - """ - Integration test for ingest_async() respecting .gitignore rules. +async def test_ingest_with_gitignore(repo_path: Path) -> None: + """Integration test for ``ingest_async()`` respecting ``.gitignore`` rules. - When ``include_gitignored`` is ``False`` (default), the content of 'exclude.txt' should be omitted. - When ``include_gitignored`` is ``True``, both files should be present. + When ``include_gitignored`` is *False* (default), the content of ``exclude.txt`` should be omitted. + When ``include_gitignored`` is *True*, both files should be present. """ # Run ingestion with the gitignore functionality enabled. _, _, content_with_ignore = await ingest_async(source=str(repo_path)) diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 363749fd..4ffc8828 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -1,26 +1,29 @@ -""" -Tests for the `query_ingestion` module. +"""Tests for the `query_ingestion` module. These tests validate directory scanning, file content extraction, notebook handling, and the overall ingestion logic, including filtering patterns and subpaths. """ +from __future__ import annotations + import re -from pathlib import Path -from typing import Set, TypedDict +from typing import TYPE_CHECKING, TypedDict import pytest from gitingest.ingestion import ingest_query -from gitingest.query_parsing import IngestionQuery + +if TYPE_CHECKING: + from pathlib import Path + + from gitingest.query_parser import IngestionQuery def test_run_ingest_query(temp_directory: Path, sample_query: IngestionQuery) -> None: - """ - Test `ingest_query` to ensure it processes the directory and returns expected results. + """Test ``ingest_query`` to ensure it processes the directory and returns expected results. - Given a directory with .txt and .py files: - When `ingest_query` is invoked, + Given a directory with ``.txt`` and ``.py`` files: + When ``ingest_query`` is invoked, Then it should produce a summary string listing the files analyzed and a combined content string. """ sample_query.local_path = temp_directory @@ -50,12 +53,14 @@ def test_run_ingest_query(temp_directory: Path, sample_query: IngestionQuery) -> class PatternScenario(TypedDict): - include_patterns: Set[str] - ignore_patterns: Set[str] + """A scenario for testing the ingestion of a set of patterns.""" + + include_patterns: set[str] + ignore_patterns: set[str] expected_num_files: int - expected_content: Set[str] - expected_structure: Set[str] - expected_not_structure: Set[str] + expected_content: set[str] + expected_structure: set[str] + expected_not_structure: set[str] @pytest.mark.parametrize( @@ -70,7 +75,7 @@ class PatternScenario(TypedDict): "expected_content": {"file2.py", "dir2/file_dir2.txt"}, "expected_structure": {"test_repo/", "dir2/"}, "expected_not_structure": {"src/", "subdir/", "dir1/"}, - } + }, ), id="include-explicit-files", ), @@ -88,7 +93,7 @@ class PatternScenario(TypedDict): "expected_content": {"file1.txt", "file2.py", "dir1/file_dir1.txt", "dir2/file_dir2.txt"}, "expected_structure": {"test_repo/", "dir1/", "dir2/"}, "expected_not_structure": {"src/", "subdir/"}, - } + }, ), id="include-wildcard-directory", ), @@ -105,7 +110,7 @@ class PatternScenario(TypedDict): }, "expected_structure": {"test_repo/", "src/", "subdir/"}, "expected_not_structure": {"dir1/", "dir2/"}, - } + }, ), id="include-wildcard-files", ), @@ -122,7 +127,7 @@ class PatternScenario(TypedDict): }, "expected_structure": {"test_repo/", "dir2/", "src/", "subdir/"}, "expected_not_structure": {"dir1/"}, - } + }, ), id="include-recursive-wildcard", ), @@ -142,7 +147,7 @@ class PatternScenario(TypedDict): }, "expected_structure": {"test_repo/", "src/", "subdir/", "dir1/"}, "expected_not_structure": {"dir2/"}, - } + }, ), id="exclude-explicit-files", ), @@ -161,7 +166,7 @@ class PatternScenario(TypedDict): }, "expected_structure": {"test_repo/", "src/", "subdir/", "dir2/"}, "expected_not_structure": {"dir1/"}, - } + }, ), id="exclude-wildcard-directory", ), @@ -187,7 +192,7 @@ class PatternScenario(TypedDict): "subdir/", }, "expected_not_structure": {*()}, - } + }, ), id="exclude-recursive-wildcard", ), @@ -198,19 +203,17 @@ def test_include_ignore_patterns( sample_query: IngestionQuery, pattern_scenario: PatternScenario, ) -> None: - """ - Test `ingest_query` to ensure included and ignored paths are included and ignored respectively. + """Test ``ingest_query`` to ensure included and ignored paths are included and ignored respectively. - Given a directory with .txt and .py files, and a set of include patterns or a set of ignore patterns: - When `ingest_query` is invoked, + Given a directory with ``.txt`` and ``.py`` files, and a set of include patterns or a set of ignore patterns: + When ``ingest_query`` is invoked, Then it should produce a summary string listing the files analyzed and a combined content string. """ - sample_query.local_path = temp_directory sample_query.subpath = "/" sample_query.type = None - sample_query.include_patterns = pattern_scenario["include_patterns"] or None - sample_query.ignore_patterns = pattern_scenario["ignore_patterns"] or None + sample_query.include_patterns = pattern_scenario["include_patterns"] + sample_query.ignore_patterns = pattern_scenario["ignore_patterns"] summary, structure, content = ingest_query(sample_query) diff --git a/tests/test_notebook_utils.py b/tests/test_notebook_utils.py index e51bbca0..927f410a 100644 --- a/tests/test_notebook_utils.py +++ b/tests/test_notebook_utils.py @@ -1,5 +1,4 @@ -""" -Tests for the `notebook_utils` module. +"""Tests for the `notebook` utils module. These tests validate how notebooks are processed into Python-like output, ensuring that markdown/raw cells are converted to triple-quoted blocks, code cells remain executable code, and various edge cases (multiple worksheets, @@ -8,32 +7,34 @@ import pytest -from gitingest.utils.notebook_utils import process_notebook +from gitingest.utils.notebook import process_notebook from tests.conftest import WriteNotebookFunc def test_process_notebook_all_cells(write_notebook: WriteNotebookFunc) -> None: - """ - Test processing a notebook containing markdown, code, and raw cells. + """Test processing a notebook containing markdown, code, and raw cells. Given a notebook with: - One markdown cell - One code cell - One raw cell - When `process_notebook` is invoked, + When ``process_notebook`` is invoked, Then markdown and raw cells should appear in triple-quoted blocks, and code cells remain as normal code. """ + expected_count = 4 notebook_content = { "cells": [ {"cell_type": "markdown", "source": ["# Markdown cell"]}, {"cell_type": "code", "source": ['print("Hello Code")']}, {"cell_type": "raw", "source": [""]}, - ] + ], } nb_path = write_notebook("all_cells.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 4, "Two non-code cells => 2 triple-quoted blocks => 4 total triple quotes." + assert result.count('"""') == expected_count, ( + "Two non-code cells => 2 triple-quoted blocks => 4 total triple quotes." + ) # Ensure markdown and raw cells are in triple quotes assert "# Markdown cell" in result @@ -45,13 +46,12 @@ def test_process_notebook_all_cells(write_notebook: WriteNotebookFunc) -> None: def test_process_notebook_with_worksheets(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook containing the (as of IPEP-17 deprecated) 'worksheets' key. + """Test a notebook containing the (as of IPEP-17 deprecated) ``worksheets`` key. - Given a notebook that uses the 'worksheets' key with a single worksheet, - When `process_notebook` is called, - Then a `DeprecationWarning` should be raised, and the content should match an equivalent notebook - that has top-level 'cells'. + Given a notebook that uses the ``worksheets`` key with a single worksheet, + When ``process_notebook`` is called, + Then a ``DeprecationWarning`` should be raised, and the content should match an equivalent notebook + that has top-level ``cells``. """ with_worksheets = { "worksheets": [ @@ -60,9 +60,9 @@ def test_process_notebook_with_worksheets(write_notebook: WriteNotebookFunc) -> {"cell_type": "markdown", "source": ["# Markdown cell"]}, {"cell_type": "code", "source": ['print("Hello Code")']}, {"cell_type": "raw", "source": [""]}, - ] - } - ] + ], + }, + ], } without_worksheets = with_worksheets["worksheets"][0] # same, but no 'worksheets' key @@ -79,13 +79,12 @@ def test_process_notebook_with_worksheets(write_notebook: WriteNotebookFunc) -> def test_process_notebook_multiple_worksheets(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook containing multiple 'worksheets'. + """Test a notebook containing multiple ``worksheets``. Given a notebook with two worksheets: - First with a markdown cell - Second with a code cell - When `process_notebook` is called, + When ``process_notebook`` is called, Then a warning about multiple worksheets should be raised, and the second worksheet's content should appear in the final output. """ @@ -93,13 +92,13 @@ def test_process_notebook_multiple_worksheets(write_notebook: WriteNotebookFunc) "worksheets": [ {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, {"cells": [{"cell_type": "code", "source": ["# Second Worksheet"]}]}, - ] + ], } single_worksheet = { "worksheets": [ {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, - ] + ], } nb_multi = write_notebook("multiple_worksheets.ipynb", multi_worksheets) @@ -107,16 +106,18 @@ def test_process_notebook_multiple_worksheets(write_notebook: WriteNotebookFunc) # Expect DeprecationWarning + UserWarning with pytest.warns( - DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." + DeprecationWarning, + match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook.", + ), pytest.warns( + UserWarning, + match="Multiple worksheets detected. Combining all worksheets into a single script.", ): - with pytest.warns( - UserWarning, match="Multiple worksheets detected. Combining all worksheets into a single script." - ): - result_multi = process_notebook(nb_multi) + result_multi = process_notebook(nb_multi) # Expect DeprecationWarning only with pytest.warns( - DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." + DeprecationWarning, + match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook.", ): result_single = process_notebook(nb_single) @@ -129,18 +130,17 @@ def test_process_notebook_multiple_worksheets(write_notebook: WriteNotebookFunc) def test_process_notebook_code_only(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook containing only code cells. + """Test a notebook containing only code cells. Given a notebook with code cells only: - When `process_notebook` is called, + When ``process_notebook`` is called, Then no triple quotes should appear in the output. """ notebook_content = { "cells": [ {"cell_type": "code", "source": ["print('Code Cell 1')"]}, {"cell_type": "code", "source": ["x = 42"]}, - ] + ], } nb_path = write_notebook("code_only.ipynb", notebook_content) result = process_notebook(nb_path) @@ -151,85 +151,84 @@ def test_process_notebook_code_only(write_notebook: WriteNotebookFunc) -> None: def test_process_notebook_markdown_only(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook with only markdown cells. + """Test a notebook with only markdown cells. Given a notebook with two markdown cells: - When `process_notebook` is called, + When ``process_notebook`` is called, Then each markdown cell should become a triple-quoted block (2 blocks => 4 triple quotes total). """ + expected_count = 4 notebook_content = { "cells": [ {"cell_type": "markdown", "source": ["# Markdown Header"]}, {"cell_type": "markdown", "source": ["Some more markdown."]}, - ] + ], } nb_path = write_notebook("markdown_only.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 4, "Two markdown cells => 2 blocks => 4 triple quotes total." + assert result.count('"""') == expected_count, "Two markdown cells => 2 blocks => 4 triple quotes total." assert "# Markdown Header" in result assert "Some more markdown." in result def test_process_notebook_raw_only(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook with only raw cells. + """Test a notebook with only raw cells. Given two raw cells: - When `process_notebook` is called, + When ``process_notebook`` is called, Then each raw cell should become a triple-quoted block (2 blocks => 4 triple quotes total). """ + expected_count = 4 notebook_content = { "cells": [ {"cell_type": "raw", "source": ["Raw content line 1"]}, {"cell_type": "raw", "source": ["Raw content line 2"]}, - ] + ], } nb_path = write_notebook("raw_only.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 4, "Two raw cells => 2 blocks => 4 triple quotes." + assert result.count('"""') == expected_count, "Two raw cells => 2 blocks => 4 triple quotes." assert "Raw content line 1" in result assert "Raw content line 2" in result def test_process_notebook_empty_cells(write_notebook: WriteNotebookFunc) -> None: - """ - Test that cells with an empty 'source' are skipped. + """Test that cells with an empty ``source`` are skipped. - Given a notebook with 4 cells, 3 of which have empty `source`: - When `process_notebook` is called, + Given a notebook with 4 cells, 3 of which have empty ``source``: + When ``process_notebook`` is called, Then only the non-empty cell should appear in the output (1 block => 2 triple quotes). """ + expected_count = 2 notebook_content = { "cells": [ {"cell_type": "markdown", "source": []}, {"cell_type": "code", "source": []}, {"cell_type": "raw", "source": []}, {"cell_type": "markdown", "source": ["# Non-empty markdown"]}, - ] + ], } nb_path = write_notebook("empty_cells.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 2, "Only one non-empty cell => 1 block => 2 triple quotes" + assert result.count('"""') == expected_count, "Only one non-empty cell => 1 block => 2 triple quotes" assert "# Non-empty markdown" in result def test_process_notebook_invalid_cell_type(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook with an unknown cell type. + """Test a notebook with an unknown cell type. - Given a notebook cell whose `cell_type` is unrecognized: - When `process_notebook` is called, + Given a notebook cell whose ``cell_type`` is unrecognized: + When ``process_notebook`` is called, Then a ValueError should be raised. """ notebook_content = { "cells": [ {"cell_type": "markdown", "source": ["# Valid markdown"]}, {"cell_type": "unknown", "source": ["Unrecognized cell type"]}, - ] + ], } nb_path = write_notebook("invalid_cell_type.ipynb", notebook_content) @@ -238,11 +237,10 @@ def test_process_notebook_invalid_cell_type(write_notebook: WriteNotebookFunc) - def test_process_notebook_with_output(write_notebook: WriteNotebookFunc) -> None: - """ - Test a notebook that has code cells with outputs. + """Test a notebook that has code cells with outputs. Given a code cell and multiple output objects: - When `process_notebook` is called with `include_output=True`, + When ``process_notebook`` is called with ``include_output=True``, Then the outputs should be appended as commented lines under the code. """ notebook_content = { @@ -261,33 +259,25 @@ def test_process_notebook_with_output(write_notebook: WriteNotebookFunc) -> None {"output_type": "execute_result", "data": {"text/plain": ["[1, 2, 3, 4, 5]"]}}, {"output_type": "display_data", "data": {"text/plain": ["
"]}}, ], - } - ] + }, + ], } nb_path = write_notebook("with_output.ipynb", notebook_content) with_output = process_notebook(nb_path, include_output=True) without_output = process_notebook(nb_path, include_output=False) - expected_source = "\n".join( - [ - "# Jupyter notebook converted to Python script.\n", - "import matplotlib.pyplot as plt", - "print('my_data')", - "my_data = [1, 2, 3, 4, 5]", - "plt.plot(my_data)", - "my_data\n", - ] - ) - expected_output = "\n".join( - [ - "# Output:", - "# my_data", - "# [1, 2, 3, 4, 5]", - "#
\n", - ] + expected_source = ( + "# Jupyter notebook converted to Python script.\n\n" + "import matplotlib.pyplot as plt\n" + "print('my_data')\n" + "my_data = [1, 2, 3, 4, 5]\n" + "plt.plot(my_data)\n" + "my_data\n" ) + expected_output = "# Output:\n# my_data\n# [1, 2, 3, 4, 5]\n#
\n" + expected_combined = expected_source + expected_output assert with_output == expected_combined, "Should include source code and comment-ified output." From 82acb0a3b21af3e2afc6ea6ad42da1aa4c8337f2 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:28:47 +0200 Subject: [PATCH 2/5] standardize docstrings --- src/gitingest/cli.py | 22 +++++------ src/gitingest/clone.py | 4 +- src/gitingest/config.py | 4 +- src/gitingest/entrypoint.py | 46 ++++++++++++----------- src/gitingest/output_formatter.py | 2 +- src/gitingest/query_parser.py | 18 ++++----- src/gitingest/schemas/ingestion.py | 2 +- src/gitingest/utils/async_compat.py | 4 +- src/gitingest/utils/exceptions.py | 2 +- src/gitingest/utils/file_utils.py | 6 +-- src/gitingest/utils/git_utils.py | 12 +++--- src/gitingest/utils/ignore_patterns.py | 6 +-- src/gitingest/utils/ingestion_utils.py | 2 +- src/gitingest/utils/path_utils.py | 4 +- src/gitingest/utils/query_parser_utils.py | 24 ++++++------ src/gitingest/utils/timeout_wrapper.py | 4 +- src/server/server_config.py | 7 +++- src/server/server_utils.py | 8 ++-- src/static/js/utils.js | 8 ++-- tests/conftest.py | 4 +- tests/query_parser/test_query_parser.py | 6 +-- tests/test_clone.py | 6 +-- tests/test_git_utils.py | 4 +- tests/test_gitignore_feature.py | 4 +- tests/test_ingestion.py | 2 +- tests/test_notebook_utils.py | 2 +- 26 files changed, 107 insertions(+), 106 deletions(-) diff --git a/src/gitingest/cli.py b/src/gitingest/cli.py index 7fffb6ae..949d9016 100644 --- a/src/gitingest/cli.py +++ b/src/gitingest/cli.py @@ -53,7 +53,7 @@ class _CLIArgs(TypedDict): envvar="GITHUB_TOKEN", default=None, help=( - "GitHub personal access token for accessing private repositories. " + "GitHub personal access token (PAT) for accessing private repositories. " "If omitted, the CLI will look for the GITHUB_TOKEN environment variable." ), ) @@ -81,9 +81,9 @@ async def _async_main( ) -> None: """Analyze a directory or repository and create a text dump of its contents. - This command analyzes the contents of a specified source directory or repository, applies custom include and - exclude patterns, and generates a text summary of the analysis which is then written to an output file - or printed to stdout. + This command scans the specified ``source`` (a local directory or Git repo), + applies custom include and exclude patterns, and generates a text summary of + the analysis. The summary is written to an output file or printed to ``stdout``. Parameters ---------- @@ -96,20 +96,20 @@ async def _async_main( include_pattern : tuple[str, ...] | None Glob patterns for including files in the output. branch : str | None - Git branch to ingest. If *None*, the repository's default branch is used. + Git branch to ingest. If ``None``, the repository's default branch is used. include_gitignored : bool - If *True*, also ingest files matched by ``.gitignore`` or ``.gitingestignore`` (default: ``False``). + If ``True``, also ingest files matched by ``.gitignore`` or ``.gitingestignore`` (default: ``False``). token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the ``GITHUB_TOKEN`` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. output : str | None - Destination file path. If *None*, the output is written to ``.txt`` in the current directory. - Use ``"-"`` to write to *stdout*. + Destination file path. If ``None``, the output is written to ``.txt`` in the current directory. + Use ``"-"`` to write to ``stdout``. Raises ------ click.Abort - If there is an error during the execution of the command, this exception is raised to abort the process. + Raised if an error occurs during execution and the command must be aborted. """ try: diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py index ab14d261..0237e5fc 100644 --- a/src/gitingest/clone.py +++ b/src/gitingest/clone.py @@ -35,8 +35,8 @@ async def clone_repo(config: CloneConfig, token: str | None = None) -> None: config : CloneConfig The configuration for cloning the repository. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the `GITHUB_TOKEN` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Raises ------ diff --git a/src/gitingest/config.py b/src/gitingest/config.py index 3f4e3724..3d154684 100644 --- a/src/gitingest/config.py +++ b/src/gitingest/config.py @@ -3,10 +3,10 @@ import tempfile from pathlib import Path -MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB +MAX_FILE_SIZE = 10 * 1024 * 1024 # Maximum size of a single file to process (10 MB) MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal MAX_FILES = 10_000 # Maximum number of files to process -MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB +MAX_TOTAL_SIZE_BYTES = 500 * 1024 * 1024 # Maximum size of output file (500 MB) DEFAULT_TIMEOUT = 60 # seconds OUTPUT_FILE_NAME = "digest.txt" diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index cbc041a2..37205f8c 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -9,7 +9,7 @@ from pathlib import Path from gitingest.clone import clone_repo -from gitingest.config import TMP_BASE_PATH +from gitingest.config import MAX_FILE_SIZE, TMP_BASE_PATH from gitingest.ingestion import ingest_query from gitingest.query_parser import IngestionQuery, parse_query from gitingest.utils.async_compat import to_thread @@ -20,7 +20,7 @@ async def ingest_async( source: str, *, - max_file_size: int = 10 * 1024 * 1024, # 10 MB + max_file_size: int = MAX_FILE_SIZE, # 10 MB include_patterns: str | set[str] | None = None, exclude_patterns: str | set[str] | None = None, branch: str | None = None, @@ -41,19 +41,20 @@ async def ingest_async( max_file_size : int Maximum allowed file size for file ingestion. Files larger than this size are ignored (default: 10 MB). include_patterns : str | set[str] | None - Pattern or set of patterns specifying which files to include. If *None*, all files are included. + Pattern or set of patterns specifying which files to include. If ``None``, all files are included. exclude_patterns : str | set[str] | None - Pattern or set of patterns specifying which files to exclude. If *None*, no files are excluded. + Pattern or set of patterns specifying which files to exclude. If ``None``, no files are excluded. branch : str | None - The branch to clone and ingest. If *None*, the default branch is used. + The branch to clone and ingest (default: the default branch). include_gitignored : bool - If *True*, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). + If ``True``, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the ``GITHUB_TOKEN`` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. output : str | None - File path where the summary and content should be written. If *"-"* (dash), the results are written to stdout. - If *None*, the results are not written to a file. + File path where the summary and content should be written. + If ``"-"`` (dash), the results are written to ``stdout``. + If ``None``, the results are not written to a file. Returns ------- @@ -99,7 +100,7 @@ async def ingest_async( def ingest( source: str, *, - max_file_size: int = 10 * 1024 * 1024, # 10 MB + max_file_size: int = MAX_FILE_SIZE, include_patterns: str | set[str] | None = None, exclude_patterns: str | set[str] | None = None, branch: str | None = None, @@ -120,19 +121,20 @@ def ingest( max_file_size : int Maximum allowed file size for file ingestion. Files larger than this size are ignored (default: 10 MB). include_patterns : str | set[str] | None - Pattern or set of patterns specifying which files to include. If *None*, all files are included. + Pattern or set of patterns specifying which files to include. If ``None``, all files are included. exclude_patterns : str | set[str] | None - Pattern or set of patterns specifying which files to exclude. If *None*, no files are excluded. + Pattern or set of patterns specifying which files to exclude. If ``None``, no files are excluded. branch : str | None The branch to clone and ingest (default: the default branch). include_gitignored : bool - If *True*, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). + If ``True``, include files ignored by ``.gitignore`` and ``.gitingestignore`` (default: ``False``). token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the ``GITHUB_TOKEN`` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. output : str | None - File path where the summary and content should be written. If *"-"* (dash), the results are written to stdout. - If *None*, the results are not written to a file. + File path where the summary and content should be written. + If ``"-"`` (dash), the results are written to ``stdout``. + If ``None``, the results are not written to a file. Returns ------- @@ -162,7 +164,7 @@ def ingest( def _apply_gitignores(query: IngestionQuery) -> None: - """Update `query.ignore_patterns` in-place. + """Update ``query.ignore_patterns`` in-place. Parameters ---------- @@ -187,7 +189,7 @@ async def _clone_if_remote(query: IngestionQuery, token: str | None) -> None: Raises ------ TypeError - If `clone_repo` does not return a coroutine. + If ``clone_repo`` does not return a coroutine. """ if not query.url: # local path ingestion @@ -207,7 +209,7 @@ async def _clone_if_remote(query: IngestionQuery, token: str | None) -> None: async def _write_output(tree: str, content: str, target: str | None) -> None: - """Write combined output to *target* (`'-'` ⇒ stdout). + """Write combined output to ``target`` (``"-"`` ⇒ stdout). Parameters ---------- @@ -216,7 +218,7 @@ async def _write_output(tree: str, content: str, target: str | None) -> None: content : str The content of the files in the repository or directory. target : str | None - The path to the output file. If *None*, the results are not written to a file. + The path to the output file. If ``None``, the results are not written to a file. """ if target == "-": diff --git a/src/gitingest/output_formatter.py b/src/gitingest/output_formatter.py index 4714fccd..034380ce 100644 --- a/src/gitingest/output_formatter.py +++ b/src/gitingest/output_formatter.py @@ -179,7 +179,7 @@ def _format_token_count(text: str) -> str | None: Returns ------- str | None - The formatted number of tokens as a string (e.g., ``1.2k``, ``1.2M``), or ``None`` if an error occurs. + The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or ``None`` if an error occurs. """ try: diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 0078043b..fb50817f 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -48,8 +48,8 @@ async def parse_query( ignore_patterns : set[str] | str | None Patterns to ignore. Can be a set of strings or a single string. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the `GITHUB_TOKEN` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- @@ -99,17 +99,17 @@ async def _parse_remote_repo(source: str, token: str | None = None) -> Ingestion """Parse a repository URL into a structured query dictionary. If source is: - - A fully qualified URL (`https://gitlab.com/...`), parse & verify that domain - - A URL missing `https://` (`gitlab.com/...`), add `https://` and parse - - A `slug` (`pandas-dev/pandas`), attempt known domains until we find one that exists. + - A fully qualified URL ('https://gitlab.com/...'), parse & verify that domain + - A URL missing 'https://' ('gitlab.com/...'), add 'https://' and parse + - A *slug* ('pandas-dev/pandas'), attempt known domains until we find one that exists. Parameters ---------- source : str The URL or domain-less slug to parse. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the ``GITHUB_TOKEN`` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- @@ -300,8 +300,8 @@ async def try_domains_for_user_and_repo(user_name: str, repo_name: str, token: s repo_name : str The name of the repository. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the ``GITHUB_TOKEN`` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- diff --git a/src/gitingest/schemas/ingestion.py b/src/gitingest/schemas/ingestion.py index efa21319..03c6083b 100644 --- a/src/gitingest/schemas/ingestion.py +++ b/src/gitingest/schemas/ingestion.py @@ -101,7 +101,7 @@ def extract_clone_config(self) -> CloneConfig: Raises ------ ValueError - If the `url` parameter is not provided. + If the ``url`` parameter is not provided. """ if not self.url: diff --git a/src/gitingest/utils/async_compat.py b/src/gitingest/utils/async_compat.py index 18123799..d74ef664 100644 --- a/src/gitingest/utils/async_compat.py +++ b/src/gitingest/utils/async_compat.py @@ -21,14 +21,14 @@ async def to_thread(func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R: """Back-port :func:`asyncio.to_thread` for Python < 3.9. - Run `func` in the default thread-pool executor and return the result. + Run ``func`` in the default thread-pool executor and return the result. """ loop = asyncio.get_running_loop() ctx = contextvars.copy_context() func_call = functools.partial(ctx.run, func, *args, **kwargs) return await loop.run_in_executor(None, func_call) - # Patch stdlib so that *existing* imports of ``asyncio`` see the shim. + # Patch stdlib so that *existing* imports of asyncio see the shim. if not hasattr(asyncio, "to_thread"): asyncio.to_thread = to_thread diff --git a/src/gitingest/utils/exceptions.py b/src/gitingest/utils/exceptions.py index dac985db..6ae34ceb 100644 --- a/src/gitingest/utils/exceptions.py +++ b/src/gitingest/utils/exceptions.py @@ -26,7 +26,7 @@ def __init__(self, pattern: str) -> None: class AsyncTimeoutError(Exception): """Exception raised when an async operation exceeds its timeout limit. - This exception is used by the `async_timeout` decorator to signal that the wrapped + This exception is used by the ``async_timeout`` decorator to signal that the wrapped asynchronous function has exceeded the specified time limit for execution. """ diff --git a/src/gitingest/utils/file_utils.py b/src/gitingest/utils/file_utils.py index ac2b8940..2c6ef74d 100644 --- a/src/gitingest/utils/file_utils.py +++ b/src/gitingest/utils/file_utils.py @@ -44,7 +44,7 @@ def _read_chunk(path: Path) -> bytes | None: Returns ------- bytes | None - The first *_CHUNK_SIZE* bytes of *path*, or None on any `OSError`. + The first ``_CHUNK_SIZE`` bytes of ``path``, or ``None`` on any ``OSError``. """ try: @@ -55,7 +55,7 @@ def _read_chunk(path: Path) -> bytes | None: def _decodes(chunk: bytes, encoding: str) -> bool: - """Return *True* if *chunk* decodes cleanly with *encoding*. + """Return ``True`` if ``chunk`` decodes cleanly with ``encoding``. Parameters ---------- @@ -67,7 +67,7 @@ def _decodes(chunk: bytes, encoding: str) -> bool: Returns ------- bool - True if the chunk decodes cleanly with the encoding, False otherwise. + ``True`` if the chunk decodes cleanly with the encoding, ``False`` otherwise. """ try: diff --git a/src/gitingest/utils/git_utils.py b/src/gitingest/utils/git_utils.py index 53f324e4..104d71a4 100644 --- a/src/gitingest/utils/git_utils.py +++ b/src/gitingest/utils/git_utils.py @@ -88,8 +88,8 @@ async def check_repo_exists(url: str, token: str | None = None) -> bool: url : str The URL of the Git repository to check. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the `GITHUB_TOKEN` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- @@ -139,8 +139,8 @@ async def _check_github_repo_exists(url: str, token: str | None = None) -> bool: url : str The URL of the GitHub repository to check. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the `GITHUB_TOKEN` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- @@ -240,8 +240,8 @@ async def fetch_remote_branch_list(url: str, token: str | None = None) -> list[s url : str The URL of the Git repository to fetch branches from. token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. - Can also be set via the `GITHUB_TOKEN` env var. + GitHub personal access token (PAT) for accessing private repositories. + Can also be set via the ``GITHUB_TOKEN`` environment variable. Returns ------- diff --git a/src/gitingest/utils/ignore_patterns.py b/src/gitingest/utils/ignore_patterns.py index 079ce758..9339b512 100644 --- a/src/gitingest/utils/ignore_patterns.py +++ b/src/gitingest/utils/ignore_patterns.py @@ -169,7 +169,7 @@ def load_ignore_patterns(root: Path, filename: str) -> set[str]: """Load ignore patterns from ``filename`` found under ``root``. The loader walks the directory tree, looks for the supplied ``filename``, - and returns a unified ``set`` of patterns. It implements the same parsing rules + and returns a unified set of patterns. It implements the same parsing rules we use for ``.gitignore`` and ``.gitingestignore`` (git-wildmatch syntax with support for negation and root-relative paths). @@ -223,12 +223,12 @@ def _parse_ignore_file(ignore_file: Path, root: Path) -> set[str]: if not line or line.startswith("#"): # comments / blank lines continue - # Handle negation (`!foobar`) + # Handle negation ("!foobar") negated = line.startswith("!") if negated: line = line[1:] - # Handle leading slash (`/foobar`) + # Handle leading slash ("/foobar") if line.startswith("/"): line = line.lstrip("/") diff --git a/src/gitingest/utils/ingestion_utils.py b/src/gitingest/utils/ingestion_utils.py index c4ba24ae..21a03f22 100644 --- a/src/gitingest/utils/ingestion_utils.py +++ b/src/gitingest/utils/ingestion_utils.py @@ -79,7 +79,7 @@ def _relative_or_none(path: Path, base: Path) -> Path | None: Returns ------- Path | None - The relative path of *path* to *base*, or ``None`` if *path* is outside *base*. + The relative path of ``path`` to ``base``, or ``None`` if ``path`` is outside ``base``. """ try: diff --git a/src/gitingest/utils/path_utils.py b/src/gitingest/utils/path_utils.py index fe387da4..55ed3d48 100644 --- a/src/gitingest/utils/path_utils.py +++ b/src/gitingest/utils/path_utils.py @@ -5,7 +5,7 @@ def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool: - """Return ``True`` if *symlink_path* resolves inside *base_path*. + """Return ``True`` if ``symlink_path`` resolves inside ``base_path``. Parameters ---------- @@ -17,7 +17,7 @@ def _is_safe_symlink(symlink_path: Path, base_path: Path) -> bool: Returns ------- bool - Whether the symlink is “safe” (i.e., does not escape *base_path*). + Whether the symlink is “safe” (i.e., does not escape ``base_path``). """ # On Windows a non-symlink is immediately unsafe diff --git a/src/gitingest/utils/query_parser_utils.py b/src/gitingest/utils/query_parser_utils.py index 7b57ba40..e6e98e18 100644 --- a/src/gitingest/utils/query_parser_utils.py +++ b/src/gitingest/utils/query_parser_utils.py @@ -32,7 +32,7 @@ def _is_valid_git_commit_hash(commit: str) -> bool: Returns ------- bool - `True` if the string is a valid 40-character Git commit hash, otherwise `False`. + ``True`` if the string is a valid 40-character Git commit hash, otherwise ``False``. """ sha_hex_length = 40 @@ -43,8 +43,8 @@ def _is_valid_pattern(pattern: str) -> bool: """Validate if the given pattern contains only valid characters. This function checks if the pattern contains only alphanumeric characters or one - of the following allowed characters: dash (`-`), underscore (`_`), dot (`.`), - forward slash (`/`), plus (`+`), asterisk (`*`), or the at sign (`@`). + of the following allowed characters: dash ('-'), underscore ('_'), dot ('.'), + forward slash ('/'), plus ('+'), asterisk ('*'), or the at sign ('@'). Parameters ---------- @@ -54,7 +54,7 @@ def _is_valid_pattern(pattern: str) -> bool: Returns ------- bool - `True` if the pattern is valid, otherwise `False`. + ``True`` if the pattern is valid, otherwise ``False``. """ return all(c.isalnum() or c in "-_./+*@" for c in pattern) @@ -63,9 +63,9 @@ def _is_valid_pattern(pattern: str) -> bool: def _validate_host(host: str) -> None: """Validate a hostname. - The host is accepted if it is either present in the hard-coded `KNOWN_GIT_HOSTS` list or if it satisfies the - simple heuristics in `_looks_like_git_host`, which try to recognise common self-hosted Git services (e.g. GitLab - instances on sub-domains such as `gitlab.example.com` or `git.example.com`). + The host is accepted if it is either present in the hard-coded ``KNOWN_GIT_HOSTS`` list or if it satisfies the + simple heuristics in ``_looks_like_git_host``, which try to recognise common self-hosted Git services (e.g. GitLab + instances on sub-domains such as 'gitlab.example.com' or 'git.example.com'). Parameters ---------- @@ -87,8 +87,8 @@ def _validate_host(host: str) -> None: def _looks_like_git_host(host: str) -> bool: """Check if the given host looks like a Git host. - The current heuristic returns `True` when the host starts with `git.` (e.g. `git.example.com`), starts with - `gitlab.` (e.g. `gitlab.company.com`), or starts with `github.` (e.g. `github.company.com` for GitHub Enterprise). + The current heuristic returns ``True`` when the host starts with ``git.`` (e.g. 'git.example.com'), starts with + 'gitlab.' (e.g. 'gitlab.company.com'), or starts with 'github.' (e.g. 'github.company.com' for GitHub Enterprise). Parameters ---------- @@ -98,7 +98,7 @@ def _looks_like_git_host(host: str) -> bool: Returns ------- bool - `True` if the host looks like a Git host, otherwise `False`. + ``True`` if the host looks like a Git host, otherwise ``False``. """ host = host.lower() @@ -116,7 +116,7 @@ def _validate_url_scheme(scheme: str) -> None: Raises ------ ValueError - If the scheme is not `http` or `https`. + If the scheme is not 'http' or 'https'. """ scheme = scheme.lower() @@ -156,7 +156,7 @@ def _normalize_pattern(pattern: str) -> str: """Normalize the given pattern by removing leading separators and appending a wildcard. This function processes the pattern string by stripping leading directory separators - and appending a wildcard (`*`) if the pattern ends with a separator. + and appending a wildcard (``*``) if the pattern ends with a separator. Parameters ---------- diff --git a/src/gitingest/utils/timeout_wrapper.py b/src/gitingest/utils/timeout_wrapper.py index f811e29c..a5268d15 100644 --- a/src/gitingest/utils/timeout_wrapper.py +++ b/src/gitingest/utils/timeout_wrapper.py @@ -16,7 +16,7 @@ def async_timeout(seconds: int) -> Callable[[Callable[P, Awaitable[T]]], Callabl This decorator wraps an asynchronous function and ensures it does not run for longer than the specified number of seconds. If the function execution exceeds - this limit, it raises an `AsyncTimeoutError`. + this limit, it raises an ``AsyncTimeoutError``. Parameters ---------- @@ -28,7 +28,7 @@ def async_timeout(seconds: int) -> Callable[[Callable[P, Awaitable[T]]], Callabl Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]] A decorator that, when applied to an async function, ensures the function completes within the specified time limit. If the function takes too long, - an `AsyncTimeoutError` is raised. + an ``AsyncTimeoutError`` is raised. """ diff --git a/src/server/server_config.py b/src/server/server_config.py index 29b02a1a..9a74b640 100644 --- a/src/server/server_config.py +++ b/src/server/server_config.py @@ -7,8 +7,11 @@ MAX_DISPLAY_SIZE: int = 300_000 DELETE_REPO_AFTER: int = 60 * 60 # In seconds (1 hour) -# Default maximum file size to include in the digest -DEFAULT_FILE_SIZE_KB: int = 50 +DEFAULT_FILE_SIZE_KB: int = 50 # Default maximum file size to include in the digest + +# Slider configuration (if updated, update the logSliderToSize function in src/static/js/utils.js) +MAX_FILE_SIZE_KB: int = 100 * 1024 # 100 MB +MAX_SLIDER_POSITION: int = 500 # Maximum slider position EXAMPLE_REPOS: list[dict[str, str]] = [ {"name": "Gitingest", "url": "https://github.com/cyclotruc/gitingest"}, diff --git a/src/server/server_utils.py b/src/server/server_utils.py index 3f1010d1..8d145b36 100644 --- a/src/server/server_utils.py +++ b/src/server/server_utils.py @@ -16,7 +16,7 @@ from gitingest.config import TMP_BASE_PATH from gitingest.utils.async_compat import to_thread -from server.server_config import DELETE_REPO_AFTER +from server.server_config import DELETE_REPO_AFTER, MAX_FILE_SIZE_KB, MAX_SLIDER_POSITION # Initialize a rate limiter limiter = Limiter(key_func=get_remote_address) @@ -177,10 +177,8 @@ def log_slider_to_size(position: int) -> int: File size in bytes corresponding to the slider position. """ - maxp = 500 - minv = math.log(1) - maxv = math.log(102_400) - return round(math.exp(minv + (maxv - minv) * pow(position / maxp, 1.5))) * 1024 + maxv = math.log(MAX_FILE_SIZE_KB) + return round(math.exp(maxv * pow(position / MAX_SLIDER_POSITION, 1.5))) * 1024 ## Color printing utility diff --git a/src/static/js/utils.js b/src/static/js/utils.js index eba042eb..538ac5dc 100644 --- a/src/static/js/utils.js +++ b/src/static/js/utils.js @@ -150,12 +150,10 @@ function copyFullDigest() { // Add the logSliderToSize helper function function logSliderToSize(position) { - const minp = 0; - const maxp = 500; - const minv = Math.log(1); - const maxv = Math.log(102400); + const maxPosition = 500; + const maxValue = Math.log(102400); // 100 MB - const value = Math.exp(minv + (maxv - minv) * Math.pow(position / maxp, 1.5)); + const value = Math.exp(maxValue * Math.pow(position / maxPosition, 1.5)); return Math.round(value); } diff --git a/tests/conftest.py b/tests/conftest.py index 3710ab4e..be22720e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def _factory(branches: list[str]) -> None: @pytest.fixture def repo_exists_true(mocker: MockerFixture) -> AsyncMock: - """Patch ``gitingest.clone.check_repo_exists`` to always return *True*.""" + """Patch ``gitingest.clone.check_repo_exists`` to always return ``True``.""" return mocker.patch("gitingest.clone.check_repo_exists", return_value=True) @@ -169,7 +169,7 @@ def run_command_mock(mocker: MockerFixture) -> AsyncMock: """Patch ``gitingest.clone.run_command`` with an ``AsyncMock``. The mocked function returns a dummy process whose ``communicate`` method yields generic - *stdout* / *stderr* bytes. Tests can still access / tweak the mock via the fixture argument. + ``stdout`` / ``stderr`` bytes. Tests can still access / tweak the mock via the fixture argument. """ mock_exec = mocker.patch("gitingest.clone.run_command", new_callable=AsyncMock) diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index 5c9c4e44..b615b093 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -1,4 +1,4 @@ -"""Tests for the `query_parser` module. +"""Tests for the ``query_parser`` module. These tests cover URL parsing, pattern parsing, and handling of branches/subpaths for HTTP(S) repositories and local paths. @@ -206,7 +206,7 @@ async def test_parse_query_empty_patterns() -> None: Given empty include_patterns and ignore_patterns: When ``parse_query`` is called, - Then ``include_patterns`` becomes *None* and default ignore patterns apply. + Then ``include_patterns`` becomes ``None`` and default ignore patterns apply. """ query = await parse_query(DEMO_URL, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") @@ -301,7 +301,7 @@ async def test_parse_url_branch_and_commit_distinction( Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL: When ``_parse_remote_repo`` is called with branch fetching, - Then the function should correctly set `branch` or `commit` based on the URL content. + Then the function should correctly set ``branch`` or ``commit`` based on the URL content. """ stub_branches(["main", "dev", "feature-branch"]) diff --git a/tests/test_clone.py b/tests/test_clone.py index d560d684..c76088e7 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -1,4 +1,4 @@ -"""Tests for the `clone` module. +"""Tests for the ``clone`` module. These tests cover various scenarios for cloning repositories, verifying that the appropriate Git commands are invoked and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches. @@ -206,7 +206,7 @@ async def test_check_repo_exists_with_redirect(mocker: MockerFixture) -> None: Given a URL that responds with "302 Found": When ``check_repo_exists`` is called, - Then it should return `False`, indicating the repo is inaccessible. + Then it should return ``False``, indicating the repo is inaccessible. """ mock_exec = mocker.patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) mock_process = AsyncMock() @@ -225,7 +225,7 @@ async def test_check_repo_exists_with_permanent_redirect(mocker: MockerFixture) Given a URL that responds with "301 Found": When ``check_repo_exists`` is called, - Then it should return *True*, indicating the repo may exist at the new location. + Then it should return ``True``, indicating the repo may exist at the new location. """ mock_exec = mocker.patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) mock_process = AsyncMock() diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index 8fbb961e..1410d355 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -1,6 +1,6 @@ -"""Tests for the `git_utils` module. +"""Tests for the ``git_utils`` module. -These tests validate the `validate_github_token` function, which ensures that +These tests validate the ``validate_github_token`` function, which ensures that GitHub personal access tokens (PATs) are properly formatted. """ diff --git a/tests/test_gitignore_feature.py b/tests/test_gitignore_feature.py index e588abc9..68ab821a 100644 --- a/tests/test_gitignore_feature.py +++ b/tests/test_gitignore_feature.py @@ -52,8 +52,8 @@ def test_load_gitignore_patterns(tmp_path: Path) -> None: async def test_ingest_with_gitignore(repo_path: Path) -> None: """Integration test for ``ingest_async()`` respecting ``.gitignore`` rules. - When ``include_gitignored`` is *False* (default), the content of ``exclude.txt`` should be omitted. - When ``include_gitignored`` is *True*, both files should be present. + When ``include_gitignored`` is ``False`` (default), the content of ``exclude.txt`` should be omitted. + When ``include_gitignored`` is ``True``, both files should be present. """ # Run ingestion with the gitignore functionality enabled. _, _, content_with_ignore = await ingest_async(source=str(repo_path)) diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index 4ffc8828..f3585e05 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -1,4 +1,4 @@ -"""Tests for the `query_ingestion` module. +"""Tests for the ``query_ingestion`` module. These tests validate directory scanning, file content extraction, notebook handling, and the overall ingestion logic, including filtering patterns and subpaths. diff --git a/tests/test_notebook_utils.py b/tests/test_notebook_utils.py index 927f410a..120b374f 100644 --- a/tests/test_notebook_utils.py +++ b/tests/test_notebook_utils.py @@ -1,4 +1,4 @@ -"""Tests for the `notebook` utils module. +"""Tests for the ``notebook`` utils module. These tests validate how notebooks are processed into Python-like output, ensuring that markdown/raw cells are converted to triple-quoted blocks, code cells remain executable code, and various edge cases (multiple worksheets, From 979df2cd2bca9eeac6fabf1b41dc254b84746964 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:59:58 +0200 Subject: [PATCH 3/5] remove to_thread in favour of loop.run_in_executor --- src/gitingest/entrypoint.py | 8 +++---- src/gitingest/utils/async_compat.py | 35 ----------------------------- src/server/server_utils.py | 18 +++++++-------- tests/test_clone.py | 16 ++++++++----- 4 files changed, 21 insertions(+), 56 deletions(-) delete mode 100644 src/gitingest/utils/async_compat.py diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index 37205f8c..91406aac 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -12,7 +12,6 @@ from gitingest.config import MAX_FILE_SIZE, TMP_BASE_PATH from gitingest.ingestion import ingest_query from gitingest.query_parser import IngestionQuery, parse_query -from gitingest.utils.async_compat import to_thread from gitingest.utils.auth import resolve_token from gitingest.utils.ignore_patterns import load_ignore_patterns @@ -221,11 +220,10 @@ async def _write_output(tree: str, content: str, target: str | None) -> None: The path to the output file. If ``None``, the results are not written to a file. """ + data = f"{tree}\n{content}" + loop = asyncio.get_running_loop() if target == "-": - loop = asyncio.get_running_loop() - data = f"{tree}\n{content}" await loop.run_in_executor(None, sys.stdout.write, data) await loop.run_in_executor(None, sys.stdout.flush) elif target is not None: - path = Path(target) - await to_thread(path.write_text, f"{tree}\n{content}", encoding="utf-8") + await loop.run_in_executor(None, Path(target).write_text, data, "utf-8") diff --git a/src/gitingest/utils/async_compat.py b/src/gitingest/utils/async_compat.py deleted file mode 100644 index d74ef664..00000000 --- a/src/gitingest/utils/async_compat.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Compatibility utilities wrapping asyncio features that are missing on older Python versions (e.g. 3.8).""" - -from __future__ import annotations - -import asyncio -import contextvars -import functools -import sys -from typing import Callable, TypeVar - -from gitingest.utils.compat_typing import ParamSpec - -P = ParamSpec("P") -R = TypeVar("R") - - -if sys.version_info >= (3, 9): - from asyncio import to_thread -else: - - async def to_thread(func: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R: - """Back-port :func:`asyncio.to_thread` for Python < 3.9. - - Run ``func`` in the default thread-pool executor and return the result. - """ - loop = asyncio.get_running_loop() - ctx = contextvars.copy_context() - func_call = functools.partial(ctx.run, func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) - - # Patch stdlib so that *existing* imports of asyncio see the shim. - if not hasattr(asyncio, "to_thread"): - asyncio.to_thread = to_thread - -__all__ = ["to_thread"] diff --git a/src/server/server_utils.py b/src/server/server_utils.py index 8d145b36..d1e2eeea 100644 --- a/src/server/server_utils.py +++ b/src/server/server_utils.py @@ -15,7 +15,6 @@ from slowapi.util import get_remote_address from gitingest.config import TMP_BASE_PATH -from gitingest.utils.async_compat import to_thread from server.server_config import DELETE_REPO_AFTER, MAX_FILE_SIZE_KB, MAX_SLIDER_POSITION # Initialize a rate limiter @@ -112,7 +111,7 @@ async def _remove_old_repositories( async def _process_folder(folder: Path) -> None: - """Append the repo URL (if discoverable) to ``history.txt`` and delete *folder*. + """Append the repo URL (if discoverable) to ``history.txt`` and delete ``folder``. Parameters ---------- @@ -121,27 +120,26 @@ async def _process_folder(folder: Path) -> None: """ history_file = Path("history.txt") + loop = asyncio.get_running_loop() try: - first_txt_file: Path = next(folder.glob("*.txt")) + first_txt_file = next(folder.glob("*.txt")) except StopIteration: # No .txt file found return - # Try to log repository URL before deletion + # Append owner/repo to history.txt try: - # Extract owner and repository name from the filename - filename = first_txt_file.stem # e.g. "owner-repo" + filename = first_txt_file.stem # "owner-repo" if "-" in filename: owner, repo = filename.split("-", 1) repo_url = f"{owner}/{repo}" - await to_thread(_append_line, history_file, repo_url) - + await loop.run_in_executor(None, _append_line, history_file, repo_url) except (OSError, PermissionError) as exc: print(f"Error logging repository URL for {folder}: {exc}") - # Delete the folder + # Delete the cloned repo try: - await to_thread(shutil.rmtree, folder) + await loop.run_in_executor(None, shutil.rmtree, folder) except PermissionError as exc: print(f"No permission to delete {folder}: {exc}") except OSError as exc: diff --git a/tests/test_clone.py b/tests/test_clone.py index c76088e7..72284a35 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -14,7 +14,6 @@ from gitingest.clone import clone_repo from gitingest.schemas import CloneConfig -from gitingest.utils.async_compat import to_thread from gitingest.utils.exceptions import AsyncTimeoutError from gitingest.utils.git_utils import check_repo_exists from tests.conftest import DEMO_URL, LOCAL_REPO_PATH @@ -272,13 +271,18 @@ async def test_clone_specific_branch(tmp_path: Path) -> None: assert local_path.exists(), "The repository was not cloned successfully." assert local_path.is_dir(), "The cloned repository path is not a directory." + loop = asyncio.get_running_loop() current_branch = ( - await to_thread( - subprocess.check_output, - ["git", "-C", str(local_path), "branch", "--show-current"], - text=True, + ( + await loop.run_in_executor( + None, + subprocess.check_output, + ["git", "-C", str(local_path), "branch", "--show-current"], + ) ) - ).strip() + .decode() + .strip() + ) assert current_branch == branch_name, f"Expected branch '{branch_name}', got '{current_branch}'." From 94be0ac00f05a3f514788c2e0053e66724fa396c Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sat, 28 Jun 2025 18:14:47 +0200 Subject: [PATCH 4/5] standardize docstrings --- src/gitingest/clone.py | 2 +- src/gitingest/entrypoint.py | 4 ++-- src/gitingest/output_formatter.py | 2 +- src/gitingest/utils/auth.py | 2 +- src/gitingest/utils/git_utils.py | 20 ++++++++++---------- src/gitingest/utils/os_utils.py | 4 ++-- src/server/models.py | 4 ++-- src/server/query_processor.py | 2 +- src/server/server_utils.py | 2 +- tests/test_cli.py | 2 +- tests/test_clone.py | 2 +- 11 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/gitingest/clone.py b/src/gitingest/clone.py index 0237e5fc..a3c02c3e 100644 --- a/src/gitingest/clone.py +++ b/src/gitingest/clone.py @@ -65,7 +65,7 @@ async def clone_repo(config: CloneConfig, token: str | None = None) -> None: clone_cmd = ["git"] if token and is_github_host(url): - clone_cmd += ["-c", create_git_auth_header(token, url)] + clone_cmd += ["-c", create_git_auth_header(token, url=url)] clone_cmd += ["clone", "--single-branch"] # TODO: Re-enable --recurse-submodules when submodule support is needed diff --git a/src/gitingest/entrypoint.py b/src/gitingest/entrypoint.py index 91406aac..96f7aa0c 100644 --- a/src/gitingest/entrypoint.py +++ b/src/gitingest/entrypoint.py @@ -83,7 +83,7 @@ async def ingest_async( repo_cloned = False try: - await _clone_if_remote(query, token) + await _clone_if_remote(query, token=token) repo_cloned = bool(query.url) summary, tree, content = ingest_query(query) @@ -183,7 +183,7 @@ async def _clone_if_remote(query: IngestionQuery, token: str | None) -> None: query : IngestionQuery The query to clone. token : str | None - The token to use for the query. + GitHub personal access token (PAT) for accessing private repositories. Raises ------ diff --git a/src/gitingest/output_formatter.py b/src/gitingest/output_formatter.py index 034380ce..fb1338fb 100644 --- a/src/gitingest/output_formatter.py +++ b/src/gitingest/output_formatter.py @@ -118,8 +118,8 @@ def _gather_file_contents(node: FileSystemNode) -> str: def _create_tree_structure( query: IngestionQuery, - node: FileSystemNode, *, + node: FileSystemNode, prefix: str = "", is_last: bool = True, ) -> str: diff --git a/src/gitingest/utils/auth.py b/src/gitingest/utils/auth.py index 6fe241d1..50185dd3 100644 --- a/src/gitingest/utils/auth.py +++ b/src/gitingest/utils/auth.py @@ -11,7 +11,7 @@ def resolve_token(token: str | None) -> str | None: Parameters ---------- token : str | None - The token to use for the query. + GitHub personal access token (PAT) for accessing private repositories. Returns ------- diff --git a/src/gitingest/utils/git_utils.py b/src/gitingest/utils/git_utils.py index 104d71a4..857e1417 100644 --- a/src/gitingest/utils/git_utils.py +++ b/src/gitingest/utils/git_utils.py @@ -253,7 +253,7 @@ async def fetch_remote_branch_list(url: str, token: str | None = None) -> list[s # Add authentication if needed if token and is_github_host(url): - fetch_branches_command += ["-c", create_git_auth_header(token, url)] + fetch_branches_command += ["-c", create_git_auth_header(token, url=url)] fetch_branches_command += ["ls-remote", "--heads", url] @@ -274,18 +274,18 @@ def create_git_command(base_cmd: list[str], local_path: str, url: str, token: st Parameters ---------- base_cmd : list[str] - The base git command to start with + The base git command to start with. local_path : str - The local path where the git command should be executed + The local path where the git command should be executed. url : str - The repository URL to check if it's a GitHub repository + The repository URL to check if it's a GitHub repository. token : str | None - GitHub personal access token for authentication + GitHub personal access token (PAT) for accessing private repositories. Returns ------- list[str] - The git command with authentication if needed + The git command with authentication if needed. """ cmd = [*base_cmd, "-C", local_path] @@ -301,7 +301,7 @@ def create_git_auth_header(token: str, url: str = "https://github.com") -> str: Parameters ---------- token : str - GitHub personal access token + GitHub personal access token (PAT) for accessing private repositories. url : str The GitHub URL to create the authentication header for. Defaults to "https://github.com" if not provided. @@ -309,7 +309,7 @@ def create_git_auth_header(token: str, url: str = "https://github.com") -> str: Returns ------- str - The git config command for setting the authentication header + The git config command for setting the authentication header. """ hostname = urlparse(url).hostname @@ -323,12 +323,12 @@ def validate_github_token(token: str) -> None: Parameters ---------- token : str - The GitHub token to validate + GitHub personal access token (PAT) for accessing private repositories. Raises ------ InvalidGitHubTokenError - If the token format is invalid + If the token format is invalid. """ if not re.match(GITHUB_PAT_PATTERN, token): diff --git a/src/gitingest/utils/os_utils.py b/src/gitingest/utils/os_utils.py index 438c108c..d90dddd2 100644 --- a/src/gitingest/utils/os_utils.py +++ b/src/gitingest/utils/os_utils.py @@ -9,12 +9,12 @@ async def ensure_directory(path: Path) -> None: Parameters ---------- path : Path - The path to ensure exists + The path to ensure exists. Raises ------ OSError - If the directory cannot be created + If the directory cannot be created. """ try: diff --git a/src/server/models.py b/src/server/models.py index 614c79f7..2383aa95 100644 --- a/src/server/models.py +++ b/src/server/models.py @@ -22,7 +22,7 @@ class QueryForm(BaseModel): pattern : str Glob/regex pattern string. token : str | None - GitHub personal-access token (``None`` if omitted). + GitHub personal access token (PAT) for accessing private repositories. """ @@ -54,7 +54,7 @@ def as_form( pattern : StrForm Glob/regex pattern string. token : OptStrForm - GitHub personal-access token (``None`` if omitted). + GitHub personal access token (PAT) for accessing private repositories. Returns ------- diff --git a/src/server/query_processor.py b/src/server/query_processor.py index 806becbd..0577f963 100644 --- a/src/server/query_processor.py +++ b/src/server/query_processor.py @@ -52,7 +52,7 @@ async def process_query( is_index : bool Flag indicating whether the request is for the index page (default: ``False``). token : str | None - GitHub personal-access token (PAT). Needed when the repository is private. + GitHub personal access token (PAT) for accessing private repositories. Returns ------- diff --git a/src/server/server_utils.py b/src/server/server_utils.py index d1e2eeea..b0371661 100644 --- a/src/server/server_utils.py +++ b/src/server/server_utils.py @@ -83,7 +83,7 @@ async def _remove_old_repositories( Parameters ---------- base_path : Path - The path to the base directory where repositories are stored (default: ``TMP_BASE_PATH``) + The path to the base directory where repositories are stored (default: ``TMP_BASE_PATH``). scan_interval : int The number of seconds between scans (default: 60). delete_after : int diff --git a/tests/test_cli.py b/tests/test_cli.py index ab018e32..f001b84a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -36,8 +36,8 @@ def test_cli_writes_file( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, - cli_args: list[str], *, + cli_args: list[str], expect_file: bool, ) -> None: """Run the CLI and verify that the SARIF file is created (or not).""" diff --git a/tests/test_clone.py b/tests/test_clone.py index 72284a35..a855ca36 100644 --- a/tests/test_clone.py +++ b/tests/test_clone.py @@ -96,8 +96,8 @@ async def test_clone_nonexistent_repository(repo_exists_true: AsyncMock) -> None ) async def test_check_repo_exists( mock_stdout: bytes, - return_code: int, *, + return_code: int, expected: bool, mocker: MockerFixture, ) -> None: From 4512cfb479b824219e64a384bd73cb7f99ab77d5 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sat, 28 Jun 2025 18:41:56 +0200 Subject: [PATCH 5/5] =?UTF-8?q?Vulnerable=20to...=20=E2=80=93>=20Minimum?= =?UTF-8?q?=20safe=20release=20(...)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index af518476..5171396d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,15 +6,15 @@ readme = {file = "README.md", content-type = "text/markdown" } requires-python = ">= 3.8" dependencies = [ "click>=8.0.0", - "fastapi[standard]>=0.109.1", # Vulnerable to https://osv.dev/vulnerability/PYSEC-2024-38 + "fastapi[standard]>=0.109.1", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2024-38) "pydantic", "python-dotenv", "slowapi", - "starlette>=0.40.0", # Vulnerable to https://osv.dev/vulnerability/GHSA-f96h-pmfr-66vw + "starlette>=0.40.0", # Minimum safe release (https://osv.dev/vulnerability/GHSA-f96h-pmfr-66vw) "tiktoken>=0.7.0", # Support for o200k_base encoding "pathspec>=0.12.1", "typing_extensions>= 4.0.0; python_version < '3.10'", - "uvicorn>=0.11.7", # Vulnerable to https://osv.dev/vulnerability/PYSEC-2020-150 + "uvicorn>=0.11.7", # Minimum safe release (https://osv.dev/vulnerability/PYSEC-2020-150) ] license = {file = "LICENSE"}