From 63b5f232ac875f20101dcbba6703c5987326d772 Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Sun, 19 Jan 2025 11:58:55 +0100 Subject: [PATCH 1/2] Refactor code for lifespan, template usage, and improved tests - Move background tasks and rate-limit handler into utils.py - Reference TEMPLATES from config instead of inline Jinja2Templates - Adopt Given/When/Then docstrings for test clarity - Parametrize some tests and consolidate code across query_parser tests - Add pytest.warns context handler to test_parse_repo_source_with_failed_git_command --- src/config.py | 4 + src/gitingest/query_parser.py | 2 +- src/main.py | 141 +--------- src/query_processor.py | 7 +- src/routers/dynamic.py | 5 +- src/routers/index.py | 6 +- src/utils.py | 138 ++++++++++ tests/conftest.py | 76 ++++-- tests/query_parser/test_git_host_agnostic.py | 15 +- tests/query_parser/test_query_parser.py | 262 ++++++++++++++----- tests/test_notebook_utils.py | 131 ++++++---- tests/test_query_ingestion.py | 171 +++++++----- tests/test_repository_clone.py | 152 +++++++---- 13 files changed, 707 insertions(+), 403 deletions(-) create mode 100644 src/utils.py diff --git a/src/config.py b/src/config.py index 7365ab8b..9d9c2113 100644 --- a/src/config.py +++ b/src/config.py @@ -2,6 +2,8 @@ from pathlib import Path +from fastapi.templating import Jinja2Templates + MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB MAX_DIRECTORY_DEPTH = 20 # Maximum depth of directory traversal MAX_FILES = 10_000 # Maximum number of files to process @@ -18,3 +20,5 @@ {"name": "Tldraw", "url": "https://github.com/tldraw/tldraw"}, {"name": "ApiAnalytics", "url": "https://github.com/tom-draper/api-analytics"}, ] + +TEMPLATES = Jinja2Templates(directory="templates") diff --git a/src/gitingest/query_parser.py b/src/gitingest/query_parser.py index 435a7996..272ae2d6 100644 --- a/src/gitingest/query_parser.py +++ b/src/gitingest/query_parser.py @@ -227,7 +227,7 @@ async def _configure_branch_and_subpath(remaining_parts: list[str], url: str) -> # Fetch the list of branches from the remote repository branches: list[str] = await fetch_remote_branch_list(url) except RuntimeError as e: - warnings.warn(f"Warning: Failed to fetch branch list: {e}") + warnings.warn(f"Warning: Failed to fetch branch list: {e}", RuntimeWarning) return remaining_parts.pop(0) branch = [] diff --git a/src/main.py b/src/main.py index 556b3e1d..241e9458 100644 --- a/src/main.py +++ b/src/main.py @@ -1,157 +1,27 @@ """ Main module for the FastAPI application. """ -import asyncio import os -import shutil -import time -from contextlib import asynccontextmanager -from pathlib import Path from api_analytics.fastapi import Analytics from dotenv import load_dotenv from fastapi import FastAPI, Request -from fastapi.responses import FileResponse, HTMLResponse, Response +from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates -from slowapi import _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from starlette.middleware.trustedhost import TrustedHostMiddleware -from config import DELETE_REPO_AFTER, TMP_BASE_PATH +from config import TEMPLATES from routers import download, dynamic, index from server_utils import limiter +from utils import lifespan, rate_limit_exception_handler # Load environment variables from .env file load_dotenv() - -async def remove_old_repositories(): - """ - Background task that runs periodically to clean up old repository directories. - - This task: - - Scans the TMP_BASE_PATH directory every 60 seconds - - Removes directories older than DELETE_REPO_AFTER seconds - - Before deletion, logs repository URLs to history.txt if a matching .txt file exists - - Handles errors gracefully if deletion fails - - The repository URL is extracted from the first .txt file in each directory, - assuming the filename format: "owner-repository.txt" - """ - while True: - try: - if not TMP_BASE_PATH.exists(): - await asyncio.sleep(60) - continue - - current_time = time.time() - - for folder in TMP_BASE_PATH.iterdir(): - if not folder.is_dir(): - continue - - # Skip if folder is not old enough - if current_time - folder.stat().st_ctime <= DELETE_REPO_AFTER: - continue - - await process_folder(folder) - - except Exception as e: - print(f"Error in remove_old_repositories: {e}") - - await asyncio.sleep(60) - - -async def process_folder(folder: Path) -> None: - """ - Process a single folder for deletion and logging. - - Parameters - ---------- - folder : Path - The path to the folder to be processed. - """ - # Try to log repository URL before deletion - try: - txt_files = [f for f in folder.iterdir() if f.suffix == ".txt"] - - # Extract owner and repository name from the filename - if txt_files and "-" in (filename := txt_files[0].stem): - owner, repo = filename.split("-", 1) - repo_url = f"{owner}/{repo}" - with open("history.txt", mode="a", encoding="utf-8") as history: - history.write(f"{repo_url}\n") - - except Exception as e: - print(f"Error logging repository URL for {folder}: {e}") - - # Delete the folder - try: - shutil.rmtree(folder) - except Exception as e: - print(f"Error deleting {folder}: {e}") - - -@asynccontextmanager -async def lifespan(_: FastAPI): - """ - Lifecycle manager for the FastAPI application. - Handles startup and shutdown events. - - Parameters - ---------- - _ : FastAPI - The FastAPI application instance (unused). - - Yields - ------- - None - Yields control back to the FastAPI application while the background task runs. - """ - task = asyncio.create_task(remove_old_repositories()) - - yield - # Cancel the background task on shutdown - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Initialize the FastAPI application with lifespan app = FastAPI(lifespan=lifespan) app.state.limiter = limiter - -async def rate_limit_exception_handler(request: Request, exc: Exception) -> Response: - """ - Custom exception handler for rate-limiting errors. - - Parameters - ---------- - request : Request - The incoming HTTP request. - exc : Exception - The exception raised, expected to be RateLimitExceeded. - - Returns - ------- - Response - A response indicating that the rate limit has been exceeded. - - Raises - ------ - exc - If the exception is not a RateLimitExceeded error, it is re-raised. - """ - if isinstance(exc, RateLimitExceeded): - # Delegate to the default rate limit handler - return _rate_limit_exceeded_handler(request, exc) - # Re-raise other exceptions - raise exc - - # Register the custom exception handler for rate limits app.add_exception_handler(RateLimitExceeded, rate_limit_exception_handler) @@ -174,9 +44,6 @@ async def rate_limit_exception_handler(request: Request, exc: Exception) -> Resp # Add middleware to enforce allowed hosts app.add_middleware(TrustedHostMiddleware, allowed_hosts=allowed_hosts) -# Set up template rendering -templates = Jinja2Templates(directory="templates") - @app.get("/health") async def health_check() -> dict[str, str]: @@ -222,7 +89,7 @@ async def api_docs(request: Request) -> HTMLResponse: HTMLResponse A rendered HTML page displaying API documentation. """ - return templates.TemplateResponse("api.jinja", {"request": request}) + return TEMPLATES.TemplateResponse("api.jinja", {"request": request}) @app.get("/robots.txt") diff --git a/src/query_processor.py b/src/query_processor.py index 62f1c83f..72603592 100644 --- a/src/query_processor.py +++ b/src/query_processor.py @@ -3,17 +3,14 @@ from functools import partial from fastapi import Request -from fastapi.templating import Jinja2Templates from starlette.templating import _TemplateResponse -from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE +from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE, TEMPLATES from gitingest.query_ingestion import run_ingest_query from gitingest.query_parser import ParsedQuery, parse_query from gitingest.repository_clone import CloneConfig, clone_repo from server_utils import Colors, log_slider_to_size -templates = Jinja2Templates(directory="templates") - async def process_query( request: Request, @@ -64,7 +61,7 @@ async def process_query( raise ValueError(f"Invalid pattern type: {pattern_type}") template = "index.jinja" if is_index else "git.jinja" - template_response = partial(templates.TemplateResponse, name=template) + template_response = partial(TEMPLATES.TemplateResponse, name=template) max_file_size = log_slider_to_size(slider_position) context = { diff --git a/src/routers/dynamic.py b/src/routers/dynamic.py index 0787fbfa..48e6c080 100644 --- a/src/routers/dynamic.py +++ b/src/routers/dynamic.py @@ -2,13 +2,12 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse -from fastapi.templating import Jinja2Templates +from config import TEMPLATES from query_processor import process_query from server_utils import limiter router = APIRouter() -templates = Jinja2Templates(directory="templates") @router.get("/{full_path:path}") @@ -32,7 +31,7 @@ async def catch_all(request: Request, full_path: str) -> HTMLResponse: An HTML response containing the rendered template, with the Git URL and other default parameters such as loading state and file size. """ - return templates.TemplateResponse( + return TEMPLATES.TemplateResponse( "git.jinja", { "request": request, diff --git a/src/routers/index.py b/src/routers/index.py index b338c301..b5d2f6c9 100644 --- a/src/routers/index.py +++ b/src/routers/index.py @@ -2,14 +2,12 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse -from fastapi.templating import Jinja2Templates -from config import EXAMPLE_REPOS +from config import EXAMPLE_REPOS, TEMPLATES from query_processor import process_query from server_utils import limiter router = APIRouter() -templates = Jinja2Templates(directory="templates") @router.get("/", response_class=HTMLResponse) @@ -31,7 +29,7 @@ async def home(request: Request) -> HTMLResponse: An HTML response containing the rendered home page template, with example repositories and other default parameters such as file size. """ - return templates.TemplateResponse( + return TEMPLATES.TemplateResponse( "index.jinja", { "request": request, diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 00000000..7c968dc0 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,138 @@ +""" Utility functions for the FastAPI server. """ + +import asyncio +import shutil +import time +from contextlib import asynccontextmanager +from pathlib import Path + +from fastapi import FastAPI, Request +from fastapi.responses import Response +from slowapi import _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded + +from config import DELETE_REPO_AFTER, TMP_BASE_PATH + + +async def rate_limit_exception_handler(request: Request, exc: Exception) -> Response: + """ + Custom exception handler for rate-limiting errors. + + Parameters + ---------- + request : Request + The incoming HTTP request. + exc : Exception + The exception raised, expected to be RateLimitExceeded. + + Returns + ------- + Response + A response indicating that the rate limit has been exceeded. + + Raises + ------ + exc + If the exception is not a RateLimitExceeded error, it is re-raised. + """ + if isinstance(exc, RateLimitExceeded): + # Delegate to the default rate limit handler + return _rate_limit_exceeded_handler(request, exc) + # Re-raise other exceptions + raise exc + + +@asynccontextmanager +async def lifespan(_: FastAPI): + """ + Lifecycle manager for handling startup and shutdown events for the FastAPI application. + + Parameters + ---------- + _ : FastAPI + The FastAPI application instance (unused). + + Yields + ------- + None + Yields control back to the FastAPI application while the background task runs. + """ + task = asyncio.create_task(_remove_old_repositories()) + + yield + # Cancel the background task on shutdown + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +async def _remove_old_repositories(): + """ + Periodically remove old repository folders. + + Background task that runs periodically to clean up old repository directories. + + This task: + - Scans the TMP_BASE_PATH directory every 60 seconds + - Removes directories older than DELETE_REPO_AFTER seconds + - Before deletion, logs repository URLs to history.txt if a matching .txt file exists + - Handles errors gracefully if deletion fails + + The repository URL is extracted from the first .txt file in each directory, + assuming the filename format: "owner-repository.txt" + """ + while True: + try: + if not TMP_BASE_PATH.exists(): + await asyncio.sleep(60) + continue + + current_time = time.time() + + for folder in TMP_BASE_PATH.iterdir(): + if folder.is_dir(): + continue + + # Skip if folder is not old enough + if current_time - folder.stat().st_ctime <= DELETE_REPO_AFTER: + continue + + await _process_folder(folder) + + except Exception as e: + print(f"Error in _remove_old_repositories: {e}") + + await asyncio.sleep(60) + + +async def _process_folder(folder: Path) -> None: + """ + Process a single folder for deletion and logging. + + Parameters + ---------- + folder : Path + The path to the folder to be processed. + """ + # Try to log repository URL before deletion + try: + txt_files = [f for f in folder.iterdir() if f.suffix == ".txt"] + + # Extract owner and repository name from the filename + if txt_files and "-" in (filename := txt_files[0].stem): + owner, repo = filename.split("-", 1) + repo_url = f"{owner}/{repo}" + + with open("history.txt", mode="a", encoding="utf-8") as history: + history.write(f"{repo_url}\n") + + except Exception as e: + print(f"Error logging repository URL for {folder}: {e}") + + # Delete the folder + try: + shutil.rmtree(folder) + except Exception as e: + print(f"Error deleting {folder}: {e}") diff --git a/tests/conftest.py b/tests/conftest.py index c11ee726..507d1f51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,12 @@ -""" This module contains fixtures for the tests. """ +""" +Fixtures for tests. + +This file provides shared fixtures for creating sample queries, a temporary directory structure, and a helper function +to write `.ipynb` notebooks for testing notebook utilities. +""" import json +from collections.abc import Callable from pathlib import Path from typing import Any @@ -8,9 +14,21 @@ from gitingest.query_parser import ParsedQuery +WriteNotebookFunc = Callable[[str, dict[str, Any]], Path] + @pytest.fixture def sample_query() -> ParsedQuery: + """ + Provide a default `ParsedQuery` object for use in tests. + + This fixture returns a `ParsedQuery` pre-populated with typical fields and some default ignore patterns. + + Returns + ------- + ParsedQuery + The sample `ParsedQuery` object. + """ return ParsedQuery( user_name="test_user", repo_name="test_repo", @@ -30,22 +48,33 @@ def sample_query() -> ParsedQuery: @pytest.fixture def temp_directory(tmp_path: Path) -> Path: """ - # Creates the following structure: - # test_repo/ - # ├── file1.txt - # ├── file2.py - # └── src/ - # | ├── subfile1.txt - # | └── subfile2.py - # | └── subdir/ - # | └── file_subdir.txt - # | └── file_subdir.py - # └── dir1/ - # | └── file_dir1.txt - # └── dir2/ - # └── file_dir2.txt + Create a temporary directory structure for testing repository scanning. + + The structure includes: + test_repo/ + ├── file1.txt + ├── file2.py + ├── src/ + │ ├── subfile1.txt + │ ├── subfile2.py + │ └── subdir/ + │ ├── file_subdir.txt + │ └── file_subdir.py + ├── dir1/ + │ └── file_dir1.txt + └── dir2/ + └── file_dir2.txt + + Parameters + ---------- + tmp_path : Path + The temporary directory path provided by the `tmp_path` fixture. + + Returns + ------- + Path + The path to the created `test_repo` directory. """ - test_dir = tmp_path / "test_repo" test_dir.mkdir() @@ -79,9 +108,20 @@ def temp_directory(tmp_path: Path) -> Path: @pytest.fixture -def write_notebook(tmp_path: Path): +def write_notebook(tmp_path: Path) -> WriteNotebookFunc: """ - A fixture that returns a helper function to write a .ipynb notebook file at runtime with given content. + Provide a helper function to write a `.ipynb` notebook file with the given content. + + Parameters + ---------- + tmp_path : Path + The temporary directory path provided by the `tmp_path` fixture. + + Returns + ------- + WriteNotebookFunc + A callable that accepts a filename and a dictionary (representing JSON notebook data), writes it to a `.ipynb` + file, and returns the path to the file. """ def _write_notebook(name: str, content: dict[str, Any]) -> Path: diff --git a/tests/query_parser/test_git_host_agnostic.py b/tests/query_parser/test_git_host_agnostic.py index 9831362d..b35d9184 100644 --- a/tests/query_parser/test_git_host_agnostic.py +++ b/tests/query_parser/test_git_host_agnostic.py @@ -1,4 +1,9 @@ -""" Tests to verify that the query parser is Git host agnostic. """ +""" +Tests to verify that the query parser is Git host agnostic. + +These tests confirm that `parse_query` correctly identifies user/repo pairs and canonical URLs for GitHub, GitLab, +Bitbucket, Gitea, and Codeberg, even if the host is omitted. +""" import pytest @@ -67,8 +72,16 @@ async def test_parse_query_without_host( expected_repo: str, expected_url: str, ) -> None: + """ + Test `parse_query` for Git host agnosticism. + + Given multiple URL variations for the same user/repo on different Git hosts (with or without host names): + When `parse_query` is called with each variation, + Then the parser should correctly identify the user, repo, canonical URL, and other default fields. + """ for url in urls: parsed_query = await parse_query(url, max_file_size=50, from_web=True) + assert parsed_query.user_name == expected_user assert parsed_query.repo_name == expected_repo assert parsed_query.url == expected_url diff --git a/tests/query_parser/test_query_parser.py b/tests/query_parser/test_query_parser.py index fc5fb0aa..8b828909 100644 --- a/tests/query_parser/test_query_parser.py +++ b/tests/query_parser/test_query_parser.py @@ -1,4 +1,9 @@ -""" Tests for the query_parser module. """ +""" +Tests for the `query_parser` module. + +These tests cover URL parsing, pattern parsing, and handling of branches/subpaths for HTTP(S) repositories and local +paths. +""" from pathlib import Path from unittest.mock import AsyncMock, patch @@ -9,10 +14,14 @@ from gitingest.query_parser import _parse_patterns, _parse_repo_source, parse_query +@pytest.mark.asyncio async def test_parse_url_valid_https() -> None: """ - Test `_parse_repo_source` with valid HTTPS URLs from supported platforms (GitHub, GitLab, Bitbucket, Gitea). - Verifies that user and repository names are correctly extracted. + Test `_parse_repo_source` with valid HTTPS URLs. + + Given various HTTPS URLs on supported platforms: + When `_parse_repo_source` is called, + Then user name, repo name, and the URL should be extracted correctly. """ test_cases = [ "https://github.com/user/repo", @@ -24,15 +33,20 @@ async def test_parse_url_valid_https() -> None: ] for url in test_cases: parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" assert parsed_query.repo_name == "repo" assert parsed_query.url == url +@pytest.mark.asyncio async def test_parse_url_valid_http() -> None: """ - Test `_parse_repo_source` with valid HTTP URLs from supported platforms. - Verifies that user and repository names, as well as the slug, are correctly extracted. + Test `_parse_repo_source` with valid HTTP URLs. + + Given various HTTP URLs on supported platforms: + When `_parse_repo_source` is called, + Then user name, repo name, and the slug should be extracted correctly. """ test_cases = [ "http://github.com/user/repo", @@ -44,71 +58,99 @@ async def test_parse_url_valid_http() -> None: ] for url in test_cases: parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" assert parsed_query.repo_name == "repo" assert parsed_query.slug == "user-repo" +@pytest.mark.asyncio async def test_parse_url_invalid() -> None: """ - Test `_parse_repo_source` with an invalid URL that does not include a repository structure. - Verifies that a ValueError is raised with an appropriate error message. + Test `_parse_repo_source` with an invalid URL. + + Given an HTTPS URL lacking a repository structure (e.g., "https://github.com"), + When `_parse_repo_source` is called, + Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com" with pytest.raises(ValueError, match="Invalid repository URL"): await _parse_repo_source(url) -async def test_parse_query_basic() -> None: +@pytest.mark.asyncio +@pytest.mark.parametrize("url", ["https://github.com/user/repo", "https://gitlab.com/user/repo"]) +async def test_parse_query_basic(url): """ - Test `parse_query` with basic inputs including valid repository URLs. - Verifies that user and repository names, URL, and ignore patterns are correctly parsed. + Test `parse_query` with a basic valid repository URL. + + Given an HTTPS URL and ignore_patterns="*.txt": + When `parse_query` is called, + Then user/repo, URL, and ignore patterns should be parsed correctly. """ - test_cases = ["https://github.com/user/repo", "https://gitlab.com/user/repo"] - for url in test_cases: - parsed_query = await parse_query(url, max_file_size=50, from_web=True, ignore_patterns="*.txt") - assert parsed_query.user_name == "user" - assert parsed_query.repo_name == "repo" - assert parsed_query.url == url - assert parsed_query.ignore_patterns - assert "*.txt" in parsed_query.ignore_patterns + parsed_query = await parse_query(source=url, max_file_size=50, from_web=True, ignore_patterns="*.txt") + + assert parsed_query.user_name == "user" + assert parsed_query.repo_name == "repo" + assert parsed_query.url == url + assert parsed_query.ignore_patterns + assert "*.txt" in parsed_query.ignore_patterns +@pytest.mark.asyncio async def test_parse_query_mixed_case() -> None: """ - Test `parse_query` with mixed case URLs. + Test `parse_query` with mixed-case URLs. + + Given a URL with mixed-case parts (e.g. "Https://GitHub.COM/UsEr/rEpO"): + When `parse_query` is called, + Then the user and repo names should be normalized to lowercase. """ url = "Https://GitHub.COM/UsEr/rEpO" parsed_query = await parse_query(url, max_file_size=50, from_web=True) + assert parsed_query.user_name == "user" assert parsed_query.repo_name == "repo" +@pytest.mark.asyncio async def test_parse_query_include_pattern() -> None: """ - Test `parse_query` with an include pattern. - Verifies that the include pattern is set correctly and default ignore patterns are applied. + Test `parse_query` with a specified include pattern. + + Given a URL and include_patterns="*.py": + When `parse_query` is called, + Then the include pattern should be set, and default ignore patterns remain applied. """ url = "https://github.com/user/repo" parsed_query = await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py") + assert parsed_query.include_patterns == {"*.py"} assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS +@pytest.mark.asyncio async def test_parse_query_invalid_pattern() -> None: """ - Test `parse_query` with an invalid pattern containing special characters. - Verifies that a ValueError is raised with an appropriate error message. + Test `parse_query` with an invalid pattern. + + Given an include pattern containing special characters (e.g., "*.py;rm -rf"): + When `parse_query` is called, + Then a ValueError should be raised indicating invalid characters. """ url = "https://github.com/user/repo" with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): await parse_query(url, max_file_size=50, from_web=True, include_patterns="*.py;rm -rf") +@pytest.mark.asyncio async def test_parse_url_with_subpaths() -> None: """ - Test `_parse_repo_source` with a URL containing a branch and subpath. - Verifies that user name, repository name, branch, and subpath are correctly extracted. + Test `_parse_repo_source` with a URL containing branch and subpath. + + Given a URL referencing a branch ("main") and a subdir ("subdir/file"): + When `_parse_repo_source` is called with remote branch fetching, + Then user, repo, branch, and subpath should be identified correctly. """ url = "https://github.com/user/repo/tree/main/subdir/file" with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command: @@ -118,16 +160,21 @@ async def test_parse_url_with_subpaths() -> None: ) as mock_fetch_branches: mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" assert parsed_query.repo_name == "repo" assert parsed_query.branch == "main" assert parsed_query.subpath == "/subdir/file" +@pytest.mark.asyncio async def test_parse_url_invalid_repo_structure() -> None: """ - Test `_parse_repo_source` with an invalid repository structure in the URL. - Verifies that a ValueError is raised with an appropriate error message. + Test `_parse_repo_source` with a URL missing a repository name. + + Given a URL like "https://github.com/user": + When `_parse_repo_source` is called, + Then a ValueError should be raised indicating an invalid repository URL. """ url = "https://github.com/user" with pytest.raises(ValueError, match="Invalid repository URL"): @@ -136,50 +183,71 @@ async def test_parse_url_invalid_repo_structure() -> None: def test_parse_patterns_valid() -> None: """ - Test `_parse_patterns` with valid patterns separated by commas. - Verifies that the patterns are correctly parsed into a list. + Test `_parse_patterns` with valid comma-separated patterns. + + Given patterns like "*.py, *.md, docs/*": + When `_parse_patterns` is called, + Then it should return a set of parsed strings. """ patterns = "*.py, *.md, docs/*" parsed_patterns = _parse_patterns(patterns) + assert parsed_patterns == {"*.py", "*.md", "docs/*"} def test_parse_patterns_invalid_characters() -> None: """ - Test `_parse_patterns` with invalid patterns containing special characters. - Verifies that a ValueError is raised with an appropriate error message. + Test `_parse_patterns` with invalid characters. + + Given a pattern string containing special characters (e.g. "*.py;rm -rf"): + When `_parse_patterns` is called, + Then a ValueError should be raised indicating invalid pattern syntax. """ patterns = "*.py;rm -rf" with pytest.raises(ValueError, match="Pattern.*contains invalid characters"): _parse_patterns(patterns) +@pytest.mark.asyncio async def test_parse_query_with_large_file_size() -> None: """ Test `parse_query` with a very large file size limit. - Verifies that the file size limit and default ignore patterns are set correctly. + + Given a URL and max_file_size=10**9: + When `parse_query` is called, + Then `max_file_size` should be set correctly and default ignore patterns remain unchanged. """ url = "https://github.com/user/repo" parsed_query = await parse_query(url, max_file_size=10**9, from_web=True) + assert parsed_query.max_file_size == 10**9 assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS +@pytest.mark.asyncio async def test_parse_query_empty_patterns() -> None: """ - Test `parse_query` with empty include and ignore patterns. - Verifies that the include patterns are set to None and default ignore patterns are applied. + Test `parse_query` with empty patterns. + + Given empty include_patterns and ignore_patterns: + When `parse_query` is called, + Then include_patterns becomes None and default ignore patterns apply. """ url = "https://github.com/user/repo" parsed_query = await parse_query(url, max_file_size=50, from_web=True, include_patterns="", ignore_patterns="") + assert parsed_query.include_patterns is None assert parsed_query.ignore_patterns == DEFAULT_IGNORE_PATTERNS +@pytest.mark.asyncio async def test_parse_query_include_and_ignore_overlap() -> None: """ - Test `parse_query` with overlapping include and ignore patterns. - Verifies that overlapping patterns are removed from the ignore patterns. + Test `parse_query` with overlapping patterns. + + Given include="*.py" and ignore={"*.py", "*.txt"}: + When `parse_query` is called, + Then "*.py" should be removed from ignore patterns. """ url = "https://github.com/user/repo" parsed_query = await parse_query( @@ -189,102 +257,155 @@ async def test_parse_query_include_and_ignore_overlap() -> None: include_patterns="*.py", ignore_patterns={"*.py", "*.txt"}, ) + assert parsed_query.include_patterns == {"*.py"} assert parsed_query.ignore_patterns is not None assert "*.py" not in parsed_query.ignore_patterns assert "*.txt" in parsed_query.ignore_patterns +@pytest.mark.asyncio async def test_parse_query_local_path() -> None: """ Test `parse_query` with a local file path. - Verifies that the local path is set, a unique ID is generated, and the slug is correctly created. + + Given "/home/user/project" and from_web=False: + When `parse_query` is called, + Then the local path should be set, id generated, and slug formed accordingly. """ path = "/home/user/project" parsed_query = await parse_query(path, max_file_size=100, from_web=False) tail = Path("home/user/project") + assert parsed_query.local_path.parts[-len(tail.parts) :] == tail.parts assert parsed_query.id is not None assert parsed_query.slug == "user/project" +@pytest.mark.asyncio async def test_parse_query_relative_path() -> None: """ - Test `parse_query` with a relative file path. - Verifies that the local path and slug are correctly resolved. + Test `parse_query` with a relative path. + + Given "./project" and from_web=False: + When `parse_query` is called, + Then local_path resolves relatively, and slug ends with "project". """ path = "./project" parsed_query = await parse_query(path, max_file_size=100, from_web=False) tail = Path("project") + assert parsed_query.local_path.parts[-len(tail.parts) :] == tail.parts assert parsed_query.slug.endswith("project") +@pytest.mark.asyncio async def test_parse_query_empty_source() -> None: """ - Test `parse_query` with an empty source input. - Verifies that a ValueError is raised with an appropriate error message. + Test `parse_query` with an empty string. + + Given an empty source string: + When `parse_query` is called, + Then a ValueError should be raised indicating an invalid repository URL. """ with pytest.raises(ValueError, match="Invalid repository URL"): await parse_query("", max_file_size=100, from_web=True) -async def test_parse_url_branch_and_commit_distinction() -> None: - """ - Test `_parse_repo_source` with URLs containing either a branch name or a commit hash. - Verifies that the branch and commit are correctly distinguished. +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url, expected_branch, expected_commit", + [ + ("https://github.com/user/repo/tree/main", "main", None), + ( + "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234", + None, + "abcd1234abcd1234abcd1234abcd1234abcd1234", + ), + ], +) +async def test_parse_url_branch_and_commit_distinction(url: str, expected_branch: str, expected_commit: str) -> None: """ - url_branch = "https://github.com/user/repo/tree/main" - url_commit = "https://github.com/user/repo/tree/abcd1234abcd1234abcd1234abcd1234abcd1234" + Test `_parse_repo_source` distinguishing branch vs. commit hash. + Given either a branch URL (e.g., ".../tree/main") or a 40-character commit URL: + When `_parse_repo_source` is called with branch fetching, + Then the function should correctly set `branch` or `commit` based on the URL content. + """ with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command: + # Mocking the return value to include 'main' and some additional branches mock_run_git_command.return_value = (b"refs/heads/main\nrefs/heads/dev\nrefs/heads/feature-branch\n", b"") with patch( "gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock ) as mock_fetch_branches: mock_fetch_branches.return_value = ["main", "dev", "feature-branch"] - parsed_query_with_branch = await _parse_repo_source(url_branch) - parsed_query_with_commit = await _parse_repo_source(url_commit) - - assert parsed_query_with_branch.branch == "main" - assert parsed_query_with_branch.commit is None + parsed_query = await _parse_repo_source(url) - assert parsed_query_with_commit.branch is None - assert parsed_query_with_commit.commit == "abcd1234abcd1234abcd1234abcd1234abcd1234" + # Verify that `branch` and `commit` match our expectations + assert parsed_query.branch == expected_branch + assert parsed_query.commit == expected_commit +@pytest.mark.asyncio async def test_parse_query_uuid_uniqueness() -> None: """ - Test `parse_query` to ensure that each call generates a unique UUID for the query. + Test `parse_query` for unique UUID generation. + + Given the same path twice: + When `parse_query` is called repeatedly, + Then each call should produce a different query id. """ path = "/home/user/project" parsed_query_1 = await parse_query(path, max_file_size=100, from_web=False) parsed_query_2 = await parse_query(path, max_file_size=100, from_web=False) + assert parsed_query_1.id != parsed_query_2.id +@pytest.mark.asyncio async def test_parse_url_with_query_and_fragment() -> None: """ - Test `_parse_repo_source` with a URL containing query parameters and a fragment. - Verifies that the URL is cleaned and other fields are correctly extracted. + Test `_parse_repo_source` with query parameters and a fragment. + + Given a URL like "https://github.com/user/repo?arg=value#fragment": + When `_parse_repo_source` is called, + Then those parts should be stripped, leaving a clean user/repo URL. """ url = "https://github.com/user/repo?arg=value#fragment" parsed_query = await _parse_repo_source(url) + assert parsed_query.user_name == "user" assert parsed_query.repo_name == "repo" assert parsed_query.url == "https://github.com/user/repo" # URL should be cleaned +@pytest.mark.asyncio async def test_parse_url_unsupported_host() -> None: + """ + Test `_parse_repo_source` with an unsupported host. + + Given "https://only-domain.com": + When `_parse_repo_source` is called, + Then a ValueError should be raised for the unknown domain. + """ url = "https://only-domain.com" with pytest.raises(ValueError, match="Unknown domain 'only-domain.com' in URL"): await _parse_repo_source(url) +@pytest.mark.asyncio async def test_parse_query_with_branch() -> None: + """ + Test `parse_query` when a branch is specified in a blob path. + + Given "https://github.com/pandas-dev/pandas/blob/2.2.x/...": + When `parse_query` is called, + Then the branch should be identified, subpath set, and commit remain None. + """ url = "https://github.com/pandas-dev/pandas/blob/2.2.x/.github/ISSUE_TEMPLATE/documentation_improvement.yaml" parsed_query = await parse_query(url, max_file_size=10**9, from_web=True) + assert parsed_query.user_name == "pandas-dev" assert parsed_query.repo_name == "pandas" assert parsed_query.url == "https://github.com/pandas-dev/pandas" @@ -307,16 +428,25 @@ async def test_parse_query_with_branch() -> None: ) async def test_parse_repo_source_with_failed_git_command(url, expected_branch, expected_subpath): """ - Test `_parse_repo_source` when git command fails. - Verifies that the function returns the first path component as the branch. + Test `_parse_repo_source` when git fetch fails. + + Given a URL referencing a branch, but Git fetching fails: + When `_parse_repo_source` is called, + Then it should fall back to path components for branch identification. """ with patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches: mock_fetch_branches.side_effect = Exception("Failed to fetch branch list") - parsed_query = await _parse_repo_source(url) + with pytest.warns( + RuntimeWarning, + match="Warning: Failed to fetch branch list: Git command failed: " + "git ls-remote --heads https://github.com/user/repo", + ): - assert parsed_query.branch == expected_branch - assert parsed_query.subpath == expected_subpath + parsed_query = await _parse_repo_source(url) + + assert parsed_query.branch == expected_branch + assert parsed_query.subpath == expected_subpath @pytest.mark.asyncio @@ -332,6 +462,13 @@ async def test_parse_repo_source_with_failed_git_command(url, expected_branch, e ], ) async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, expected_subpath): + """ + Test `_parse_repo_source` with various URL patterns. + + Given multiple branch/blob patterns (including nonexistent branches): + When `_parse_repo_source` is called with remote branch fetching, + Then the correct branch/subpath should be set or None if unmatched. + """ with ( patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_run_git_command, patch("gitingest.repository_clone.fetch_remote_branch_list", new_callable=AsyncMock) as mock_fetch_branches, @@ -344,5 +481,6 @@ async def test_parse_repo_source_with_various_url_patterns(url, expected_branch, mock_fetch_branches.return_value = ["feature/fix1", "main", "feature-branch"] parsed_query = await _parse_repo_source(url) + assert parsed_query.branch == expected_branch assert parsed_query.subpath == expected_subpath diff --git a/tests/test_notebook_utils.py b/tests/test_notebook_utils.py index 6a23b926..3335a797 100644 --- a/tests/test_notebook_utils.py +++ b/tests/test_notebook_utils.py @@ -1,17 +1,27 @@ -""" Tests for the notebook_utils module. """ +""" +Tests for the `notebook_utils` module. + +These tests validate how notebooks are processed into Python-like output, ensuring that markdown/raw cells are +converted to triple-quoted blocks, code cells remain executable code, and various edge cases (multiple worksheets, +empty cells, outputs, etc.) are handled appropriately. +""" import pytest from gitingest.notebook_utils import process_notebook +from tests.conftest import WriteNotebookFunc -def test_process_notebook_all_cells(write_notebook): +def test_process_notebook_all_cells(write_notebook: WriteNotebookFunc) -> None: """ - Test a notebook containing markdown, code, and raw cells. - - - Markdown/raw cells => triple-quoted - - Code cells => remain normal code - - For 1 markdown + 1 raw => 2 triple-quoted blocks => 4 occurrences of triple-quotes. + Test processing a notebook containing markdown, code, and raw cells. + + Given a notebook with: + - One markdown cell + - One code cell + - One raw cell + When `process_notebook` is invoked, + Then markdown and raw cells should appear in triple-quoted blocks, and code cells remain as normal code. """ notebook_content = { "cells": [ @@ -23,24 +33,25 @@ def test_process_notebook_all_cells(write_notebook): nb_path = write_notebook("all_cells.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 4, "Expected 4 triple-quote occurrences for 2 blocks." + assert result.count('"""') == 4, "Two non-code cells => 2 triple-quoted blocks => 4 total triple quotes." - # Check that markdown and raw content are inside triple-quoted blocks + # Ensure markdown and raw cells are in triple quotes assert "# Markdown cell" in result assert "" in result - # Check code cell is present and not wrapped in triple quotes + # Ensure code cell is not in triple quotes assert 'print("Hello Code")' in result assert '"""\nprint("Hello Code")\n"""' not in result -def test_process_notebook_with_worksheets(write_notebook): +def test_process_notebook_with_worksheets(write_notebook: WriteNotebookFunc) -> None: """ - Test a notebook containing the 'worksheets' key (deprecated as of IPEP-17). + Test a notebook containing the (as of IPEP-17 deprecated) 'worksheets' key. - - Should raise a DeprecationWarning. - - We process only the first (and only) worksheet's cells. - - The resulting content matches an equivalent notebook with "cells" at top level. + Given a notebook that uses the 'worksheets' key with a single worksheet, + When `process_notebook` is called, + Then a `DeprecationWarning` should be raised, and the content should match an equivalent notebook + that has top-level 'cells'. """ with_worksheets = { "worksheets": [ @@ -53,7 +64,7 @@ def test_process_notebook_with_worksheets(write_notebook): } ] } - without_worksheets = with_worksheets["worksheets"][0] # same, but no 'worksheets' key at top + without_worksheets = with_worksheets["worksheets"][0] # same, but no 'worksheets' key nb_with = write_notebook("with_worksheets.ipynb", with_worksheets) nb_without = write_notebook("without_worksheets.ipynb", without_worksheets) @@ -61,15 +72,22 @@ def test_process_notebook_with_worksheets(write_notebook): with pytest.warns(DeprecationWarning, match="Worksheets are deprecated as of IPEP-17."): result_with = process_notebook(nb_with) - # No warnings here + # Should not raise a warning result_without = process_notebook(nb_without) - assert result_with == result_without, "Both notebooks should produce identical content." + assert result_with == result_without, "Content from the single worksheet should match the top-level equivalent." -def test_process_notebook_multiple_worksheets(write_notebook): +def test_process_notebook_multiple_worksheets(write_notebook: WriteNotebookFunc) -> None: """ Test a notebook containing multiple 'worksheets'. + + Given a notebook with two worksheets: + - First with a markdown cell + - Second with a code cell + When `process_notebook` is called, + Then a warning about multiple worksheets should be raised, and the second worksheet's content should appear + in the final output. """ multi_worksheets = { "worksheets": [ @@ -78,7 +96,6 @@ def test_process_notebook_multiple_worksheets(write_notebook): ] } - # Single-worksheet version (only the first) single_worksheet = { "worksheets": [ {"cells": [{"cell_type": "markdown", "source": ["# First Worksheet"]}]}, @@ -88,6 +105,7 @@ def test_process_notebook_multiple_worksheets(write_notebook): nb_multi = write_notebook("multiple_worksheets.ipynb", multi_worksheets) nb_single = write_notebook("single_worksheet.ipynb", single_worksheet) + # Expect DeprecationWarning + UserWarning with pytest.warns( DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." ): @@ -96,25 +114,27 @@ def test_process_notebook_multiple_worksheets(write_notebook): ): result_multi = process_notebook(nb_multi) + # Expect DeprecationWarning only with pytest.warns( DeprecationWarning, match="Worksheets are deprecated as of IPEP-17. Consider updating the notebook." ): result_single = process_notebook(nb_single) - # The second worksheet (with code) should have been ignored - assert result_multi != result_single, "The multi-worksheet notebook should have more content." - assert len(result_multi) > len(result_single), "The multi-worksheet notebook should have more content." - assert "# First Worksheet" in result_single, "First worksheet content should be present." - assert "# Second Worksheet" not in result_single, "Second worksheet content should be absent." - assert "# First Worksheet" in result_multi, "First worksheet content should be present." - assert "# Second Worksheet" in result_multi, "Second worksheet content should be present." + assert result_multi != result_single, "Two worksheets should produce more content than one." + assert len(result_multi) > len(result_single), "The multi-worksheet notebook should have extra code content." + assert "# First Worksheet" in result_single + assert "# Second Worksheet" not in result_single + assert "# First Worksheet" in result_multi + assert "# Second Worksheet" in result_multi -def test_process_notebook_code_only(write_notebook): +def test_process_notebook_code_only(write_notebook: WriteNotebookFunc) -> None: """ Test a notebook containing only code cells. - No triple quotes should appear. + Given a notebook with code cells only: + When `process_notebook` is called, + Then no triple quotes should appear in the output. """ notebook_content = { "cells": [ @@ -125,17 +145,18 @@ def test_process_notebook_code_only(write_notebook): nb_path = write_notebook("code_only.ipynb", notebook_content) result = process_notebook(nb_path) - # No triple quotes - assert '"""' not in result + assert '"""' not in result, "No triple quotes expected when there are only code cells." assert "print('Code Cell 1')" in result assert "x = 42" in result -def test_process_notebook_markdown_only(write_notebook): +def test_process_notebook_markdown_only(write_notebook: WriteNotebookFunc) -> None: """ - Test a notebook with 2 markdown cells. + Test a notebook with only markdown cells. - 2 markdown cells => each becomes 1 triple-quoted block => 2 blocks => 4 triple quotes. + Given a notebook with two markdown cells: + When `process_notebook` is called, + Then each markdown cell should become a triple-quoted block (2 blocks => 4 triple quotes total). """ notebook_content = { "cells": [ @@ -146,16 +167,18 @@ def test_process_notebook_markdown_only(write_notebook): nb_path = write_notebook("markdown_only.ipynb", notebook_content) result = process_notebook(nb_path) - assert result.count('"""') == 4, "Two markdown cells => two triple-quoted blocks => 4 triple quotes total." + assert result.count('"""') == 4, "Two markdown cells => 2 blocks => 4 triple quotes total." assert "# Markdown Header" in result assert "Some more markdown." in result -def test_process_notebook_raw_only(write_notebook): +def test_process_notebook_raw_only(write_notebook: WriteNotebookFunc) -> None: """ - Test a notebook with 2 raw cells. + Test a notebook with only raw cells. - 2 raw cells => 2 blocks => 4 triple quotes. + Given two raw cells: + When `process_notebook` is called, + Then each raw cell should become a triple-quoted block (2 blocks => 4 triple quotes total). """ notebook_content = { "cells": [ @@ -166,17 +189,18 @@ def test_process_notebook_raw_only(write_notebook): nb_path = write_notebook("raw_only.ipynb", notebook_content) result = process_notebook(nb_path) - # 2 raw cells => 2 triple-quoted blocks => 4 occurrences - assert result.count('"""') == 4 + assert result.count('"""') == 4, "Two raw cells => 2 blocks => 4 triple quotes." assert "Raw content line 1" in result assert "Raw content line 2" in result -def test_process_notebook_empty_cells(write_notebook): +def test_process_notebook_empty_cells(write_notebook: WriteNotebookFunc) -> None: """ - Test that cells with an empty 'source' are skipped entirely. + Test that cells with an empty 'source' are skipped. - 4 cells but 3 are empty => only 1 non-empty cell => 1 triple-quoted block => 2 quotes. + Given a notebook with 4 cells, 3 of which have empty `source`: + When `process_notebook` is called, + Then only the non-empty cell should appear in the output (1 block => 2 triple quotes). """ notebook_content = { "cells": [ @@ -189,16 +213,17 @@ def test_process_notebook_empty_cells(write_notebook): nb_path = write_notebook("empty_cells.ipynb", notebook_content) result = process_notebook(nb_path) - # Only one non-empty markdown cell => 1 block => 2 triple quotes - assert result.count('"""') == 2 + assert result.count('"""') == 2, "Only one non-empty cell => 1 block => 2 triple quotes" assert "# Non-empty markdown" in result -def test_process_notebook_invalid_cell_type(write_notebook): +def test_process_notebook_invalid_cell_type(write_notebook: WriteNotebookFunc) -> None: """ Test a notebook with an unknown cell type. - Should raise a ValueError. + Given a notebook cell whose `cell_type` is unrecognized: + When `process_notebook` is called, + Then a ValueError should be raised. """ notebook_content = { "cells": [ @@ -212,11 +237,13 @@ def test_process_notebook_invalid_cell_type(write_notebook): process_notebook(nb_path) -def test_process_notebook_with_output(write_notebook): +def test_process_notebook_with_output(write_notebook: WriteNotebookFunc) -> None: """ - Test a notebook with code cells and outputs. + Test a notebook that has code cells with outputs. - The outputs should be included as comments if `include_output=True`. + Given a code cell and multiple output objects: + When `process_notebook` is called with `include_output=True`, + Then the outputs should be appended as commented lines under the code. """ notebook_content = { "cells": [ @@ -263,5 +290,5 @@ def test_process_notebook_with_output(write_notebook): expected_combined = expected_source + expected_output - assert with_output == expected_combined, "Expected source code and output as comments." - assert without_output == expected_source, "Expected source code only." + assert with_output == expected_combined, "Should include source code and comment-ified output." + assert without_output == expected_source, "Should include only the source code without output." diff --git a/tests/test_query_ingestion.py b/tests/test_query_ingestion.py index 09076586..cde8df3f 100644 --- a/tests/test_query_ingestion.py +++ b/tests/test_query_ingestion.py @@ -1,35 +1,57 @@ -""" Tests for the query_ingestion module """ +""" +Tests for the `query_ingestion` module. + +These tests validate directory scanning, file content extraction, notebook handling, and the overall ingestion logic, +including filtering patterns and subpaths. +""" from pathlib import Path from unittest.mock import patch +import pytest + from gitingest.query_ingestion import _extract_files_content, _read_file_content, _scan_directory, run_ingest_query from gitingest.query_parser import ParsedQuery def test_scan_directory(temp_directory: Path, sample_query: ParsedQuery) -> None: + """ + Test `_scan_directory` with default settings. + + Given a populated test directory: + When `_scan_directory` is called, + Then it should return a structured node containing the correct directories and file counts. + """ sample_query.local_path = temp_directory result = _scan_directory(temp_directory, query=sample_query) - if result is None: - assert False, "Result is None" + assert result is not None, "Expected a valid directory node structure" assert result["type"] == "directory" - assert result["file_count"] == 8 # All .txt and .py files - assert result["dir_count"] == 4 # src, src/subdir, dir1, dir2 - assert len(result["children"]) == 5 # file1.txt, file2.py, src, dir1, dir2 + assert result["file_count"] == 8, "Should count all .txt and .py files" + assert result["dir_count"] == 4, "Should include src, src/subdir, dir1, dir2" + assert len(result["children"]) == 5, "Should contain file1.txt, file2.py, src, dir1, dir2" def test_extract_files_content(temp_directory: Path, sample_query: ParsedQuery) -> None: - sample_query.local_path = temp_directory + """ + Test `_extract_files_content` to ensure it gathers contents from scanned nodes. + Given a populated test directory: + When `_extract_files_content` is called with a valid scan result, + Then it should return a list of file info containing the correct filenames and paths. + """ + sample_query.local_path = temp_directory nodes = _scan_directory(temp_directory, query=sample_query) - if nodes is None: - assert False, "Nodes is None" + + assert nodes is not None, "Expected a valid scan result" + files = _extract_files_content(query=sample_query, node=nodes) - assert len(files) == 8 # All .txt and .py files - # Check for presence of key files + assert len(files) == 8, "Should extract all .txt and .py files" + paths = [f["path"] for f in files] + + # Verify presence of key files assert any("file1.txt" in p for p in paths) assert any("subfile1.txt" in p for p in paths) assert any("file2.py" in p for p in paths) @@ -39,124 +61,128 @@ def test_extract_files_content(temp_directory: Path, sample_query: ParsedQuery) assert any("file_dir2.txt" in p for p in paths) -def test_read_file_content_with_notebook(tmp_path: Path): +def test_read_file_content_with_notebook(tmp_path: Path) -> None: + """ + Test `_read_file_content` with a notebook file. + + Given a minimal .ipynb file: + When `_read_file_content` is called, + Then `process_notebook` should be invoked to handle notebook-specific content. + """ notebook_path = tmp_path / "dummy_notebook.ipynb" notebook_path.write_text("{}", encoding="utf-8") # minimal JSON - # Patch the symbol as it is used in query_ingestion with patch("gitingest.query_ingestion.process_notebook") as mock_process: _read_file_content(notebook_path) + mock_process.assert_called_once_with(notebook_path) def test_read_file_content_with_non_notebook(tmp_path: Path): + """ + Test `_read_file_content` with a non-notebook file. + + Given a standard .py file: + When `_read_file_content` is called, + Then `process_notebook` should not be triggered. + """ py_file_path = tmp_path / "dummy_file.py" py_file_path.write_text("print('Hello')", encoding="utf-8") with patch("gitingest.query_ingestion.process_notebook") as mock_process: _read_file_content(py_file_path) + mock_process.assert_not_called() -# Test that when using a ['*.txt'] as include pattern, only .txt files are processed & .py files are excluded def test_include_txt_pattern(temp_directory: Path, sample_query: ParsedQuery) -> None: + """ + Test including only .txt files using a pattern like `*.txt`. + + Given a directory with mixed .txt and .py files: + When `include_patterns` is set to `*.txt`, + Then `_scan_directory` should include only .txt files, excluding .py files. + """ sample_query.local_path = temp_directory sample_query.include_patterns = {"*.txt"} result = _scan_directory(temp_directory, query=sample_query) - assert result is not None, "Result should not be None" + assert result is not None, "Expected a valid directory node structure" files = _extract_files_content(query=sample_query, node=result) file_paths = [f["path"] for f in files] - assert len(files) == 5, "Should have found exactly 5 .txt files" + + assert len(files) == 5, "Should find exactly 5 .txt files" assert all(path.endswith(".txt") for path in file_paths), "Should only include .txt files" expected_files = ["file1.txt", "subfile1.txt", "file_subdir.txt", "file_dir1.txt", "file_dir2.txt"] for expected_file in expected_files: assert any(expected_file in path for path in file_paths), f"Missing expected file: {expected_file}" - assert not any(path.endswith(".py") for path in file_paths), "Should not include .py files" + assert not any(path.endswith(".py") for path in file_paths), "No .py files should be included" def test_include_nonexistent_extension(temp_directory: Path, sample_query: ParsedQuery) -> None: + """ + Test including a nonexistent extension (e.g., `*.query`). + + Given a directory with no files matching `*.query`: + When `_scan_directory` is called with that pattern, + Then no files should be returned in the result. + """ sample_query.local_path = temp_directory - sample_query.include_patterns = {"*.query"} # Is a Non existant extension ? + sample_query.include_patterns = {"*.query"} # Nonexistent extension result = _scan_directory(temp_directory, query=sample_query) - assert result is not None, "Result should not be None" + assert result is not None, "Expected a valid directory node structure" - # Extract the files content & set file limit cap files = _extract_files_content(query=sample_query, node=result) - # Verify no file processed with wrong extension - assert len(files) == 0, "Should not find any files with .qwerty extension" + assert len(files) == 0, "Should not find any files matching *.query" assert result["type"] == "directory" - assert result["file_count"] == 0 + assert result["file_count"] == 0, "No files counted with this pattern" assert result["dir_count"] == 0 assert len(result["children"]) == 0 -# single folder patterns -def test_include_src_star_pattern(temp_directory: Path, sample_query: ParsedQuery) -> None: - """ - Test that when using 'src/*' as include pattern, files under the src directory - are included. - Note: Windows is not supported - test converts Windows paths to Unix-style for validation. +@pytest.mark.parametrize("include_pattern", ["src/*", "src/**", "src*"]) +def test_include_src_patterns(temp_directory: Path, sample_query: ParsedQuery, include_pattern: str) -> None: """ - sample_query.local_path = temp_directory - sample_query.include_patterns = {"src/*"} + Test including files under the `src` directory with various patterns. - result = _scan_directory(temp_directory, query=sample_query) - assert result is not None, "Result should not be None" + Given a directory containing `src` with subfiles: + When `include_patterns` is set to `src/*`, `src/**`, or `src*`, + Then `_scan_directory` should include the correct files under `src`. - files = _extract_files_content(query=sample_query, node=result) - # Convert Windows paths to Unix-style for test validation - file_paths = {f["path"].replace("\\", "/") for f in files} - expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} - assert file_paths == expected_paths, "Missing or unexpected files in result" - - -def test_include_src_recursive(temp_directory: Path, sample_query: ParsedQuery) -> None: - """ - Test that when using 'src/**' as include pattern, all files under src - directory are included recursively. - Note: Windows is not supported - test converts Windows paths to Unix-style for validation. + Note: Windows is not supported; paths are converted to Unix-style for validation. """ sample_query.local_path = temp_directory - sample_query.include_patterns = {"src/**"} + sample_query.include_patterns = {include_pattern} result = _scan_directory(temp_directory, query=sample_query) - assert result is not None, "Result should not be None" + assert result is not None, "Expected a valid directory node structure" files = _extract_files_content(query=sample_query, node=result) - # Convert Windows paths to Unix-style for test validation - file_paths = {f["path"].replace("\\", "/") for f in files} - expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} - assert file_paths == expected_paths, "Missing or unexpected files in result" - - -def test_include_src_wildcard_prefix(temp_directory: Path, sample_query: ParsedQuery) -> None: - """ - Test that when using 'src*' as include pattern, it matches the src directory - and any paths that start with 'src'. - Note: Windows is not supported - test converts Windows paths to Unix-style for validation. - """ - sample_query.local_path = temp_directory - sample_query.include_patterns = {"src*"} - result = _scan_directory(temp_directory, query=sample_query) - assert result is not None, "Result should not be None" - - files = _extract_files_content(query=sample_query, node=result) - # Convert Windows paths to Unix-style for test validation + # Convert Windows paths to Unix-style file_paths = {f["path"].replace("\\", "/") for f in files} - expected_paths = {"src/subfile1.txt", "src/subfile2.py", "src/subdir/file_subdir.txt", "src/subdir/file_subdir.py"} + + expected_paths = { + "src/subfile1.txt", + "src/subfile2.py", + "src/subdir/file_subdir.txt", + "src/subdir/file_subdir.py", + } assert file_paths == expected_paths, "Missing or unexpected files in result" def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> None: """ - Test the run_ingest_query function to ensure it processes the directory correctly. + Test `run_ingest_query` to ensure it processes the directory and returns expected results. + + Given a directory with .txt and .py files: + When `run_ingest_query` is invoked, + Then it should produce a summary string listing the files analyzed and a combined content string. """ sample_query.local_path = temp_directory sample_query.subpath = "/" @@ -166,6 +192,8 @@ def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> No assert "Repository: test_user/test_repo" in summary assert "Files analyzed: 8" in summary + + # Check presence of key files in the content assert "src/subfile1.txt" in content assert "src/subfile2.py" in content assert "src/subdir/file_subdir.txt" in content @@ -176,7 +204,6 @@ def test_run_ingest_query(temp_directory: Path, sample_query: ParsedQuery) -> No assert "dir2/file_dir2.txt" in content -# multiple patterns -# TODO: test with multiple include patterns: ['*.txt', '*.py'] -# TODO: test with multiple include patterns: ['/src/*', '*.txt'] -# TODO: test with multiple include patterns: ['/src*', '*.txt'] +# TODO: Additional tests: +# - Multiple include patterns, e.g. ["*.txt", "*.py"] or ["/src/*", "*.txt"]. +# - Edge cases with weird file names or deep subdirectory structures. diff --git a/tests/test_repository_clone.py b/tests/test_repository_clone.py index 9ff2736f..de417bea 100644 --- a/tests/test_repository_clone.py +++ b/tests/test_repository_clone.py @@ -1,4 +1,9 @@ -""" Tests for the repository_clone module. """ +""" +Tests for the `repository_clone` module. + +These tests cover various scenarios for cloning repositories, verifying that the appropriate Git commands are invoked +and handling edge cases such as nonexistent URLs, timeouts, redirects, and specific commits or branches. +""" import asyncio from unittest.mock import AsyncMock, patch @@ -12,8 +17,11 @@ @pytest.mark.asyncio async def test_clone_repo_with_commit() -> None: """ - Test the `clone_repo` function when a specific commit hash is provided. - Verifies that the repository is cloned and checked out to the specified commit. + Test cloning a repository with a specific commit hash. + + Given a valid URL and a commit hash: + When `clone_repo` is called, + Then the repository should be cloned and checked out at that commit. """ clone_config = CloneConfig( url="https://github.com/user/repo", @@ -27,7 +35,9 @@ async def test_clone_repo_with_commit() -> None: mock_process = AsyncMock() mock_process.communicate.return_value = (b"output", b"error") mock_exec.return_value = mock_process + await clone_repo(clone_config) + mock_check.assert_called_once_with(clone_config.url) assert mock_exec.call_count == 2 # Clone and checkout calls @@ -35,10 +45,18 @@ async def test_clone_repo_with_commit() -> None: @pytest.mark.asyncio async def test_clone_repo_without_commit() -> None: """ - Test the `clone_repo` function when no commit hash is provided. - Verifies that only the repository clone operation is performed. + Test cloning a repository when no commit hash is provided. + + Given a valid URL and no commit hash: + When `clone_repo` is called, + Then only the clone operation should be performed (no checkout). """ - query = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", commit=None, branch="main") + query = CloneConfig( + url="https://github.com/user/repo", + local_path="/tmp/repo", + commit=None, + branch="main", + ) with patch("gitingest.repository_clone._check_repo_exists", return_value=True) as mock_check: with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_exec: @@ -47,6 +65,7 @@ async def test_clone_repo_without_commit() -> None: mock_exec.return_value = mock_process await clone_repo(query) + mock_check.assert_called_once_with(query.url) assert mock_exec.call_count == 1 # Only clone call @@ -54,8 +73,11 @@ async def test_clone_repo_without_commit() -> None: @pytest.mark.asyncio async def test_clone_repo_nonexistent_repository() -> None: """ - Test the `clone_repo` function when the repository does not exist. - Verifies that a ValueError is raised with an appropriate error message. + Test cloning a nonexistent repository URL. + + Given an invalid or nonexistent URL: + When `clone_repo` is called, + Then a ValueError should be raised with an appropriate error message. """ clone_config = CloneConfig( url="https://github.com/user/nonexistent-repo", @@ -66,41 +88,49 @@ async def test_clone_repo_nonexistent_repository() -> None: with patch("gitingest.repository_clone._check_repo_exists", return_value=False) as mock_check: with pytest.raises(ValueError, match="Repository not found"): await clone_repo(clone_config) + mock_check.assert_called_once_with(clone_config.url) @pytest.mark.asyncio -async def test_check_repo_exists() -> None: +@pytest.mark.parametrize( + "mock_stdout, return_code, expected", + [ + (b"HTTP/1.1 200 OK\n", 0, True), # Existing repo + (b"HTTP/1.1 404 Not Found\n", 0, False), # Non-existing repo + (b"HTTP/1.1 200 OK\n", 1, False), # Failed request + ], +) +async def test_check_repo_exists(mock_stdout: bytes, return_code: int, expected: bool) -> None: """ - Test the `_check_repo_exists` function to verify if a repository exists. - Covers cases for existing repositories, non-existing repositories (404), and failed requests. + Test the `_check_repo_exists` function with different Git HTTP responses. + + Given various stdout lines and return codes: + When `_check_repo_exists` is called, + Then it should correctly indicate whether the repository exists. """ url = "https://github.com/user/repo" with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_exec: mock_process = AsyncMock() - mock_process.communicate.return_value = (b"HTTP/1.1 200 OK\n", b"") + # Mock the subprocess output + mock_process.communicate.return_value = (mock_stdout, b"") + mock_process.returncode = return_code mock_exec.return_value = mock_process - # Test existing repository - mock_process.returncode = 0 - assert await _check_repo_exists(url) is True + repo_exists = await _check_repo_exists(url) - # Test non-existing repository (404 response) - mock_process.communicate.return_value = (b"HTTP/1.1 404 Not Found\n", b"") - mock_process.returncode = 0 - assert await _check_repo_exists(url) is False - - # Test failed request - mock_process.returncode = 1 - assert await _check_repo_exists(url) is False + assert repo_exists is expected @pytest.mark.asyncio async def test_clone_repo_invalid_url() -> None: """ - Test the `clone_repo` function when an invalid or empty URL is provided. - Verifies that a ValueError is raised with an appropriate error message. + Test cloning when the URL is invalid or empty. + + Given an empty URL: + When `clone_repo` is called, + Then a ValueError should be raised with an appropriate error message. """ clone_config = CloneConfig( url="", @@ -113,8 +143,11 @@ async def test_clone_repo_invalid_url() -> None: @pytest.mark.asyncio async def test_clone_repo_invalid_local_path() -> None: """ - Test the `clone_repo` function when an invalid or empty local path is provided. - Verifies that a ValueError is raised with an appropriate error message. + Test cloning when the local path is invalid or empty. + + Given an empty local path: + When `clone_repo` is called, + Then a ValueError should be raised with an appropriate error message. """ clone_config = CloneConfig( url="https://github.com/user/repo", @@ -127,17 +160,17 @@ async def test_clone_repo_invalid_local_path() -> None: @pytest.mark.asyncio async def test_clone_repo_with_custom_branch() -> None: """ - Test the `clone_repo` function when a custom branch is specified. - Verifies that the repository is cloned with the specified branch using a shallow clone. + Test cloning a repository with a specified custom branch. + + Given a valid URL and a branch: + When `clone_repo` is called, + Then the repository should be cloned shallowly to that branch. """ - clone_config = CloneConfig( - url="https://github.com/user/repo", - local_path="/tmp/repo", - branch="feature-branch", - ) + clone_config = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo", branch="feature-branch") with patch("gitingest.repository_clone._check_repo_exists", return_value=True): with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_exec: await clone_repo(clone_config) + mock_exec.assert_called_once_with( "git", "clone", @@ -153,8 +186,11 @@ async def test_clone_repo_with_custom_branch() -> None: @pytest.mark.asyncio async def test_git_command_failure() -> None: """ - Test the `clone_repo` function when a Git command fails during execution. - Verifies that a RuntimeError is raised with an appropriate error message. + Test cloning when the Git command fails during execution. + + Given a valid URL, but `_run_git_command` raises a RuntimeError: + When `clone_repo` is called, + Then a RuntimeError should be raised with the correct message. """ clone_config = CloneConfig( url="https://github.com/user/repo", @@ -169,16 +205,21 @@ async def test_git_command_failure() -> None: @pytest.mark.asyncio async def test_clone_repo_default_shallow_clone() -> None: """ - Test the `clone_repo` function with default shallow clone behavior. - Verifies that the repository is cloned with `--depth=1` and `--single-branch` options. + Test cloning a repository with the default shallow clone options. + + Given a valid URL and no branch or commit: + When `clone_repo` is called, + Then the repository should be cloned with `--depth=1` and `--single-branch`. """ clone_config = CloneConfig( url="https://github.com/user/repo", local_path="/tmp/repo", ) + with patch("gitingest.repository_clone._check_repo_exists", return_value=True): with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_exec: await clone_repo(clone_config) + mock_exec.assert_called_once_with( "git", "clone", "--depth=1", "--single-branch", clone_config.url, clone_config.local_path ) @@ -187,8 +228,11 @@ async def test_clone_repo_default_shallow_clone() -> None: @pytest.mark.asyncio async def test_clone_repo_commit_without_branch() -> None: """ - Test the `clone_repo` function when a commit hash is provided but no branch is specified. - Verifies that the repository is cloned and checked out to the specified commit. + Test cloning when a commit hash is provided but no branch is specified. + + Given a valid URL and a commit hash (but no branch): + When `clone_repo` is called, + Then the repository should be cloned and checked out at that commit. """ clone_config = CloneConfig( url="https://github.com/user/repo", @@ -198,6 +242,7 @@ async def test_clone_repo_commit_without_branch() -> None: with patch("gitingest.repository_clone._check_repo_exists", return_value=True): with patch("gitingest.repository_clone._run_git_command", new_callable=AsyncMock) as mock_exec: await clone_repo(clone_config) + assert mock_exec.call_count == 2 # Clone and checkout calls mock_exec.assert_any_call("git", "clone", "--single-branch", clone_config.url, clone_config.local_path) mock_exec.assert_any_call("git", "-C", clone_config.local_path, "checkout", clone_config.commit) @@ -206,9 +251,11 @@ async def test_clone_repo_commit_without_branch() -> None: @pytest.mark.asyncio async def test_check_repo_exists_with_redirect() -> None: """ - Test the `_check_repo_exists` function when the repository URL returns a redirect response. + Test `_check_repo_exists` when a redirect (302) is returned. - Verifies that the function returns False when a 302 Found response is received. + Given a URL that responds with "302 Found": + When `_check_repo_exists` is called, + Then it should return `False`, indicating the repo is inaccessible. """ url = "https://github.com/user/repo" with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_exec: @@ -217,15 +264,19 @@ async def test_check_repo_exists_with_redirect() -> None: mock_process.returncode = 0 # Simulate successful request mock_exec.return_value = mock_process - assert await _check_repo_exists(url) is False + repo_exists = await _check_repo_exists(url) + + assert repo_exists is False @pytest.mark.asyncio async def test_check_repo_exists_with_permanent_redirect() -> None: """ - Test the `_check_repo_exists` function when the repository URL returns a redirect response. + Test `_check_repo_exists` when a permanent redirect (301) is returned. - Verifies that the function returns True when a 301 Found response is received. + Given a URL that responds with "301 Found": + When `_check_repo_exists` is called, + Then it should return `True`, indicating the repo may exist at the new location. """ url = "https://github.com/user/repo" with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock) as mock_exec: @@ -234,14 +285,19 @@ async def test_check_repo_exists_with_permanent_redirect() -> None: mock_process.returncode = 0 # Simulate successful request mock_exec.return_value = mock_process - assert await _check_repo_exists(url) + repo_exists = await _check_repo_exists(url) + + assert repo_exists @pytest.mark.asyncio async def test_clone_repo_with_timeout() -> None: """ - Test the `clone_repo` function when the cloning process exceeds the timeout limit. - Verifies that an AsyncTimeoutError is raised. + Test cloning a repository when a timeout occurs. + + Given a valid URL, but `_run_git_command` times out: + When `clone_repo` is called, + Then an `AsyncTimeoutError` should be raised to indicate the operation exceeded time limits. """ clone_config = CloneConfig(url="https://github.com/user/repo", local_path="/tmp/repo") From 903e6991b0f02e192f63223968fe33a5d237a8cf Mon Sep 17 00:00:00 2001 From: Filip Christiansen <22807962+filipchristiansen@users.noreply.github.com> Date: Tue, 21 Jan 2025 07:15:10 +0100 Subject: [PATCH 2/2] make TEMPLATES lowercase --- src/config.py | 2 +- src/main.py | 4 ++-- src/query_processor.py | 4 ++-- src/routers/dynamic.py | 4 ++-- src/routers/index.py | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/config.py b/src/config.py index 9d9c2113..a262db5e 100644 --- a/src/config.py +++ b/src/config.py @@ -21,4 +21,4 @@ {"name": "ApiAnalytics", "url": "https://github.com/tom-draper/api-analytics"}, ] -TEMPLATES = Jinja2Templates(directory="templates") +templates = Jinja2Templates(directory="templates") diff --git a/src/main.py b/src/main.py index 241e9458..7bfec181 100644 --- a/src/main.py +++ b/src/main.py @@ -10,7 +10,7 @@ from slowapi.errors import RateLimitExceeded from starlette.middleware.trustedhost import TrustedHostMiddleware -from config import TEMPLATES +from config import templates from routers import download, dynamic, index from server_utils import limiter from utils import lifespan, rate_limit_exception_handler @@ -89,7 +89,7 @@ async def api_docs(request: Request) -> HTMLResponse: HTMLResponse A rendered HTML page displaying API documentation. """ - return TEMPLATES.TemplateResponse("api.jinja", {"request": request}) + return templates.TemplateResponse("api.jinja", {"request": request}) @app.get("/robots.txt") diff --git a/src/query_processor.py b/src/query_processor.py index 72603592..af7a9079 100644 --- a/src/query_processor.py +++ b/src/query_processor.py @@ -5,7 +5,7 @@ from fastapi import Request from starlette.templating import _TemplateResponse -from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE, TEMPLATES +from config import EXAMPLE_REPOS, MAX_DISPLAY_SIZE, templates from gitingest.query_ingestion import run_ingest_query from gitingest.query_parser import ParsedQuery, parse_query from gitingest.repository_clone import CloneConfig, clone_repo @@ -61,7 +61,7 @@ async def process_query( raise ValueError(f"Invalid pattern type: {pattern_type}") template = "index.jinja" if is_index else "git.jinja" - template_response = partial(TEMPLATES.TemplateResponse, name=template) + template_response = partial(templates.TemplateResponse, name=template) max_file_size = log_slider_to_size(slider_position) context = { diff --git a/src/routers/dynamic.py b/src/routers/dynamic.py index 48e6c080..d711655a 100644 --- a/src/routers/dynamic.py +++ b/src/routers/dynamic.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse -from config import TEMPLATES +from config import templates from query_processor import process_query from server_utils import limiter @@ -31,7 +31,7 @@ async def catch_all(request: Request, full_path: str) -> HTMLResponse: An HTML response containing the rendered template, with the Git URL and other default parameters such as loading state and file size. """ - return TEMPLATES.TemplateResponse( + return templates.TemplateResponse( "git.jinja", { "request": request, diff --git a/src/routers/index.py b/src/routers/index.py index b5d2f6c9..ff226130 100644 --- a/src/routers/index.py +++ b/src/routers/index.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Form, Request from fastapi.responses import HTMLResponse -from config import EXAMPLE_REPOS, TEMPLATES +from config import EXAMPLE_REPOS, templates from query_processor import process_query from server_utils import limiter @@ -29,7 +29,7 @@ async def home(request: Request) -> HTMLResponse: An HTML response containing the rendered home page template, with example repositories and other default parameters such as file size. """ - return TEMPLATES.TemplateResponse( + return templates.TemplateResponse( "index.jinja", { "request": request,