Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add script to check imports #19611

Merged
merged 11 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions .github/workflows/_test_doc_imports.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: test
ccurme marked this conversation as resolved.
Show resolved Hide resolved

env:
POETRY_VERSION: "1.7.1"

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
ccurme marked this conversation as resolved.
Show resolved Hide resolved
name: "check doc imports #${{ matrix.python-version }}"
steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
cache-key: core

- name: Install dependencies
shell: bash
run: poetry install --with test

- name: Install langchain editable
run: |
poetry run pip install -e libs/langchain libs/community libs/experimental
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we install core editable as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also check partner packages (and install those here)


- name: Check doc imports
shell: bash
run: |
python docs/scripts/check_imports.py
ccurme marked this conversation as resolved.
Show resolved Hide resolved

- name: Ensure the test did not create any additional files
shell: bash
run: |
set -eu

STATUS="$(git status)"
echo "$STATUS"

# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
6 changes: 6 additions & 0 deletions .github/workflows/check_diffs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ jobs:
working-directory: ${{ matrix.working-directory }}
secrets: inherit

test_doc_imports:
needs: [ build ]
if: ${{ needs.build.outputs.dirs-to-test != '[]' }}
uses: ./.github/workflows/_test_doc_imports.yml
secrets: inherit

compile-integration-tests:
name: cd ${{ matrix.working-directory }}
needs: [ build ]
Expand Down
129 changes: 129 additions & 0 deletions docs/scripts/check_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import importlib
ccurme marked this conversation as resolved.
Show resolved Hide resolved
import json
import logging
import os
import re
import warnings
from pathlib import Path
from typing import List, Tuple

logger = logging.getLogger(__name__)

DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
import_pattern = re.compile(
r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
)


def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
"""Get (module, import) statements from a single code cell."""
import_statements = []
for line in code_lines:
line = line.strip()
if line.startswith("#") or not line:
continue
# Join lines that end with a backslash
if line.endswith("\\"):
line = line[:-1].rstrip() + " "
continue
matches = import_pattern.findall(line)
for match in matches:
if match[0]: # simple import statement
import_statements.append((match[0], ""))
else: # from ___ import statement
module, items = match[1], match[2]
items_list = items.replace(" ", "").split(",")
for item in items_list:
import_statements.append((module, item))
return import_statements


def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
"""Get (module, import) statements from a Jupyter notebook."""
with open(notebook_path, "r", encoding="utf-8") as file:
notebook = json.load(file)
code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"]
import_statements = []
for cell in code_cells:
code_lines = cell["source"]
import_statements.extend(_get_imports_from_code_cell(code_lines))
return import_statements


def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""Collect offending import statements."""
offending_imports = []
for module, item in import_statements:
try:
if item:
try:
# submodule
full_module_name = f"{module}.{item}"
importlib.import_module(full_module_name)
except ModuleNotFoundError:
# attribute
try:
imported_module = importlib.import_module(module)
getattr(imported_module, item)
except Exception as e:
ccurme marked this conversation as resolved.
Show resolved Hide resolved
offending_imports.append((module, item))
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a broad exception?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in specific exceptions I was expecting. These are broad in case I missed something among ImportError, AttributeError, or ModuleNotFoundError.

offending_imports.append((module, item))
else:
importlib.import_module(module)
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import error?

offending_imports.append((module, item))

return offending_imports


def _is_relevant_import(module: str) -> bool:
"""Check if module is recognized."""
# Ignore things like langchain_{bla}, where bla is unrecognized.
recognized_packages = [
"langchain",
"langchain_core",
"langchain_community",
"langchain_experimental",
"langchain_text_splitters",
]
return module.split(".")[0] in recognized_packages


def _serialize_bad_imports(bad_files: list) -> str:
"""Serialize bad imports to a string."""
bad_imports_str = ""
for file, bad_imports in bad_files:
bad_imports_str += f"File: {file}\n"
for module, item in bad_imports:
bad_imports_str += f" {module}.{item}\n"
return bad_imports_str


def check_notebooks(directory: str) -> list:
"""Check notebooks for broken import statements."""
bad_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".ipynb"):
notebook_path = os.path.join(root, file)
import_statements = [
(module, item)
for module, item in _extract_import_statements(notebook_path)
if _is_relevant_import(module)
]
bad_imports = _get_bad_imports(import_statements)
if bad_imports:
bad_files.append(
(
file,
bad_imports,
)
)
return bad_files


if __name__ == "__main__":
bad_files = check_notebooks(DOCS_DIR)
if bad_files:
warnings.warn("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")