diff --git a/.github/workflows/modelkit-ci.yml b/.github/workflows/modelkit-ci.yml index 486bcb3fc..a3e7ecba1 100644 --- a/.github/workflows/modelkit-ci.yml +++ b/.github/workflows/modelkit-ci.yml @@ -19,21 +19,23 @@ jobs: fail-fast: false matrix: include: - - group: unit - paths: tests/unit + - group: analyze + paths: tests/unit/analyze - group: models - paths: tests/models tests/loader tests/dataset_tests tests/export + paths: >- + tests/unit/models tests/unit/loader tests/unit/datasets + tests/unit/export - group: optim - paths: tests/optim + paths: tests/unit/optim - group: commands paths: >- - tests/commands tests/config tests/build tests/compiler - tests/session tests/eval + tests/unit/commands tests/unit/config tests/unit/build + tests/unit/compiler tests/unit/session tests/unit/eval - group: remaining paths: >- - tests/core tests/onnx tests/cache tests/utils tests/sysinfo - tests/inspect tests/regression tests/optracing - tests/test_cli.py tests/test_text_classification.py + tests/unit/core tests/unit/onnx tests/unit/cache + tests/unit/utils tests/unit/sysinfo tests/unit/inspect + tests/unit/optracing tests/regression name: test (${{ matrix.group }}) diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..ec43fa6a9 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,37 @@ +# CLAUDE.md + +## Cardinal Rules + +### 1. No Hardcoded Logic + +Never hardcode model architecture names, node/operator names, input/output tensor names, layer naming patterns, or any model-specific logic. All solutions must be universal and architecture-agnostic. + +### 2. Pytest Only + +All testing uses pytest with code-generated results. Never create standalone test scripts, use LLM-generated expectations, or generate test data manually. + +### 3. Mandatory Test Verification + +Run `uv run pytest tests/` after every implementation or test revision. Never assume tests pass without verification. + +### 4. Never Skip Failing Tests + +Investigate root cause and fix the underlying issue. Never use `pytest.mark.skip` or `xfail` to hide failures. Skips are only acceptable for hardware/EP requirements (CUDA, DirectML, AVX). + +## Development Commands + +- **Python**: Always use `uv run` or activate venv first. Never run bare python commands. +- **Temp files**: Use `temp/` folder in project root. +- **Node.js**: Available via fnm. Use `eval "$(fnm env)"` before npm/npx commands. + +## Code Quality + +- Run `uv run ruff check --fix` after revising Python code +- Follow naming rules in [`/docs/naming-convention.md`](/docs/naming-convention.md) (ONNX, EP, QDQ, Op acronym casing) +- Always ask clarifying questions before planning if requirements are ambiguous +- Critically evaluate proposals — challenge design decisions when warranted + +## Git + +- Never add `Co-Authored-By` when doing git commit +- Do not include "Test plan" section in PR descriptions diff --git a/docs/naming-convention.md b/docs/naming-convention.md new file mode 100644 index 000000000..c70d67edf --- /dev/null +++ b/docs/naming-convention.md @@ -0,0 +1,100 @@ +# ModelKit Naming Convention + +This document defines the naming rules for the ModelKit codebase. All new code and refactored code must follow these conventions. + +## 1. Acronyms in Class Names + +Domain acronyms in PascalCase class names **retain their uppercase form**, except for two-letter abbreviations used as generic prefixes. + +### Canonical Acronym Table + +| Acronym | Meaning | Class Casing | Example | +|---------|---------|--------------|---------| +| ONNX | Open Neural Network Exchange | `ONNX` | `ONNXStaticAnalyzer`, `ONNXLoader` | +| EP | Execution Provider | `EP` | `EPChecker`, `EPConfig`, `EPMonitor` | +| QDQ | Quantize-Dequantize | `QDQ` | `QDQParameterConfig`, `QDQGenerator` | +| QNN | Qualcomm Neural Network | `QNN` | `QNNMonitor` | +| Op | Operator (2-letter prefix) | `Op` | `OpUnsupportedError` | +| IO | Input/Output | `IO` | `IOConfigInfo` | +| HF | HuggingFace | `HF` | `HF_MODEL_CLASS_MAPPING` | +| HTP | Hexagon Tensor Processor | `HTP` | (directory/module level) | + +### Why `Op` Not `OP` + +Two-letter acronyms used as **class name prefixes** use PascalCase: + +- `OPUnsupported` reads ambiguously as three tokens (O-P-Unsupported) +- `OpUnsupported` reads clearly as two tokens (Op-Unsupported) +- Consistent with conventions like `Id` vs `ID` + +All-caps is acceptable in **constants** (e.g., `SUPPORTED_OPS`). + +## 2. Module and Package Names + +Follow PEP 8: all lowercase with underscores. + +``` +correct: onnx_op.py, ep_checker.py, qdq_fix.py +wrong: OnnxOp.py, EP_Checker.py +``` + +## 3. Function and Method Names + +Snake_case, lowercase. + +``` +correct: normalize_ep_name(), generate_build_config() +wrong: normalizeEPName(), GenerateBuildConfig() +``` + +## 4. Constants + +UPPER_CASE with underscores. + +``` +correct: SUPPORTED_EPS, EP_ALIASES, DEVICE_TO_DEVICE_TYPE +wrong: supportedEps, ep_aliases +``` + +## 5. Directory Abbreviation Policy + +The codebase uses a mix of abbreviated and full directory names. The established names are frozen — do not rename existing directories for consistency alone. For **new** directories, prefer full names unless the abbreviation is widely recognized in the domain (e.g., `optim`, `eval`, `quant`). + +| Established Abbreviation | Full Form | +|---|---| +| `optim` | optimization | +| `quant` | quantization | +| `eval` | evaluation | +| `sysinfo` | system information | +| `optracing` | operator tracing | + +## 6. Avoid Name Collisions Across Hierarchy + +Do not reuse a parent or sibling package name at a deeper level. When creating new subpackages, verify the name does not already exist elsewhere in the tree. + +Known collisions to be aware of: + +| Name | Locations | Issue | +|---|---|---| +| `winml` | top-level namespace, `modelkit/winml.py`, `models/winml/` | 3-level collision | +| `core` | `modelkit/core/`, `analyze/core/` | same name, different content | +| `models` | `modelkit/models/`, `analyze/models/` | ML models vs data models | +| `utils` | `modelkit/utils/`, `analyze/utils/` | no shared content | +| `pattern` | `modelkit/pattern/`, `analyze/pattern/` | active vs near-empty | +| `inspect` | `modelkit/inspect/` | shadows Python stdlib | + +## 7. Current Violations + +The following classes violate the acronym naming rules and should be renamed: + +| Current | Correct | File | +|---|---|---| +| `OnnxOP` | `ONNXOp` | `src/winml/modelkit/analyze/models/onnx_op.py` | +| `OnnxConfigNotFoundError` | `ONNXConfigNotFoundError` | `src/winml/modelkit/export/io.py` | +| `OnnxModelOutput` | `ONNXModelOutput` | `src/winml/modelkit/export/htp/metadata_builder.py` | +| `EpContextNodeChecker` | `EPContextNodeChecker` | `src/winml/modelkit/analyze/core/node_checkers/ep_context_node_checker.py` | +| `EpPackage` | `EPPackage` | `src/winml/modelkit/sysinfo/software.py` | +| `QdqFixResult` | `QDQFixResult` | `src/winml/modelkit/quant/qdq_fix.py` | +| `OPOptionalInputSupportError` | `OpOptionalInputSupportError` | `src/winml/modelkit/analyze/exceptions.py` | +| `OPLackOfRequiredInformationError` | `OpLackOfRequiredInformationError` | `src/winml/modelkit/analyze/exceptions.py` | +| `OPUnsupportedError` | `OpUnsupportedError` | `src/winml/modelkit/analyze/exceptions.py` | diff --git a/docs/pytest-best-practices.md b/docs/pytest-best-practices.md new file mode 100644 index 000000000..30142d2df --- /dev/null +++ b/docs/pytest-best-practices.md @@ -0,0 +1,2832 @@ +# Complete Pytest Best Practices Guide (2025) + +A comprehensive guide covering all aspects of pytest, from basic usage to advanced patterns and project organization. + +## Table of Contents + +1. [Project Structure & Organization](#project-structure--organization) +2. [Test Discovery & Naming Conventions](#test-discovery--naming-conventions) +3. [Fixtures: The Heart of Pytest](#fixtures-the-heart-of-pytest) +4. [Markers & Test Categorization](#markers--test-categorization) +5. [Parametrization: Data-Driven Testing](#parametrization-data-driven-testing) +6. [Assertions & Error Handling](#assertions--error-handling) +7. [Configuration & Settings](#configuration--settings) +8. [Conftest.py: Shared Test Logic](#conftest-py-shared-test-logic) +9. [Mocking & Monkeypatching](#mocking--monkeypatching) +10. [Database Testing Patterns](#database-testing-patterns) +11. [Performance & Optimization](#performance--optimization) +12. [CI/CD Integration](#cicd-integration) +13. [Plugin Ecosystem](#plugin-ecosystem) +14. [Snapshot & Regression Testing](#snapshot--regression-testing) +15. [Property-Based Testing with Hypothesis](#property-based-testing-with-hypothesis) +16. [Test Asset Generation & Management](#test-asset-generation--management) +17. [Common Patterns & Anti-Patterns](#common-patterns--anti-patterns) +18. [Debugging & Troubleshooting](#debugging--troubleshooting) +19. [Best Practices Checklist](#best-practices-checklist) + +--- + +## Project Structure & Organization + +### Recommended Layout + +``` +project/ +├── src/ # Source code +│ └── myproject/ +│ ├── __init__.py +│ ├── core/ +│ │ ├── __init__.py +│ │ └── engine.py +│ ├── utils/ +│ │ ├── __init__.py +│ │ └── helpers.py +│ └── api/ +│ ├── __init__.py +│ └── endpoints.py +├── tests/ # Test directory +│ ├── __init__.py # Makes tests a package (optional - see note below) +│ ├── conftest.py # Shared fixtures and configuration +│ ├── unit/ # Unit tests +│ │ ├── __init__.py +│ │ ├── test_engine.py +│ │ └── test_helpers.py +│ ├── integration/ # Integration tests +│ │ ├── __init__.py +│ │ └── test_api.py +│ ├── e2e/ # End-to-end tests +│ │ ├── __init__.py +│ │ └── test_workflows.py +│ └── fixtures/ # Shared test data/utilities +│ ├── __init__.py +│ └── test_data.py +├── pyproject.toml # Modern Python project config (preferred) +├── pytest.ini # Legacy pytest configuration (avoid) +├── .coveragerc # Coverage configuration +└── tox.ini # Multiple environment testing +``` + +### Key Principles + +1. **Mirror Source Structure**: Test directory structure should mirror your source code +2. **Separate Test Types**: Keep unit, integration, and e2e tests in separate directories +3. **`__init__.py` in Tests**: Optional - use only when you need to import between test modules (see detailed explanation below) +4. **Centralize Fixtures**: Use `conftest.py` for shared fixtures + +### Should You Use `__init__.py` in Test Directories? + +The use of `__init__.py` in test directories is **optional** and depends on your specific needs: + +#### When to USE `__init__.py` in tests ✅ + +1. **Cross-test imports**: When you need to import helper functions or classes between test modules + ```python + # tests/unit/test_user.py + from tests.helpers.factories import UserFactory # Requires __init__.py + ``` + +2. **Test utilities as a package**: When you have reusable test utilities that need to be imported + ``` + tests/ + ├── __init__.py + ├── helpers/ + │ ├── __init__.py + │ ├── factories.py + │ └── assertions.py + ``` + +3. **Namespace packages**: When you need to avoid naming conflicts with application modules + ```python + # Disambiguates tests.models from myapp.models + from tests.models import TestUser + from myapp.models import User + ``` + +#### When NOT to use `__init__.py` in tests ❌ + +1. **Simple test structures**: Most projects don't need it - pytest discovers tests without it +2. **Import mode conflicts**: Can cause issues with pytest's import mechanisms +3. **Accidental test collection**: May cause pytest to collect non-test files + +#### Best Practice Recommendation + +**Default approach**: Start WITHOUT `__init__.py` in test directories. Only add it when you have a specific need for cross-test imports or test utilities. + +``` +# Recommended minimal structure +tests/ +├── conftest.py # Shared fixtures (no __init__.py needed) +├── unit/ +│ └── test_models.py # Tests work without __init__.py +└── integration/ + └── test_api.py +``` + +#### pytest.ini Configuration for Import Issues + +If you encounter import issues, configure pytest's import mode instead of adding `__init__.py`: + +```ini +# pytest.ini +[pytest] +# Use importlib mode for better import handling +import_mode = importlib + +# Or use prepend mode (default) +import_mode = prepend +``` + +### Alternative Layouts + +#### Tests Outside Application Code (Recommended) +``` +project/ +├── src/myproject/ +└── tests/ +``` + +#### Tests as Part of Application (Less Common) +``` +project/ +└── myproject/ + ├── core/ + │ ├── engine.py + │ └── tests/ + │ └── test_engine.py + └── utils/ + ├── helpers.py + └── tests/ + └── test_helpers.py +``` + +--- + +## Test Discovery & Naming Conventions + +### Default Discovery Rules + +Pytest automatically discovers tests following these patterns: + +- **Test files**: `test_*.py` or `*_test.py` +- **Test classes**: `Test*` (must not have an `__init__` method) +- **Test functions**: `test_*` +- **Test methods**: `test_*` inside `Test*` classes + +### Naming Best Practices + +```python +# ❌ Bad: Unclear test names +def test_1(): + pass + +def test_user(): + pass + +def test_function(): + pass + +# ✅ Good: Descriptive test names +def test_user_creation_with_valid_email(): + """Test that a user can be created with a valid email address.""" + pass + +def test_user_creation_fails_with_duplicate_email(): + """Test that creating a user with an existing email raises an error.""" + pass + +def test_password_reset_sends_email_to_registered_user(): + """Test that password reset email is sent to registered users.""" + pass +``` + +### Test Class Organization + +```python +class TestUserAuthentication: + """Test cases for user authentication functionality.""" + + def test_login_with_valid_credentials_returns_token(self): + """Test successful login returns authentication token.""" + pass + + def test_login_with_invalid_password_returns_401(self): + """Test login with wrong password returns 401 status.""" + pass + + def test_login_with_nonexistent_user_returns_404(self): + """Test login with non-existent user returns 404 status.""" + pass +``` + +### Custom Discovery Configuration + +```ini +# pytest.ini +[pytest] +# Custom patterns for test discovery +python_files = test_*.py check_*.py +python_classes = Test* Check* +python_functions = test_* check_* + +# Ignore specific directories +norecursedirs = .git .tox build dist *.egg +``` + +--- + +## Fixtures: The Heart of Pytest + +### Basic Fixture Concepts + +```python +import pytest + +# Simple fixture +@pytest.fixture +def sample_data(): + """Provide sample data for tests.""" + return {"name": "John", "age": 30} + +# Fixture with teardown +@pytest.fixture +def database_connection(): + """Create database connection and clean up after test.""" + conn = create_connection() + yield conn # This is where the test runs + conn.close() # Teardown happens after test + +# Using fixtures in tests +def test_user_data(sample_data): + assert sample_data["name"] == "John" +``` + +### Fixture Scopes + +```python +# Function scope (default) - run once per test function +@pytest.fixture(scope="function") +def function_resource(): + return expensive_setup() + +# Class scope - run once per test class +@pytest.fixture(scope="class") +def class_resource(): + return expensive_setup() + +# Module scope - run once per module +@pytest.fixture(scope="module") +def module_resource(): + return expensive_setup() + +# Session scope - run once per test session +@pytest.fixture(scope="session") +def session_resource(): + return expensive_setup() + +# Package scope - run once per package +@pytest.fixture(scope="package") +def package_resource(): + return expensive_setup() +``` + +### Advanced Fixture Patterns + +#### Factory Fixtures +```python +@pytest.fixture +def make_user(): + """Factory fixture for creating users.""" + created_users = [] + + def _make_user(name, email=None): + user = User(name=name, email=email or f"{name}@example.com") + created_users.append(user) + return user + + yield _make_user + + # Cleanup all created users + for user in created_users: + user.delete() + +def test_user_interactions(make_user): + alice = make_user("alice") + bob = make_user("bob", "bob@company.com") + assert alice.can_message(bob) +``` + +#### Parametrized Fixtures +```python +@pytest.fixture(params=["sqlite", "postgresql", "mysql"]) +def database(request): + """Test with multiple database backends.""" + return setup_database(request.param) + +def test_query_performance(database): + # This test runs three times, once for each database + result = database.execute("SELECT * FROM users") + assert result.execution_time < 100 # ms +``` + +#### Dynamic Fixture Scope +```python +def determine_scope(fixture_name, config): + """Dynamically determine fixture scope based on config.""" + if config.getoption("--quick", None): + return "session" # Reuse fixtures for speed + return "function" # Fresh fixtures for isolation + +@pytest.fixture(scope=determine_scope) +def api_client(): + return APIClient() +``` + +#### Fixture Dependencies +```python +@pytest.fixture +def config(): + return load_config() + +@pytest.fixture +def database(config): + return Database(config["db_url"]) + +@pytest.fixture +def api_client(config, database): + # Fixtures can depend on other fixtures + return APIClient(config["api_url"], database) +``` + +### Auto-use Fixtures + +```python +@pytest.fixture(autouse=True) +def reset_global_state(): + """Automatically run before each test without explicit request.""" + clear_caches() + reset_singletons() + yield + # Cleanup happens after test + +@pytest.fixture(autouse=True, scope="session") +def configure_test_environment(): + """Set up test environment once for entire session.""" + os.environ["TESTING"] = "true" + configure_logging("debug") +``` + +### Fixture Finalization + +```python +@pytest.fixture +def resource_with_finalizer(request): + """Using request.addfinalizer for cleanup.""" + resource = acquire_resource() + + def cleanup(): + release_resource(resource) + + request.addfinalizer(cleanup) + return resource + +# Equivalent using yield +@pytest.fixture +def resource_with_yield(): + """Using yield for cleanup (preferred).""" + resource = acquire_resource() + yield resource + release_resource(resource) +``` + +--- + +## Markers & Test Categorization + +### Built-in Markers + +```python +import pytest +import sys + +# Skip marker +@pytest.mark.skip(reason="Not implemented yet") +def test_future_feature(): + pass + +# Conditional skip +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+") +def test_pattern_matching(): + match value: + case 1: return "one" + case _: return "other" + +# Expected failure +@pytest.mark.xfail(reason="Known bug #123") +def test_known_issue(): + assert buggy_function() == expected_value + +# Strict xfail - fails if test passes +@pytest.mark.xfail(strict=True, reason="Should be fixed in v2.0") +def test_upcoming_fix(): + assert new_feature() == expected + +# Platform-specific tests +@pytest.mark.skipif(sys.platform != "linux", reason="Linux only test") +def test_linux_specific(): + pass + +# Import skip +def test_optional_dependency(): + numpy = pytest.importorskip("numpy", minversion="1.20.0") + # Test only runs if numpy >= 1.20.0 is available +``` + +### Custom Markers + +```ini +# pytest.ini - Register custom markers +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + smoke: core functionality that must always work + integration: requires external services + unit: fast isolated unit tests + flaky: tests that occasionally fail + requires_db: tests that need database access + requires_network: tests that need network access +``` + +```python +# Using custom markers +@pytest.mark.slow +@pytest.mark.integration +def test_full_workflow(): + """Test complete user workflow with external services.""" + pass + +@pytest.mark.smoke +def test_critical_functionality(): + """Test that must always pass.""" + pass + +# Multiple markers +@pytest.mark.unit +@pytest.mark.smoke +def test_core_logic(): + """Fast unit test for critical functionality.""" + pass +``` + +### Marker Expressions + +```bash +# Run only smoke tests +pytest -m smoke + +# Run all tests except slow ones +pytest -m "not slow" + +# Complex expressions +pytest -m "smoke and not slow" +pytest -m "(unit or integration) and not flaky" + +# List all markers +pytest --markers +``` + +### Applying Markers Dynamically + +```python +# In conftest.py +def pytest_collection_modifyitems(items): + """Dynamically add markers during collection.""" + for item in items: + # Add marker based on test location + if "integration" in str(item.fspath): + item.add_marker(pytest.mark.integration) + + # Add marker based on test name + if "slow" in item.name: + item.add_marker(pytest.mark.slow) +``` + +--- + +## Parametrization: Data-Driven Testing + +### Basic Parametrization + +```python +import pytest + +# Single parameter +@pytest.mark.parametrize("number", [1, 2, 3, 4, 5]) +def test_square(number): + assert number ** 2 == number * number + +# Multiple parameters +@pytest.mark.parametrize("input,expected", [ + (2, 4), + (3, 9), + (4, 16), + (-2, 4), +]) +def test_square_with_expected(input, expected): + assert input ** 2 == expected + +# Using test IDs for better output +@pytest.mark.parametrize("input,expected", [ + (2, 4), + (3, 9), + (-2, 4), +], ids=["positive_2", "positive_3", "negative_2"]) +def test_square_with_ids(input, expected): + assert input ** 2 == expected + +# ID function +def idfn(val): + return f"num_{val}" + +@pytest.mark.parametrize("number", [1, 2, 3], ids=idfn) +def test_with_id_function(number): + assert number > 0 +``` + +### Advanced Parametrization + +```python +# Nested parametrization +@pytest.mark.parametrize("x", [1, 2]) +@pytest.mark.parametrize("y", [10, 20]) +def test_multiplication(x, y): + # Runs 4 times: (1,10), (1,20), (2,10), (2,20) + assert x * y == y * x + +# Parametrize with marks +@pytest.mark.parametrize("test_input,expected", [ + ("3+5", 8), + ("2+4", 6), + pytest.param("6*9", 42, marks=pytest.mark.xfail(reason="Hitchhiker's joke")), + pytest.param("1/0", 0, marks=pytest.mark.skip(reason="Division by zero")), +]) +def test_eval(test_input, expected): + assert eval(test_input) == expected + +# Indirect parametrization (parametrize fixtures) +@pytest.mark.parametrize("db_name", ["sqlite", "postgres"], indirect=True) +def test_database_operations(db_name): + # db_name fixture receives the parameter value + assert db_name.connect() +``` + +### Parametrization Patterns + +```python +# Test class parametrization +@pytest.mark.parametrize("browser", ["chrome", "firefox", "safari"]) +class TestWebApplication: + def test_login(self, browser): + # Each test method runs with each browser + pass + + def test_search(self, browser): + pass + +# Dynamic parametrization +def pytest_generate_tests(metafunc): + """Dynamically parametrize tests.""" + if "dynamic_value" in metafunc.fixturenames: + values = load_test_values_from_file() + metafunc.parametrize("dynamic_value", values) + +# Parametrization from fixtures +@pytest.fixture(params=["admin", "user", "guest"]) +def user_role(request): + return create_user_with_role(request.param) + +def test_permissions(user_role): + # Test runs for each user role + assert user_role.can_access("/dashboard") == user_role.is_admin +``` + +--- + +## Assertions & Error Handling + +### Enhanced Assertions + +```python +# Pytest rewrites assert statements for better output +def test_assertion_introspection(): + data = {"name": "Alice", "items": [1, 2, 3]} + # Pytest shows detailed diff on failure + assert data == {"name": "Bob", "items": [1, 2, 3]} + +# Custom assertion messages +def test_with_message(): + result = complex_calculation() + assert result > 0, f"Expected positive result, got {result}" +``` + +### Exception Testing + +```python +import pytest + +# Basic exception testing +def test_raises_exception(): + with pytest.raises(ValueError): + raise ValueError("Invalid value") + +# Check exception message +def test_exception_message(): + with pytest.raises(ValueError, match="Invalid.*value"): + raise ValueError("Invalid value provided") + +# Access exception info +def test_exception_info(): + with pytest.raises(ValueError) as exc_info: + raise ValueError("test error") + + assert str(exc_info.value) == "test error" + assert exc_info.type == ValueError + +# Test multiple exceptions (ExceptionGroup) +def test_exception_group(): + with pytest.raises(ExceptionGroup) as exc_info: + raise ExceptionGroup("errors", [ + ValueError("error 1"), + TypeError("error 2") + ]) + + assert len(exc_info.value.exceptions) == 2 +``` + +### Warning Testing + +```python +import warnings +import pytest + +def test_warns(): + with pytest.warns(UserWarning): + warnings.warn("This is a warning", UserWarning) + +def test_warns_with_match(): + with pytest.warns(DeprecationWarning, match="deprecated"): + warnings.warn("This function is deprecated", DeprecationWarning) + +def test_no_warnings(): + # Ensure no warnings are raised + with warnings.catch_warnings(): + warnings.simplefilter("error") + clean_function() # Should not raise any warnings +``` + +### Approximate Comparisons + +```python +import pytest + +def test_float_comparison(): + assert 0.1 + 0.2 == pytest.approx(0.3) + +def test_list_approximate(): + assert [0.1 + 0.2, 0.2 + 0.4] == pytest.approx([0.3, 0.6]) + +def test_dict_approximate(): + assert {"a": 0.1 + 0.2} == pytest.approx({"a": 0.3}) + +# Custom tolerance +def test_custom_tolerance(): + assert 1.0001 == pytest.approx(1.0, rel=1e-3) + assert 1.0001 == pytest.approx(1.0, abs=1e-3) +``` + +--- + +## Configuration & Settings + +### Configuration File Priority (Critical Knowledge) + +Understanding configuration file priority is essential for debugging pytest configuration issues. + +**Priority Order** (first match wins - configurations are NEVER merged): + +| Priority | File | Notes | +|----------|------|-------| +| 1 (Highest) | `pytest.toml` / `.pytest.toml` | New in pytest 9.0, native TOML | +| 2 | `pytest.ini` / `.pytest.ini` | Classic pytest config | +| 3 | `pyproject.toml` | Modern Python project standard | +| 4 | `tox.ini` | Tox integration | +| 5 (Lowest) | `setup.cfg` | Legacy, not recommended | + +> ⚠️ **Critical Gotcha**: If an empty `pytest.ini` file exists in your project, ALL settings in `pyproject.toml` will be ignored! This is a common source of confusion. Delete any empty `pytest.ini` files. + +**Configuration Sections by File Type**: + +| File Type | Section Name | +|-----------|--------------| +| pytest.ini | `[pytest]` | +| pyproject.toml (pytest 6.0-8.x) | `[tool.pytest.ini_options]` | +| pyproject.toml (pytest 9.0+) | `[tool.pytest]` | +| tox.ini | `[pytest]` | +| setup.cfg | `[tool:pytest]` | + +**Best Practice**: Use `pyproject.toml` as your single source of truth for all Python tooling configuration (pytest, ruff, mypy, etc.). + +### pyproject.toml Configuration (Recommended) + +Using `pyproject.toml` is the modern, preferred approach for Python project configuration. It consolidates all project metadata and tool configurations in one place. + +```toml +# pyproject.toml +[tool.pytest.ini_options] +# Minimum pytest version +minversion = "7.0" + +# Default command line options +addopts = [ + "--strict-markers", # Fail on unknown markers + "--strict-config", # Fail on config errors + "--import-mode=importlib", # Use standard import system (recommended) + "--verbose", # Verbose output + "-ra", # Show all test outcomes + "--cov=myproject", # Coverage for your project + "--cov-report=html", # HTML coverage report + "--cov-report=term-missing", # Terminal report with missing lines +] + +> 💡 **Recommended**: Always include `--import-mode=importlib` in your `addopts`. This uses Python's standard import system instead of modifying `sys.path`, avoiding common import issues. This has been the default since pytest 6.0 but explicitly setting it ensures consistent behavior. + +# Test discovery +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*", "*Tests"] +python_functions = ["test_*"] + +# Python path configuration +pythonpath = ["src"] + +# Import mode (importlib is recommended for most projects) +import_mode = "importlib" + +# Custom markers registration +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: requires external services", + "unit: fast isolated unit tests", + "smoke: core functionality that must always work", + "flaky: tests that occasionally fail", + "requires_network: tests that need network access", +] + +# Output configuration +console_output_style = "progress" + +# Directories to ignore +norecursedirs = [".git", ".tox", "dist", "build", "*.egg", "__pycache__"] + +# Logging configuration +log_cli = true +log_cli_level = "INFO" +log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" + +# Warning filters +filterwarnings = [ + "error", # Turn warnings into errors + "ignore::UserWarning", # Ignore user warnings + "ignore::DeprecationWarning", # Ignore deprecation warnings + "default:.*deprecated.*:DeprecationWarning", # Show deprecation warnings with "deprecated" in message +] + +# Required plugins +required_plugins = [ + "pytest-cov>=4.0", +] + +# Test timeout (requires pytest-timeout) +timeout = 300 +timeout_method = "thread" + +# Strict xfail +xfail_strict = true + +# Asyncio configuration (requires pytest-asyncio) +asyncio_mode = "auto" + +# Coverage configuration (can also be in [tool.coverage]) +[tool.coverage.run] +source = ["myproject"] +omit = [ + "*/tests/*", + "*/venv/*", + "*/.venv/*", + "*/migrations/*", + "*/__pycache__/*", + "*/.pytest_cache/*", +] + +[tool.coverage.report] +precision = 2 +show_missing = true +skip_covered = false +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", +] + +[tool.coverage.html] +directory = "htmlcov" + +[tool.coverage.xml] +output = "coverage.xml" +``` + +### Complete pyproject.toml Example + +Here's a complete `pyproject.toml` that includes project metadata along with pytest configuration: + +```toml +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "myproject" +version = "1.0.0" +description = "My awesome project" +readme = "README.md" +requires-python = ">=3.8" +license = {text = "MIT"} +authors = [ + {name = "Your Name", email = "you@example.com"}, +] +dependencies = [ + "requests>=2.28.0", + "pydantic>=2.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.10.0", + "pytest-asyncio>=0.21.0", + "pytest-timeout>=2.1.0", + "pytest-xdist>=3.0.0", + "black>=23.0.0", + "ruff>=0.1.0", + "mypy>=1.0.0", +] + +[project.urls] +Homepage = "https://github.com/username/myproject" +Documentation = "https://myproject.readthedocs.io" +Repository = "https://github.com/username/myproject.git" +Issues = "https://github.com/username/myproject/issues" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +# ... (configuration from above) + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] +include = '\.pyi?$' + +[tool.ruff] +line-length = 88 +target-version = "py38" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade +] + +[tool.mypy] +python_version = "3.8" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +``` + +### Migration from pytest.ini to pyproject.toml + +If you have an existing `pytest.ini`, here's how to migrate: + +```ini +# OLD: pytest.ini +[pytest] +markers = + slow: slow tests +testpaths = tests +``` + +Becomes: + +```toml +# NEW: pyproject.toml +[tool.pytest.ini_options] +markers = [ + "slow: slow tests", +] +testpaths = ["tests"] +``` + +### pytest 9.0+ Native TOML Configuration + +Starting with pytest 9.0, you can use the native `[tool.pytest]` table which provides cleaner TOML syntax: + +```toml +# pytest 9.0+ (native TOML arrays - cleaner syntax) +[tool.pytest] +minversion = "9.0" + +# Test discovery +testpaths = ["tests"] +pythonpath = ["."] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +norecursedirs = [".git", ".tox", "dist", "build", ".venv", "__pycache__"] + +# Command line options (native TOML arrays) +addopts = [ + "--strict-markers", + "--strict-config", + "--import-mode=importlib", + "-ra", + "--tb=short", +] + +# Markers +markers = [ + "slow: marks tests as slow", + "integration: integration tests", +] + +# Warning filters +filterwarnings = [ + "error", + "ignore::DeprecationWarning", +] + +# Required plugins +required_plugins = [ + "pytest-cov>=4.0", +] +``` + +**Benefits over `[tool.pytest.ini_options]`**: +- Native TOML array syntax (clearer than space-separated strings in some cases) +- Better TOML type support +- Future-proof configuration format +- Reserved by pytest team for enhanced features + +**Migration**: Simply rename `[tool.pytest.ini_options]` to `[tool.pytest]` when upgrading to pytest 9.0+. + +### Legacy pytest.ini (Not Recommended) + +While `pytest.ini` still works, it's considered legacy. Use `pyproject.toml` instead for these benefits: +- Single configuration file for all Python tools +- Better IDE support +- TOML format is more readable +- Standardized by PEP 518 and PEP 621 + +### Command Line Configuration + +```bash +# Common command line options +pytest -v # Verbose output +pytest -q # Quiet output +pytest -s # No capture, show print statements +pytest -x # Stop on first failure +pytest --maxfail=3 # Stop after 3 failures +pytest -k "user" # Run tests matching "user" +pytest -m "not slow" # Run tests not marked as slow +pytest --lf # Run last failed tests +pytest --ff # Run failed tests first +pytest --tb=short # Short traceback format +pytest --tb=no # No traceback +pytest --setup-show # Show fixture setup/teardown +pytest --fixtures # Show available fixtures +pytest --markers # Show available markers +pytest --collect-only # Only collect tests, don't run +pytest --cache-clear # Clear cache before run +pytest --doctest-modules # Run doctests +pytest --cov=myproject # Coverage report +pytest --cov-report=html # HTML coverage report +pytest --durations=10 # Show 10 slowest tests +pytest --pdb # Drop to debugger on failure +pytest --pdbcls=IPython.terminal.debugger:TerminalPdb # Use IPython debugger +``` + +--- + +## Conftest.py: Shared Test Logic + +### Fixture Sharing + +```python +# tests/conftest.py - Available to all tests +import pytest +import tempfile +from pathlib import Path + +@pytest.fixture(scope="session") +def test_data_dir(): + """Shared test data directory.""" + return Path(__file__).parent / "data" + +@pytest.fixture +def temp_dir(): + """Create temporary directory for test.""" + with tempfile.TemporaryDirectory() as tmp: + yield Path(tmp) + +# tests/unit/conftest.py - Available to unit tests only +@pytest.fixture +def mock_database(): + """Mock database for unit tests.""" + return MockDatabase() + +# tests/integration/conftest.py - Available to integration tests only +@pytest.fixture(scope="module") +def real_database(): + """Real database connection for integration tests.""" + db = Database() + yield db + db.cleanup() +``` + +### Hooks in conftest.py + +```python +# Modify test collection +def pytest_collection_modifyitems(config, items): + """Modify test collection.""" + # Add markers based on test file location + for item in items: + # Add markers based on location + if "integration" in str(item.fspath): + item.add_marker(pytest.mark.integration) + + # Skip tests based on environment + if "requires_gpu" in item.keywords and not has_gpu(): + item.add_marker(pytest.mark.skip(reason="GPU not available")) + +# Custom command line options +def pytest_addoption(parser): + """Add custom command line options.""" + parser.addoption( + "--run-slow", + action="store_true", + default=False, + help="Run slow tests" + ) + parser.addoption( + "--integration", + action="store_true", + default=False, + help="Run integration tests" + ) + +# Configure based on options +def pytest_configure(config): + """Configure pytest based on command line options.""" + if config.getoption("--run-slow"): + config.option.markexpr = "slow" + +# Custom markers registration +def pytest_configure(config): + config.addinivalue_line( + "markers", "slow: marks tests as slow" + ) +``` + +#### Hook Execution Order Control + +Control when your hooks run relative to other plugins: + +```python +@pytest.hookimpl(tryfirst=True) +def pytest_collection_modifyitems(items): + """Execute BEFORE other implementations.""" + # Priority operations here + pass + +@pytest.hookimpl(trylast=True) +def pytest_collection_modifyitems(items): + """Execute AFTER other implementations.""" + # Cleanup or final modifications + pass +``` + +#### Wrapper Hooks (Advanced) + +Wrap other hook implementations for cross-cutting concerns: + +```python +@pytest.hookimpl(wrapper=True) +def pytest_runtest_makereport(item, call): + """Wrap report generation for custom handling.""" + # Code before other hooks run + outcome = yield # Run wrapped hooks + report = outcome.get_result() + + # Code after - modify or log report + if report.when == "call" and report.failed: + # Handle test failure + log_failure(item.nodeid, report.longreprtext) + + return report + +@pytest.hookimpl(wrapper=True, tryfirst=True) +def pytest_runtest_setup(item): + """Wrap setup with timing.""" + start = time.time() + yield # Run actual setup + duration = time.time() - start + item.setup_duration = duration +``` + +#### Storing Data Across Hooks + +Use `item.stash` for type-safe data storage: + +```python +from pytest import StashKey + +# Define typed keys +phase_report_key = StashKey[dict]() +timing_key = StashKey[float]() + +@pytest.hookimpl(wrapper=True) +def pytest_runtest_makereport(item, call): + """Store reports for fixture access.""" + outcome = yield + report = outcome.get_result() + + # Store in stash (type-safe) + item.stash.setdefault(phase_report_key, {})[report.when] = report + return report + +@pytest.fixture +def test_outcome(request): + """Fixture to access test outcome.""" + yield + report = request.node.stash.get(phase_report_key, {}).get("call") + if report and report.failed: + # Handle failure in fixture teardown + pass +``` + +#### Custom Report Sections + +Add extra information to test reports: + +```python +@pytest.hookimpl(tryfirst=True, wrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + report = outcome.get_result() + + # Add custom sections to report + if report.when == "call": + report.sections.append( + ("Custom Info", f"Test: {item.nodeid}\nDuration: {call.duration:.2f}s") + ) + + return report +``` + +### Plugin Registration + +```python +# Register external plugins +pytest_plugins = [ + "myproject.testing.fixtures", + "myproject.testing.helpers", +] + +# Conditional plugin loading +import sys +if sys.platform.startswith("win"): + pytest_plugins.append("myproject.testing.windows") +``` + +--- + +## Mocking & Monkeypatching + +### Using pytest-mock + +```python +# Install: pip install pytest-mock + +def test_with_mock(mocker): + """Using pytest-mock plugin.""" + # Mock a module function + mock_func = mocker.patch("mymodule.function") + mock_func.return_value = 42 + + # Mock an object method + mock_method = mocker.patch.object(MyClass, "method") + mock_method.return_value = "mocked" + + # Spy on a function + spy = mocker.spy(mymodule, "function") + mymodule.function() + spy.assert_called_once() + +# Using side effects +def test_side_effects(mocker): + mock = mocker.patch("mymodule.function") + mock.side_effect = [1, 2, 3] # Returns different values each call + + assert mymodule.function() == 1 + assert mymodule.function() == 2 + assert mymodule.function() == 3 + +# Mock with exceptions +def test_mock_exception(mocker): + mock = mocker.patch("mymodule.function") + mock.side_effect = ValueError("Error!") + + with pytest.raises(ValueError): + mymodule.function() +``` + +### Monkeypatch + +```python +def test_monkeypatch_env(monkeypatch): + """Monkeypatch environment variables.""" + monkeypatch.setenv("API_KEY", "test-key") + monkeypatch.delenv("OLD_VAR", raising=False) + + assert os.environ["API_KEY"] == "test-key" + assert "OLD_VAR" not in os.environ + +def test_monkeypatch_attribute(monkeypatch): + """Monkeypatch object attributes.""" + class MyClass: + value = 10 + + obj = MyClass() + monkeypatch.setattr(obj, "value", 20) + assert obj.value == 20 + +def test_monkeypatch_module(monkeypatch): + """Monkeypatch module functions.""" + import time + + def mock_time(): + return 123456.0 + + monkeypatch.setattr(time, "time", mock_time) + assert time.time() == 123456.0 + +def test_monkeypatch_dict(monkeypatch): + """Monkeypatch dictionary items.""" + config = {"url": "production.com"} + monkeypatch.setitem(config, "url", "test.com") + assert config["url"] == "test.com" +``` + +### Advanced Mocking Patterns + +```python +# Context manager mocking +def test_context_manager(mocker): + mock_cm = mocker.MagicMock() + mock_cm.__enter__.return_value = "resource" + mock_cm.__exit__.return_value = None + + mocker.patch("mymodule.get_resource", return_value=mock_cm) + + with mymodule.get_resource() as resource: + assert resource == "resource" + + mock_cm.__enter__.assert_called_once() + mock_cm.__exit__.assert_called_once() + +# Property mocking +def test_property_mock(mocker): + mock_property = mocker.PropertyMock(return_value=42) + mocker.patch("mymodule.MyClass.my_property", new_callable=mock_property) + + obj = mymodule.MyClass() + assert obj.my_property == 42 + mock_property.assert_called_once() + +# Async mocking +async def test_async_mock(mocker): + mock_async = mocker.AsyncMock(return_value="async result") + mocker.patch("mymodule.async_function", mock_async) + + result = await mymodule.async_function() + assert result == "async result" + mock_async.assert_awaited_once() +``` + +--- + +## Database Testing Patterns + +Testing database interactions requires careful isolation and cleanup strategies. + +### Transaction-Based Isolation + +The most reliable approach is rolling back transactions after each test: + +```python +import pytest + +@pytest.fixture +def db_session(engine): + """Create a transactional test session.""" + connection = engine.connect() + transaction = connection.begin() + session = Session(bind=connection) + + yield session + + session.close() + transaction.rollback() + connection.close() + +def test_user_creation(db_session): + """Test runs in transaction that gets rolled back.""" + user = User(name="test") + db_session.add(user) + db_session.flush() + + assert user.id is not None + # Transaction rolled back - no cleanup needed +``` + +### pytest-django Database Access + +```python +import pytest + +# Mark test to enable database access +@pytest.mark.django_db +def test_user_creation(): + User.objects.create(username="testuser") + assert User.objects.count() == 1 + +# Transaction testing (for testing transaction behavior) +@pytest.mark.django_db(transaction=True) +def test_atomic_operations(): + with transaction.atomic(): + User.objects.create(username="user1") + # Test atomic behavior + +# Multiple database support +@pytest.mark.django_db(databases=["default", "secondary"]) +def test_multi_db(): + User.objects.using("secondary").create(username="remote_user") +``` + +### Database Blocker Pattern + +Control database access at fixture level: + +```python +@pytest.fixture +def setup_data(django_db_blocker): + """Fixture that needs temporary DB access.""" + with django_db_blocker.unblock(): + # Database operations allowed here + User.objects.create(username="fixture_user") + # Database blocked again outside context + +@pytest.fixture +def no_db_fixture(django_db_blocker): + """Ensure no accidental DB access.""" + with django_db_blocker.block(): + yield # DB access will raise error +``` + +### Query Count Assertions + +Prevent N+1 query issues: + +```python +def test_efficient_queries(django_assert_num_queries): + """Assert exact number of queries.""" + with django_assert_num_queries(3): + list(User.objects.all()) + list(Post.objects.all()) + list(Comment.objects.all()) + +def test_max_queries(django_assert_max_num_queries): + """Assert maximum query count.""" + with django_assert_max_num_queries(5): + # Complex operation that should be efficient + process_users() +``` + +### SQLAlchemy Testing Patterns + +```python +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +@pytest.fixture(scope="session") +def engine(): + """Create test database engine.""" + return create_engine("sqlite:///:memory:") + +@pytest.fixture(scope="session") +def tables(engine): + """Create all tables.""" + Base.metadata.create_all(engine) + yield + Base.metadata.drop_all(engine) + +@pytest.fixture +def db_session(engine, tables): + """Create a new database session for each test.""" + connection = engine.connect() + transaction = connection.begin() + Session = sessionmaker(bind=connection) + session = Session() + + yield session + + session.close() + transaction.rollback() + connection.close() +``` + +### Factory Pattern for Test Data + +```python +import pytest +from factory import Factory, Faker, SubFactory + +class UserFactory(Factory): + class Meta: + model = User + + username = Faker("user_name") + email = Faker("email") + +class PostFactory(Factory): + class Meta: + model = Post + + title = Faker("sentence") + author = SubFactory(UserFactory) + +@pytest.fixture +def user_factory(db_session): + """Factory fixture for creating test users.""" + def _create_user(**kwargs): + user = UserFactory.build(**kwargs) + db_session.add(user) + db_session.flush() + return user + return _create_user + +def test_user_posts(user_factory): + author = user_factory(username="author") + post = PostFactory.build(author=author) + assert post.author.username == "author" +``` + +--- + +## Performance & Optimization + +### Parallel Execution with pytest-xdist + +```bash +# Install pytest-xdist +pip install pytest-xdist +``` + +#### Basic Usage + +```bash +pytest -n auto # Use all available CPUs +pytest -n 4 # Use 4 workers +pytest -n logical # Use logical cores (requires psutil) +``` + +#### Distribution Strategies + +Understanding distribution strategies is critical for efficient parallel testing: + +```bash +# Load balancing (default) - distributes tests as workers become available +pytest -n auto --dist load + +# Group by scope - keeps tests sharing fixtures on same worker (RECOMMENDED) +pytest -n auto --dist loadscope + +# Group by file - all tests in a file run on same worker +pytest -n auto --dist loadfile + +# Each test runs on every worker (for environment-specific testing) +pytest -n 2 --dist each +``` + +**When to Use Each Strategy**: + +| Strategy | Use Case | Performance | +|----------|----------|-------------| +| `load` | Independent tests, no shared state | Best parallelization | +| `loadscope` | Tests sharing expensive fixtures | Balanced (recommended default) | +| `loadfile` | File-level isolation needed | Good for integration tests | +| `each` | Multi-environment testing | Multiplies test count | + +#### Grouping Tests with xdist_group Marker + +Force related tests to run on the same worker: + +```python +import pytest + +@pytest.mark.xdist_group(name="database") +def test_create_user(): + """Runs on same worker as other 'database' group tests.""" + db.create_user("alice") + +@pytest.mark.xdist_group(name="database") +def test_query_user(): + """Guaranteed same worker as test_create_user.""" + user = db.get_user("alice") + assert user is not None + +@pytest.mark.xdist_group(name="api") +def test_api_endpoint(): + """Runs on potentially different worker.""" + pass +``` + +#### Session-Scoped Fixtures with Parallel Execution + +Session-scoped fixtures require special handling in parallel execution to avoid race conditions: + +```python +import json +from pathlib import Path +from filelock import FileLock # pip install filelock + +@pytest.fixture(scope="session") +def expensive_shared_data(tmp_path_factory, worker_id): + """Thread-safe session fixture for parallel execution.""" + # Single worker mode - no synchronization needed + if worker_id == "master": + return generate_expensive_data() + + # Multi-worker mode - use file locking + root_tmp = tmp_path_factory.getbasetemp().parent + data_file = root_tmp / "shared_data.json" + lock_file = str(data_file) + ".lock" + + with FileLock(lock_file): + if data_file.is_file(): + # Another worker already created the data + return json.loads(data_file.read_text()) + else: + # First worker creates the data + data = generate_expensive_data() + data_file.write_text(json.dumps(data)) + return data + +@pytest.fixture(scope="session") +def database_url(tmp_path_factory, worker_id): + """Per-worker database for parallel isolation.""" + # Each worker gets its own database + return f"sqlite:///test_db_{worker_id}.sqlite" +``` + +#### Configuration for Parallel Execution + +```toml +# pyproject.toml +[tool.pytest.ini_options] +addopts = [ + "-n", "auto", + "--dist", "loadscope", +] +``` + +> ⚠️ **Warning**: Not all tests are parallelization-safe. Tests that modify global state, shared files, or external services may conflict. Use `xdist_group` or run such tests serially with `-n 0`. + +### Test Duration Analysis + +```python +# Show test durations +pytest --durations=10 # Show 10 slowest tests +pytest --durations=0 # Show all test durations + +# In conftest.py - Custom timing +import time + +@pytest.fixture(autouse=True) +def measure_test_time(request): + start = time.time() + yield + duration = time.time() - start + print(f"\n{request.node.name} took {duration:.2f}s") +``` + +### Caching + +```python +# Using pytest cache +def test_expensive_computation(cache): + # Check cache + result = cache.get("computation_result", None) + if result is None: + # Compute and cache + result = expensive_computation() + cache.set("computation_result", result) + + assert result == expected_value + +# Cache command line +pytest --cache-show # Show cache contents +pytest --cache-clear # Clear cache +``` + +### Fixture Optimization + +```python +# Reuse expensive fixtures with broader scope +@pytest.fixture(scope="session") +def expensive_resource(): + """Create once, use many times.""" + resource = create_expensive_resource() + yield resource + resource.cleanup() + +# Lazy fixture creation +@pytest.fixture +def maybe_expensive(): + """Only created if actually used by test.""" + return ExpensiveObject() + +# Fixture factories for controlled creation +@pytest.fixture +def resource_factory(): + resources = [] + + def _make_resource(**kwargs): + resource = Resource(**kwargs) + resources.append(resource) + return resource + + yield _make_resource + + # Cleanup all at once + for resource in resources: + resource.cleanup() +``` + +--- + +## CI/CD Integration + +### GitHub Actions Example + +```yaml +# .github/workflows/test.yml +name: Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests + run: | + pytest -v --cov=myproject --cov-report=xml + + - name: Upload coverage + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml +``` + +### Test Stages + +```yaml +# Multi-stage testing +stages: + - quick-tests + - full-tests + - integration-tests + +quick-tests: + script: + - pytest -m "unit and not slow" --fail-fast + +full-tests: + script: + - pytest -m "not integration" + +integration-tests: + script: + - pytest -m integration + only: + - main + - merge_requests +``` + +### Coverage Configuration + +```ini +# .coveragerc +[run] +source = myproject +omit = + */tests/* + */venv/* + */migrations/* + */__init__.py + +[report] +precision = 2 +show_missing = True +skip_covered = False + +[html] +directory = htmlcov + +[xml] +output = coverage.xml +``` + +--- + +## Plugin Ecosystem + +### Essential Plugins + +```bash +# Coverage +pip install pytest-cov + +# Parallel execution +pip install pytest-xdist + +# Mocking +pip install pytest-mock + +# Timeout +pip install pytest-timeout + +# HTML reports +pip install pytest-html + +# BDD +pip install pytest-bdd + +# Benchmarking +pip install pytest-benchmark + +# Django +pip install pytest-django + +# Asyncio +pip install pytest-asyncio + +# Flake8 integration +pip install pytest-flake8 + +# Order randomization +pip install pytest-randomly +``` + +### Plugin Usage Examples + +```python +# pytest-timeout +@pytest.mark.timeout(10) # 10 second timeout +def test_slow_operation(): + perform_slow_operation() + +# pytest-benchmark +def test_performance(benchmark): + result = benchmark(my_function, arg1, arg2) + assert result == expected + +# pytest-randomly (randomize test order) +# Just install and it works automatically +# Use --randomly-seed=1234 to reproduce order +``` + +### Async Testing with pytest-asyncio + +#### Installation and Configuration + +```bash +pip install pytest-asyncio +``` + +```toml +# pyproject.toml +[tool.pytest.ini_options] +asyncio_mode = "auto" # Automatically handle async tests +``` + +#### Basic Async Tests + +```python +import pytest + +@pytest.mark.asyncio +async def test_async_function(): + """Test async function.""" + result = await async_operation() + assert result == expected + +@pytest.mark.asyncio +async def test_async_context_manager(): + """Test async context manager.""" + async with AsyncResource() as resource: + result = await resource.fetch() + assert result is not None +``` + +#### Async Fixtures + +```python +@pytest.fixture +async def async_client(): + """Async fixture with proper cleanup.""" + client = await create_async_client() + yield client + await client.close() + +@pytest.fixture(scope="session") +async def async_database(): + """Session-scoped async fixture.""" + db = await Database.connect() + yield db + await db.disconnect() + +@pytest.mark.asyncio +async def test_with_async_fixtures(async_client, async_database): + """Test using async fixtures.""" + result = await async_client.query(async_database) + assert result is not None +``` + +#### Fixture Scopes for Async + +```python +# Function scope (default) - new event loop per test +@pytest.fixture +async def function_resource(): + return await create_resource() + +# Session scope - shared across tests +@pytest.fixture(scope="session") +async def session_resource(): + resource = await expensive_async_setup() + yield resource + await resource.cleanup() +``` + +> ⚠️ **Deprecation Warning**: Sync tests depending on async fixtures will warn in pytest 8.x and error in future versions. Always use `@pytest.mark.asyncio` for tests using async fixtures. + +#### Event Loop Scope (pytest-asyncio 0.21+) + +```python +# Control event loop scope +@pytest.fixture(scope="session") +def event_loop_policy(): + """Use uvloop for faster async.""" + import uvloop + return uvloop.EventLoopPolicy() + +# Or via configuration +# pyproject.toml +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" +``` + +--- + +## Snapshot & Regression Testing + +Snapshot testing captures expected output and compares against future runs. + +### Using syrupy (Recommended) + +```bash +pip install syrupy +``` + +```python +def test_api_response(snapshot): + """Compare API response against snapshot.""" + response = api_client.get("/users/1") + assert response.json() == snapshot + +def test_html_output(snapshot): + """Compare rendered HTML.""" + html = render_template("user_profile.html", user=mock_user) + assert html == snapshot + +def test_complex_object(snapshot): + """Snapshot complex data structures.""" + result = process_data(input_data) + assert result == snapshot +``` + +### Snapshot Management + +```bash +# Update all snapshots (after intentional changes) +pytest --snapshot-update + +# Review snapshot changes interactively +pytest --snapshot-warn-unused + +# CI mode - fail on snapshot mismatch +pytest # Default behavior +``` + +### Custom Snapshot Serializers + +```python +from syrupy.extensions.json import JSONSnapshotExtension + +@pytest.fixture +def snapshot_json(snapshot): + """Use JSON serialization for snapshots.""" + return snapshot.use_extension(JSONSnapshotExtension) + +def test_json_api(snapshot_json): + response = api.get("/data") + assert response.json() == snapshot_json +``` + +### Inline Snapshots + +```python +def test_inline(snapshot): + """Snapshot stored in test file itself.""" + result = calculate_value() + assert result == snapshot(result) # First run creates snapshot +``` + +### Best Practices for Snapshot Testing + +1. **Use for stable outputs**: HTML, JSON responses, serialized objects +2. **Avoid for volatile data**: Timestamps, random IDs, system-specific paths +3. **Review diffs carefully**: Snapshot updates should be intentional +4. **Combine with unit tests**: Snapshots complement, not replace, assertions +5. **Keep snapshots small**: Large snapshots are hard to review + +--- + +## Property-Based Testing with Hypothesis + +Property-based testing generates random inputs to find edge cases. + +### Installation + +```bash +pip install hypothesis +``` + +### Basic Property Tests + +```python +from hypothesis import given, strategies as st + +@given(st.integers()) +def test_integer_properties(x): + """Test properties that should hold for all integers.""" + assert x + 0 == x + assert x * 1 == x + assert x - x == 0 + +@given(st.lists(st.integers())) +def test_sort_is_idempotent(data): + """Sorting twice equals sorting once.""" + assert sorted(data) == sorted(sorted(data)) + +@given(st.lists(st.integers())) +def test_sort_preserves_length(data): + """Sorting doesn't change length.""" + assert len(sorted(data)) == len(data) + +@given(st.text()) +def test_string_roundtrip(s): + """Encoding and decoding returns original.""" + assert s.encode("utf-8").decode("utf-8") == s +``` + +### Combining with pytest Fixtures + +```python +@given(st.integers(min_value=1, max_value=100)) +def test_with_fixture(db_session, quantity): + """Property test with pytest fixture.""" + order = Order(quantity=quantity) + db_session.add(order) + db_session.flush() + + assert order.total == order.price * quantity + +@pytest.mark.parametrize("discount", [0, 10, 25, 50]) +@given(st.integers(min_value=1)) +def test_parametrized_property(discount, price): + """Combine parametrize with hypothesis.""" + discounted = apply_discount(price, discount) + assert discounted <= price +``` + +### Custom Strategies + +```python +from hypothesis import strategies as st + +# Email strategy +emails = st.emails() + +# Custom composite strategy +@st.composite +def user_data(draw): + """Generate valid user data.""" + return { + "username": draw(st.text(min_size=3, max_size=20)), + "email": draw(st.emails()), + "age": draw(st.integers(min_value=18, max_value=120)), + } + +@given(user_data()) +def test_user_creation(data): + user = User(**data) + assert user.is_valid() +``` + +### Controlling Test Generation + +```python +from hypothesis import given, settings, Verbosity + +@given(st.integers()) +@settings( + max_examples=500, # More thorough testing + deadline=1000, # 1 second timeout per example + verbosity=Verbosity.verbose, +) +def test_thorough(x): + assert some_property(x) + +@given(st.integers()) +@settings(max_examples=10) # Quick smoke test +def test_quick(x): + assert basic_property(x) +``` + +### Example Database for Reproducibility + +```python +from hypothesis import given, settings, Phase + +@given(st.integers()) +@settings( + database=None, # Disable example database + phases=[Phase.generate], # Only generate, don't replay +) +def test_stateless(x): + pass +``` + +### Best Practices + +1. **Test properties, not examples**: Focus on invariants that always hold +2. **Keep tests fast**: Each example should be quick +3. **Use `@settings(deadline=None)`** for slow operations +4. **Review failing examples**: Hypothesis shows minimal failing case +5. **Combine with unit tests**: Property tests find edge cases, unit tests verify specific behavior + +--- + +## Test Asset Generation & Management + +Dynamic test asset generation ensures tests are self-contained, reproducible, and independent of external files. This is especially critical for ML/ONNX testing where models must be generated programmatically. + +### Core Principle: Code-Generated Assets + +**CARDINAL RULE**: Never rely on pre-existing files or LLM-generated test data. All test assets must be generated by code during test execution. + +```python +# ❌ BAD: Relying on pre-existing files +def test_model_optimization(): + model = onnx.load("tests/fixtures/bert_model.onnx") # External dependency! + optimized = optimize(model) + assert optimized is not None + +# ✅ GOOD: Generate assets programmatically +def test_model_optimization(simple_model_fixture): + """Model is generated by fixture - no external dependencies.""" + optimized = optimize(simple_model_fixture) + assert optimized is not None +``` + +### Fixture-Based Asset Generation + +#### Session-Scoped Expensive Assets + +For expensive-to-generate assets, use session scope to generate once per test session: + +```python +# conftest.py +import onnx +from onnx import helper, TensorProto +import numpy as np + +@pytest.fixture(scope="session") +def base_model() -> onnx.ModelProto: + """Generate a base ONNX model for testing. + + Session-scoped to avoid regenerating for every test. + """ + # Create input + X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 128]) + + # Create nodes + nodes = [ + helper.make_node("Relu", ["input"], ["relu_out"], name="relu_1"), + helper.make_node("Sigmoid", ["relu_out"], ["output"], name="sigmoid_1"), + ] + + # Create output + Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 128]) + + # Build graph and model + graph = helper.make_graph(nodes, "test_graph", [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + return model +``` + +#### Function-Scoped Mutable Assets + +For assets that tests may modify, use function scope: + +```python +@pytest.fixture(scope="function") +def mutable_model(base_model) -> onnx.ModelProto: + """Create a fresh copy for tests that modify the model.""" + import copy + return copy.deepcopy(base_model) +``` + +### Pattern-Specific Model Generation + +Generate models containing specific patterns for targeted testing: + +```python +# tests/optim/conftest.py + +@pytest.fixture(scope="session") +def gelu_pattern_model() -> onnx.ModelProto: + """Generate model with GELU approximation pattern. + + GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) + This pattern should be detected and fused by GELU fusion optimizers. + """ + X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 768]) + + # Create GELU approximation nodes + nodes = [ + # x³ + helper.make_node("Pow", ["input", "three"], ["x_cubed"], name="pow_1"), + # 0.044715 * x³ + helper.make_node("Mul", ["x_cubed", "coef"], ["scaled_cube"], name="mul_1"), + # x + 0.044715 * x³ + helper.make_node("Add", ["input", "scaled_cube"], ["sum_1"], name="add_1"), + # sqrt(2/π) * (x + 0.044715 * x³) + helper.make_node("Mul", ["sum_1", "sqrt_2_pi"], ["tanh_input"], name="mul_2"), + # tanh(...) + helper.make_node("Tanh", ["tanh_input"], ["tanh_out"], name="tanh_1"), + # 1 + tanh(...) + helper.make_node("Add", ["one", "tanh_out"], ["one_plus_tanh"], name="add_2"), + # 0.5 * x + helper.make_node("Mul", ["half", "input"], ["half_x"], name="mul_3"), + # 0.5 * x * (1 + tanh(...)) + helper.make_node("Mul", ["half_x", "one_plus_tanh"], ["output"], name="mul_4"), + ] + + # Create initializers for constants + initializers = [ + numpy_helper.from_array(np.array([3.0], dtype=np.float32), "three"), + numpy_helper.from_array(np.array([0.044715], dtype=np.float32), "coef"), + numpy_helper.from_array(np.array([0.7978845608], dtype=np.float32), "sqrt_2_pi"), + numpy_helper.from_array(np.array([1.0], dtype=np.float32), "one"), + numpy_helper.from_array(np.array([0.5], dtype=np.float32), "half"), + ] + + Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 768]) + graph = helper.make_graph(nodes, "gelu_pattern", [X], [Y], initializers) + + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +@pytest.fixture(scope="session") +def matmul_add_pattern_model() -> onnx.ModelProto: + """Generate model with MatMul+Add pattern for Gemm fusion testing.""" + X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 512]) + + # Weight and bias initializers + weight = numpy_helper.from_array( + np.random.randn(512, 256).astype(np.float32), "weight" + ) + bias = numpy_helper.from_array( + np.random.randn(256).astype(np.float32), "bias" + ) + + nodes = [ + helper.make_node("MatMul", ["input", "weight"], ["matmul_out"], name="matmul_1"), + helper.make_node("Add", ["matmul_out", "bias"], ["output"], name="add_1"), + ] + + Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 256]) + graph = helper.make_graph(nodes, "matmul_add_pattern", [X], [Y], [weight, bias]) + + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) +``` + +### Multi-Pattern Test Models + +For comprehensive testing, generate models with multiple patterns: + +```python +@pytest.fixture(scope="session") +def all_patterns_model() -> onnx.ModelProto: + """Generate model with ALL optimization patterns for comprehensive testing. + + Patterns included (with prefixes for identification): + - p01_identity_: Identity elimination pattern + - p02_dropout_: Dropout elimination pattern + - p03_reshape_: Reshape fusion pattern + - p04_transpose_: Transpose optimization pattern + - p05_conv_: Conv optimization pattern + - p06_matmuladdrelu_: MatMul+Add+Relu fusion pattern + - p07_attention_: Attention pattern + - p08_biasgelu_: Bias+GELU fusion pattern + - p09_skiplayernorm_: SkipLayerNorm pattern + + Node naming convention: {pattern_prefix}{operation}_{index} + Example: p06_matmuladdrelu_matmul_1 + """ + # Implementation generates all patterns in one model + # Each pattern uses consistent naming for verification + ... +``` + +### Conftest Hierarchy for Asset Sharing + +Organize conftest files hierarchically for proper asset sharing: + +``` +tests/ +├── conftest.py # Root: Core helpers (optimize_at_level, etc.) +├── optim/ +│ ├── conftest.py # Optim-wide: Base model fixtures +│ ├── capabilities/ +│ │ ├── conftest.py # Capability-specific: Pattern models, ORT names +│ │ ├── test_gelu_fusion.py +│ │ └── test_matmul_add.py +│ ├── pipes/ +│ │ ├── conftest.py # Pipe-specific: Pipe configs, mock models +│ │ ├── test_pipe_graph.py +│ │ └── test_pipe_fusion.py +│ └── integration/ +│ ├── conftest.py # Integration: Complex model fixtures +│ └── test_optimizer.py +``` + +#### Root conftest.py - Core Helpers + +```python +# tests/conftest.py +"""Root conftest - Core testing utilities.""" + +import onnx +import onnxruntime as ort +import tempfile +from pathlib import Path + +def optimize_at_level( + model: onnx.ModelProto, + level: int = 2, + disabled_optimizers: list[str] | None = None, +) -> onnx.ModelProto: + """Apply ORT graph optimization at specified level. + + This is the RAW ORT API helper - does NOT use Pipe classes. + Use this in capability tests for isolation testing. + """ + opts = ort.SessionOptions() + opts.graph_optimization_level = ort.GraphOptimizationLevel(level) + + if disabled_optimizers: + for name in disabled_optimizers: + opts.add_session_config_entry( + f"session.disable_specified_optimizers", + ",".join(disabled_optimizers) + ) + + with tempfile.TemporaryDirectory() as tmpdir: + input_path = Path(tmpdir) / "input.onnx" + output_path = Path(tmpdir) / "output.onnx" + + onnx.save(model, str(input_path)) + opts.optimized_model_filepath = str(output_path) + + # Create session to trigger optimization + ort.InferenceSession(str(input_path), opts) + + return onnx.load(str(output_path)) +``` + +#### Domain conftest.py - Shared Fixtures + +```python +# tests/optim/capabilities/conftest.py +"""Capability test fixtures - Pattern-specific models.""" + +import pytest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import onnx + +# Import pattern model generators +from tests.optim.conftest import ( + gelu_pattern_model, + matmul_add_pattern_model, + all_patterns_model, +) + +def get_all_ort_names() -> list[str]: + """Get all registered ORT optimizer names for isolation testing.""" + return [ + "GeluFusionL2", + "BiasGeluFusion", + "MatMulAddFusion", + "LayerNormFusion", + # ... all 49+ ORT optimizer names + ] + +@pytest.fixture(scope="session") +def ort_optimizer_names() -> list[str]: + """Fixture providing all ORT optimizer names.""" + return get_all_ort_names() +``` + +### Asset Verification Helpers + +Create helpers to verify generated assets have expected structure: + +```python +# tests/helpers/model_verification.py + +def count_nodes_by_op(model: onnx.ModelProto, op_type: str) -> int: + """Count nodes of specific operation type.""" + return sum(1 for n in model.graph.node if n.op_type == op_type) + +def count_nodes_by_prefix(model: onnx.ModelProto, prefix: str) -> int: + """Count nodes with name prefix (for pattern identification).""" + return sum(1 for n in model.graph.node if n.name.startswith(prefix)) + +def count_nodes_by_prefix_and_op( + model: onnx.ModelProto, prefix: str, op_type: str +) -> int: + """Count nodes matching both prefix and operation type.""" + return sum( + 1 for n in model.graph.node + if n.name.startswith(prefix) and n.op_type == op_type + ) + +def verify_pattern_exists( + model: onnx.ModelProto, + pattern_prefix: str, + expected_ops: list[str], +) -> bool: + """Verify a pattern exists in the model with expected operations.""" + for op in expected_ops: + if count_nodes_by_prefix_and_op(model, pattern_prefix, op) == 0: + return False + return True +``` + +### Differential Testing with Generated Assets + +Test optimization effects by comparing before/after states: + +```python +def test_gelu_fusion_effectiveness(gelu_pattern_model): + """Test that GELU fusion actually reduces node count.""" + from tests.conftest import optimize_at_level + from tests.helpers.model_verification import count_nodes_by_op + + # Before optimization + before_tanh = count_nodes_by_op(gelu_pattern_model, "Tanh") + before_mul = count_nodes_by_op(gelu_pattern_model, "Mul") + + # Apply optimization with GELU fusion enabled + optimized = optimize_at_level( + gelu_pattern_model, + level=2, + disabled_optimizers=[] # All enabled + ) + + # After optimization - GELU pattern should be fused + after_tanh = count_nodes_by_op(optimized, "Tanh") + after_mul = count_nodes_by_op(optimized, "Mul") + + # Verify fusion occurred + assert after_tanh < before_tanh, "GELU fusion should reduce Tanh nodes" + assert after_mul < before_mul, "GELU fusion should reduce Mul nodes" +``` + +### Best Practices Summary + +1. **Always generate assets in code**: Never rely on external files +2. **Use appropriate fixture scope**: Session for expensive, function for mutable +3. **Name patterns consistently**: Use prefixes for pattern identification +4. **Create verification helpers**: Standardize how you check asset structure +5. **Document pattern structure**: Explain what each generated model contains +6. **Test asset generation**: Verify fixtures produce expected structures +7. **Use conftest hierarchy**: Share assets at appropriate levels +8. **Prefer RAW APIs in unit tests**: Don't couple to higher-level abstractions + +--- + +## Common Patterns & Anti-Patterns + +### Patterns ✅ + +```python +# Good: Descriptive test names +def test_user_registration_sends_welcome_email(): + pass + +# Good: Focused tests +def test_calculate_tax_for_standard_rate(): + income = 50000 + assert calculate_tax(income) == 10000 + +# Good: Using fixtures for setup +@pytest.fixture +def authenticated_client(client, user): + client.login(username=user.username, password="password") + return client + +# Good: Parametrize instead of loops +@pytest.mark.parametrize("value,expected", [ + (1, 1), + (2, 4), + (3, 9), +]) +def test_square(value, expected): + assert value ** 2 == expected + +# Good: Clear test structure (Arrange-Act-Assert) +def test_user_creation(): + # Arrange + data = {"username": "john", "email": "john@example.com"} + + # Act + user = User.create(**data) + + # Assert + assert user.username == "john" + assert user.email == "john@example.com" +``` + +### Anti-Patterns ❌ + +```python +# Bad: Test doing too much +def test_everything(): + user = create_user() + post = create_post(user) + comment = create_comment(post) + assert user.is_active + assert post.author == user + assert comment.post == post + # Too many things tested at once + +# Bad: Modifying global state +def test_with_global_state(): + global CONFIG + CONFIG["debug"] = True # Don't modify globals + assert my_function() == expected + +# Bad: Tests depending on order +def test_first(): + global shared_data + shared_data = setup_data() + +def test_second(): + # Depends on test_first running first + assert shared_data.value == expected + +# Bad: Catching all exceptions +def test_broad_exception(): + try: + risky_operation() + except Exception: # Too broad + pass # Test passes even if unexpected error + +# Bad: No assertion +def test_without_assertion(): + result = my_function() + # No assert - test always passes +``` + +--- + +## Debugging & Troubleshooting + +### Debugging Techniques + +```python +# Drop into debugger on failure +pytest --pdb + +# Drop into IPython debugger +pytest --pdbcls=IPython.terminal.debugger:TerminalPdb + +# Set breakpoint in code +def test_debug(): + value = calculate() + import pdb; pdb.set_trace() # or breakpoint() in Python 3.7+ + assert value == expected + +# Print debugging (use -s flag) +def test_with_print(): + print("Debug info:", value) # Visible with pytest -s + assert value == expected + +# Capture logs +def test_with_logging(caplog): + with caplog.at_level(logging.INFO): + my_function() + assert "Expected message" in caplog.text + +# Detailed failure info +pytest -vv # Very verbose +pytest --tb=short # Short traceback +pytest --tb=line # One line per failure +pytest --tb=no # No traceback +``` + +### Common Issues & Solutions + +```python +# Issue: Import errors +# Solution: Check PYTHONPATH and use --import-mode +pytest --import-mode=importlib + +# Issue: Fixture not found +# Solution: Check scope and conftest.py location +pytest --fixtures # List available fixtures + +# Issue: Tests not discovered +# Solution: Check naming conventions +pytest --collect-only # See what's collected + +# Issue: Flaky tests +# Solution: Use pytest-rerunfailures +pip install pytest-rerunfailures +pytest --reruns 3 --reruns-delay 1 + +# Issue: Test isolation +# Solution: Use fixtures and avoid global state +@pytest.fixture(autouse=True) +def reset_state(): + cleanup_before_test() + yield + cleanup_after_test() +``` + +--- + +--- + +## Deprecations & Migration Guide + +### Deprecated Patterns to Avoid + +Understanding deprecated patterns helps maintain forward compatibility. + +#### Marker Access (Changed in pytest 4.0+) + +```python +# ❌ DEPRECATED - will be removed +marker = item.get_marker("slow") + +# ✅ CURRENT - use these instead +marker = item.get_closest_marker("slow") # Single marker +markers = list(item.iter_markers("slow")) # Multiple markers +``` + +#### Hook Decorators (Changed in pytest 7.0+) + +```python +# ❌ DEPRECATED +@pytest.mark.tryfirst +def pytest_collection_modifyitems(items): + pass + +# ✅ CURRENT +@pytest.hookimpl(tryfirst=True) +def pytest_collection_modifyitems(items): + pass +``` + +#### pytest_namespace Hook (Removed in pytest 8.0) + +```python +# ❌ REMOVED - no longer works +def pytest_namespace(): + return {"my_value": 42} + +# ✅ CURRENT - use pytest_configure instead +def pytest_configure(config): + config.my_value = 42 +``` + +#### Async Fixtures with Sync Tests (Warning in pytest 8.x+) + +```python +# ❌ DEPRECATED - will warn and eventually error +@pytest.fixture +async def async_data(): + return await fetch_data() + +def test_sync(async_data): # Sync test using async fixture + assert async_data is not None + +# ✅ CURRENT - explicit async handling +@pytest.fixture +def async_data(): + import asyncio + return asyncio.run(fetch_data()) + +def test_sync(async_data): + assert async_data is not None + +# OR use async test +@pytest.fixture +async def async_data(): + return await fetch_data() + +@pytest.mark.asyncio +async def test_async(async_data): + assert async_data is not None +``` + +#### yield_fixture Decorator (Removed) + +```python +# ❌ REMOVED +@pytest.yield_fixture +def resource(): + r = acquire() + yield r + release(r) + +# ✅ CURRENT - use regular fixture with yield +@pytest.fixture +def resource(): + r = acquire() + yield r + release(r) +``` + +### Migration Checklist + +When upgrading pytest versions, check for: + +- [ ] Replace `item.get_marker()` with `item.get_closest_marker()` +- [ ] Replace `@pytest.mark.tryfirst/trylast` with `@pytest.hookimpl(tryfirst=True/trylast=True)` +- [ ] Remove any `pytest_namespace` hooks +- [ ] Update async fixtures to use explicit handling +- [ ] Replace `@pytest.yield_fixture` with `@pytest.fixture` +- [ ] Check `--strict-config` passes with your configuration +- [ ] Review `filterwarnings` for any pytest deprecation warnings + +### Version Compatibility Matrix + +| Feature | Minimum Version | Notes | +|---------|-----------------|-------| +| `pyproject.toml` support | pytest 6.0 | `[tool.pytest.ini_options]` | +| Native TOML `[tool.pytest]` | pytest 9.0 | Cleaner syntax | +| `--import-mode=importlib` | pytest 6.0 | Recommended default | +| `@pytest.hookimpl` | pytest 7.0 | Replaces mark decorators | +| `item.iter_markers()` | pytest 4.0 | Replaces `get_marker()` | +| `required_plugins` | pytest 7.0 | With `--strict-config` | + +## Best Practices Checklist + +### ✅ DO's + +1. **Write descriptive test names** that explain what is being tested +2. **Use fixtures** for setup and teardown +3. **Keep tests focused** - one concept per test +4. **Use parametrize** for data-driven tests +5. **Organize tests** to mirror source code structure +6. **Register custom markers** in pytest.ini +7. **Use appropriate scopes** for fixtures +8. **Mock external dependencies** in unit tests +9. **Run fastest tests first** in CI/CD +10. **Use pytest.raises** for exception testing +11. **Document complex test scenarios** +12. **Use tmp_path fixture** for file operations +13. **Configure pytest** in pyproject.toml or pytest.ini +14. **Use pytest plugins** to extend functionality +15. **Profile slow tests** and optimize +16. **Start without `__init__.py`** in test directories - add only when needed +17. **Use `--import-mode=importlib`** for modern import handling +18. **Declare `required_plugins`** for team/CI consistency +19. **Use `--strict-config`** to catch configuration errors early +20. **Handle async fixtures properly** with `@pytest.mark.asyncio` +21. **Use file locking** for session fixtures with parallel execution + +### ❌ DON'Ts + +1. **Don't write tests that depend on execution order** +2. **Don't use global state** that affects other tests +3. **Don't catch broad exceptions** without re-raising +4. **Don't hardcode paths** - use fixtures and tmp_path +5. **Don't skip writing tests** for "simple" functions +6. **Don't mix test types** in the same file +7. **Don't use production credentials** in tests +8. **Don't ignore flaky tests** - fix or mark them +9. **Don't write tests without assertions** +10. **Don't duplicate test logic** - use fixtures +11. **Don't test implementation details** - test behavior +12. **Don't use time.sleep** - use proper synchronization +13. **Don't modify source code** for testing - use mocks +14. **Don't run all tests locally** for every change +15. **Don't ignore test warnings** - fix or suppress explicitly +16. **Don't add `__init__.py` to tests by default** - pytest works without it +17. **Don't use deprecated marker access** - use `get_closest_marker()` not `get_marker()` +18. **Don't mix sync tests with async fixtures** - will warn/error in pytest 8+ +19. **Don't ignore configuration file priority** - empty `pytest.ini` blocks `pyproject.toml` +20. **Don't use `@pytest.yield_fixture`** - use `@pytest.fixture` with yield +21. **Don't forget `xdist_group`** when tests must share state in parallel execution + +### Final Recommendations + +1. **Start Simple**: Begin with basic tests and add complexity as needed +2. **Test First**: Consider TDD for complex logic +3. **Continuous Integration**: Run tests automatically on every commit +4. **Code Coverage**: Aim for high coverage but focus on critical paths +5. **Performance**: Monitor and optimize test suite performance +6. **Documentation**: Document complex test scenarios and fixtures +7. **Maintenance**: Regularly update and refactor tests +8. **Team Standards**: Establish and follow team testing conventions + +Remember: Good tests are as important as good code. They provide confidence, documentation, and safety for refactoring. diff --git a/src/winml/modelkit/analyze/core/information_engine.py b/src/winml/modelkit/analyze/core/information_engine.py index bfe4f3c7a..9f9202f9b 100644 --- a/src/winml/modelkit/analyze/core/information_engine.py +++ b/src/winml/modelkit/analyze/core/information_engine.py @@ -460,7 +460,7 @@ def _query_doc_constraints(self, runtime_result: PatternRuntime, pattern_id: str pattern_match = runtime_result.pattern_match - # PatternMatch has matched_node_names (list[OnnxOP]), not matched_nodes + # PatternMatch has matched_node_names (list[ONNXOp]), not matched_nodes if ( not hasattr(pattern_match, "matched_node_names") or not pattern_match.matched_node_names diff --git a/src/winml/modelkit/analyze/core/node_checkers/__init__.py b/src/winml/modelkit/analyze/core/node_checkers/__init__.py index 3b540505d..5c24c5601 100644 --- a/src/winml/modelkit/analyze/core/node_checkers/__init__.py +++ b/src/winml/modelkit/analyze/core/node_checkers/__init__.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from .ep_context_node_checker import EpContextNodeChecker +from .ep_context_node_checker import EPContextNodeChecker -__all__ = ["EpContextNodeChecker"] +__all__ = ["EPContextNodeChecker"] diff --git a/src/winml/modelkit/analyze/core/node_checkers/ep_context_node_checker.py b/src/winml/modelkit/analyze/core/node_checkers/ep_context_node_checker.py index 948f17d1f..28bef64bf 100644 --- a/src/winml/modelkit/analyze/core/node_checkers/ep_context_node_checker.py +++ b/src/winml/modelkit/analyze/core/node_checkers/ep_context_node_checker.py @@ -18,7 +18,7 @@ @NodeCheckerRegistry.register_checker() -class EpContextNodeChecker(NodeChecker): +class EPContextNodeChecker(NodeChecker): """Checker for validating EPContext nodes based on their attributes. This checker applies to EPContext nodes in the com.microsoft domain and diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index 47a42ebb0..4f85b3bad 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -27,9 +27,9 @@ ) from ..exceptions import ( - OPLackOfRequiredInformationError, - OPOptionalInputSupportError, - OPUnsupportedError, + OpLackOfRequiredInformationError, + OpOptionalInputSupportError, + OpUnsupportedError, ) from ..models.runtime_checks import NodeTag, PatternAlternative, PatternRuntime, RuntimeTestResult from ..runtime_checker.ep_checker import EPChecker @@ -470,7 +470,7 @@ def get_query_conditions_for_node( for a in node.attribute: if a is None: - raise OPOptionalInputSupportError( + raise OpOptionalInputSupportError( f"Node {node.op_type} has optional attribute. " f"Expected attribute names: {_format_list_preview(attribute_names)}" ) @@ -484,7 +484,7 @@ def get_query_conditions_for_node( try: runtime_checker_op = get_runtime_checker_op(node.op_type)(schema) except KeyError: - raise OPUnsupportedError(f"Node {node.op_type} is not supported") from None + raise OpUnsupportedError(f"Node {node.op_type} is not supported") from None type_vars = {} # fill missing attrs with default values; set None for optional attrs without defaults @@ -601,7 +601,7 @@ def update_conditions_( conditions[f"{input_name}_value"] = None continue # Required input is missing - this is an error - raise OPOptionalInputSupportError( + raise OpOptionalInputSupportError( f"Node {node.op_type} missing required input {input_name}" ) @@ -649,7 +649,7 @@ def update_conditions_( # Input is provided but valueinfo not found # This commonly happens in quantized models where DequantizeLinear outputs # are not properly captured by shape inference - raise OPLackOfRequiredInformationError( + raise OpLackOfRequiredInformationError( f"Node {node.op_type} (name: " f"{node.name}): Input '{inp_name}' " f"(parameter '{input_name}') not found " @@ -677,7 +677,7 @@ def update_conditions_( # KeyError: missing required property (e.g., 'input_value', 'input_shape') # TypeError: invalid property value (e.g., None when expecting iterable) # IndexError: accessing empty shape/array (e.g., shape[-1] on empty tuple) - raise OPLackOfRequiredInformationError( + raise OpLackOfRequiredInformationError( f"Node {node.op_type} (name: {node.name}): " f"Incomplete model information for " f"derive_properties: {e}" @@ -1348,7 +1348,7 @@ def _check_negative_rules( Tuple of (passed, reason_text). passed is False if the op fails this phase. Raises: - OPOptionalInputSupportError: If a required property is missing from conditions. + OpOptionalInputSupportError: If a required property is missing from conditions. """ if op_neg_rules["all_failed"][phase]: return False, f"The op {node.op_type} is not supported by {phase}, " @@ -1357,7 +1357,7 @@ def _check_negative_rules( reason = "" for k, v in op_neg_rules["negative_rules"][phase].items(): if k not in conditions: - raise OPOptionalInputSupportError( + raise OpOptionalInputSupportError( f"{phase.capitalize()} check for op " f"{node.op_type}: required property " f"'{k}' not found in conditions" @@ -1492,9 +1492,9 @@ def get_pattern_id(is_qdq): dynamic_axis_strict_mode=self.dynamic_axis_strict_mode, ) except ( - OPOptionalInputSupportError, - OPLackOfRequiredInformationError, - OPUnsupportedError, + OpOptionalInputSupportError, + OpLackOfRequiredInformationError, + OpUnsupportedError, ) as e: exception_type = type(e).__name__ logger.error( @@ -1616,7 +1616,7 @@ def get_pattern_id(is_qdq): filter_v[k] = conditions[k] else: avail = _format_list_preview(conditions.keys()) - raise OPOptionalInputSupportError( + raise OpOptionalInputSupportError( f"Match key '{k}' not found " f"in conditions for op " f"{node.op_type} (domain: " @@ -1761,7 +1761,7 @@ def get_pattern_id(is_qdq): alternatives=self.alternatives, pattern_match=pattern_match, ) - except (OPOptionalInputSupportError, OPLackOfRequiredInformationError) as e: + except (OpOptionalInputSupportError, OpLackOfRequiredInformationError) as e: exception_type = type(e).__name__ logger.error( "%s caught for op %s (node: %s): %s", diff --git a/src/winml/modelkit/analyze/exceptions.py b/src/winml/modelkit/analyze/exceptions.py index 156929075..9ee08ed3e 100644 --- a/src/winml/modelkit/analyze/exceptions.py +++ b/src/winml/modelkit/analyze/exceptions.py @@ -5,11 +5,11 @@ """Exceptions for static analyzer.""" -class OPOptionalInputSupportError(Exception): +class OpOptionalInputSupportError(Exception): """Raised when optional attributes or inputs are not supported.""" -class OPLackOfRequiredInformationError(Exception): +class OpLackOfRequiredInformationError(Exception): """Raised when required information (shape, dtype, etc.) is missing from the model. This commonly occurs in: @@ -18,5 +18,6 @@ class OPLackOfRequiredInformationError(Exception): - Models with dynamic/symbolic dimensions """ -class OPUnsupportedError(Exception): + +class OpUnsupportedError(Exception): """Raised when an unsupported operator is encountered.""" diff --git a/src/winml/modelkit/analyze/models/__init__.py b/src/winml/modelkit/analyze/models/__init__.py index 5faa73227..4c4cf8f7a 100644 --- a/src/winml/modelkit/analyze/models/__init__.py +++ b/src/winml/modelkit/analyze/models/__init__.py @@ -10,7 +10,7 @@ from .ihv_type import IHVType from .information import Action, ActionLevel, Information from .onnx_model import ModelTag, ONNXModel -from .onnx_op import OnnxOP +from .onnx_op import ONNXOp from .output import AnalysisOutput, EPSupport, ModelStats, extract_model_stats from .runtime_checks import AlternativeType, RuntimeCheckRule, RuntimeTestResult from .support_level import SupportLevel @@ -28,7 +28,7 @@ "ModelStats", "ModelTag", "ONNXModel", - "OnnxOP", + "ONNXOp", "OperatorPattern", "Pattern", "PatternMatchResult", diff --git a/src/winml/modelkit/analyze/models/onnx_op.py b/src/winml/modelkit/analyze/models/onnx_op.py index cbee8f62f..b42fcbe18 100644 --- a/src/winml/modelkit/analyze/models/onnx_op.py +++ b/src/winml/modelkit/analyze/models/onnx_op.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field -class OnnxOP(BaseModel): +class ONNXOp(BaseModel): """Represents an ONNX operator node for output. Attributes: diff --git a/src/winml/modelkit/analyze/models/output.py b/src/winml/modelkit/analyze/models/output.py index f57facfdc..d1c26e62b 100644 --- a/src/winml/modelkit/analyze/models/output.py +++ b/src/winml/modelkit/analyze/models/output.py @@ -53,7 +53,7 @@ class EPSupport(BaseModel): ..., description=( "Operator classification by support level, " - "the list[str] will contain OnnxOP's display name" + "the list[str] will contain ONNXOp's display name" ), ) information: list[Information] = Field( diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 463fbdd63..00cd4ea32 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -56,7 +56,7 @@ WinMLExportConfig, _resolve_export_config_from_specs, ) -from ..export.io import OnnxConfigNotFoundError +from ..export.io import ONNXConfigNotFoundError from ..loader import resolve_loader_config from ..loader.config import WinMLLoaderConfig from ..optim.config import WinMLOptimizationConfig @@ -534,7 +534,7 @@ class name. Uses torchinfo to discover submodules and infer batch_size=WinMLExportConfig().batch_size, **(shape_config or {}), ) - except OnnxConfigNotFoundError: + except ONNXConfigNotFoundError: logger.info( "Optimum has no OnnxConfig for '%s'; using empty export config", _registry_key, diff --git a/src/winml/modelkit/export/__init__.py b/src/winml/modelkit/export/__init__.py index 251fcd3b2..4041e6eef 100644 --- a/src/winml/modelkit/export/__init__.py +++ b/src/winml/modelkit/export/__init__.py @@ -19,7 +19,7 @@ ) from .io import ( MaxLengthTextInputGenerator, - OnnxConfigNotFoundError, + ONNXConfigNotFoundError, register_onnx_overwrite, resolve_io_specs, ) @@ -32,7 +32,7 @@ __all__ = [ "InputTensorSpec", "MaxLengthTextInputGenerator", - "OnnxConfigNotFoundError", + "ONNXConfigNotFoundError", "OutputTensorSpec", "WinMLExportConfig", "export_onnx", diff --git a/src/winml/modelkit/export/htp/htp_metadata_schema.json b/src/winml/modelkit/export/htp/htp_metadata_schema.json index f74fd6db9..624705392 100644 --- a/src/winml/modelkit/export/htp/htp_metadata_schema.json +++ b/src/winml/modelkit/export/htp/htp_metadata_schema.json @@ -201,7 +201,7 @@ "title": "TracingInfo", "type": "object" }, - "OnnxModelOutput": { + "ONNXModelOutput": { "description": "ONNX model output information", "properties": { "path": { @@ -235,7 +235,7 @@ "size_mb", "opset_version" ], - "title": "OnnxModelOutput", + "title": "ONNXModelOutput", "type": "object" }, "FileInfo": { @@ -431,7 +431,7 @@ "type": "object", "properties": { "onnx_model": { - "$ref": "#/$defs/OnnxModelOutput" + "$ref": "#/$defs/ONNXModelOutput" }, "metadata": { "$ref": "#/$defs/FileInfo" @@ -527,4 +527,4 @@ ], "title": "HTPMetadata", "type": "object" -} \ No newline at end of file +} diff --git a/src/winml/modelkit/export/htp/metadata_builder.py b/src/winml/modelkit/export/htp/metadata_builder.py index bf4aff09a..b7728ee8c 100644 --- a/src/winml/modelkit/export/htp/metadata_builder.py +++ b/src/winml/modelkit/export/htp/metadata_builder.py @@ -83,7 +83,7 @@ class TaggingInfo: @dataclass -class OnnxModelOutput: +class ONNXModelOutput: """ONNX model output information.""" path: str @@ -96,7 +96,7 @@ class OnnxModelOutput: class OutputFiles: """Output file information.""" - onnx_model: OnnxModelOutput + onnx_model: ONNXModelOutput metadata: dict[str, str] = field(default_factory=dict) report: dict[str, str] = field(default_factory=dict) @@ -255,7 +255,7 @@ def with_output_files( ) -> HTPMetadataBuilder: """Set output file information.""" self._output_files = OutputFiles( - onnx_model=OnnxModelOutput( + onnx_model=ONNXModelOutput( path=Path(onnx_path).name, size_mb=onnx_size_mb, opset_version=opset_version, diff --git a/src/winml/modelkit/export/io.py b/src/winml/modelkit/export/io.py index 1b0aa9bd8..0780ba4a7 100644 --- a/src/winml/modelkit/export/io.py +++ b/src/winml/modelkit/export/io.py @@ -52,7 +52,7 @@ logger = logging.getLogger(__name__) -class OnnxConfigNotFoundError(ValueError): +class ONNXConfigNotFoundError(ValueError): """Raised when no OnnxConfig is registered for a model_type/task combination.""" @@ -198,7 +198,7 @@ def _get_onnx_config( library_name=library_name, ) except KeyError as e: - raise OnnxConfigNotFoundError( + raise ONNXConfigNotFoundError( f"No OnnxConfig registered for model_type='{model_type}' with task='{task}'. " f"Ensure the model's ONNX config is registered with TasksManager. " f"Original error: {e}" @@ -288,7 +288,9 @@ def _populate_sequence_length_from_config( shape_kwargs["sequence_length"] = seq_len logger.debug( "Set sequence_length=%d from max_position_embeddings=%d (cap=%d)", - seq_len, max_pos, _MAX_EXPORT_SEQ_LEN, + seq_len, + max_pos, + _MAX_EXPORT_SEQ_LEN, ) diff --git a/src/winml/modelkit/pattern/match.py b/src/winml/modelkit/pattern/match.py index 0d1520b5f..17ba5ac70 100644 --- a/src/winml/modelkit/pattern/match.py +++ b/src/winml/modelkit/pattern/match.py @@ -136,20 +136,20 @@ def matched_nodes(self) -> list[str]: @property def matched_node_names(self): - """Get matched nodes as OnnxOP objects. + """Get matched nodes as ONNXOp objects. - Note: Despite the name, this returns OnnxOP objects, not strings. + Note: Despite the name, this returns ONNXOp objects, not strings. This is for backward compatibility. Use matched_nodes for string names. Returns: - List of OnnxOP instances containing node metadata (when used from analyze). - Falls back to dicts when OnnxOP is not available. + List of ONNXOp instances containing node metadata (when used from analyze). + Falls back to dicts when ONNXOp is not available. """ try: - from winml.modelkit.analyze.models.onnx_op import OnnxOP + from winml.modelkit.analyze.models.onnx_op import ONNXOp return [ - OnnxOP( + ONNXOp( node_name=node.name if node.name else f"{node.op_type}_node", op_type=node.op_type, namespace=node.domain if node.domain else "ai.onnx", diff --git a/src/winml/modelkit/quant/qdq_fix.py b/src/winml/modelkit/quant/qdq_fix.py index 5703cc998..9103b67af 100644 --- a/src/winml/modelkit/quant/qdq_fix.py +++ b/src/winml/modelkit/quant/qdq_fix.py @@ -35,7 +35,7 @@ @dataclass -class QdqFixResult: +class QDQFixResult: """Result of QDQ dtype fix operation.""" inputs_fixed: int = 0 @@ -45,7 +45,7 @@ class QdqFixResult: warnings: list[str] = field(default_factory=list) -def fix_qdq_dtype_info(model: onnx.ModelProto) -> QdqFixResult: +def fix_qdq_dtype_info(model: onnx.ModelProto) -> QDQFixResult: """Fix UNDEFINED dtype on QDQ node scale/zero_point tensors. Modifies the model **in-place**: @@ -61,9 +61,9 @@ def fix_qdq_dtype_info(model: onnx.ModelProto) -> QdqFixResult: model: ONNX ModelProto to fix (modified in-place). Returns: - QdqFixResult with counts of fixes applied. + QDQFixResult with counts of fixes applied. """ - result = QdqFixResult() + result = QDQFixResult() graph = model.graph # Step 1: Collect all scale/zp tensor names from QDQ nodes, diff --git a/src/winml/modelkit/sysinfo/software.py b/src/winml/modelkit/sysinfo/software.py index 4af92fbe8..90e050ddf 100644 --- a/src/winml/modelkit/sysinfo/software.py +++ b/src/winml/modelkit/sysinfo/software.py @@ -195,7 +195,7 @@ def calculate_folder_hash(folder_path: str) -> str | None: return hash_obj.hexdigest() -class EpPackage: +class EPPackage: """Represents an Execution Provider (EP) package.""" class SignatureKind(Enum): @@ -208,7 +208,7 @@ class SignatureKind(Enum): System = 4 # pylint: disable=invalid-name @staticmethod - def get_all() -> list["EpPackage"]: + def get_all() -> list["EPPackage"]: """Get all installed WinML Execution Provider packages.""" packages = [] for package in AppxPackage.get_by_hint("*WinML*EP*"): @@ -216,7 +216,7 @@ def get_all() -> list["EpPackage"]: if re.match(r"^Microsoft\.WindowsMLRuntime\.\d+\.\d+$", name): # skip the deprecated WinML runtime package continue - packages.append(EpPackage(package)) + packages.append(EPPackage(package)) return packages def __init__(self, appx_package: AppxPackage) -> None: @@ -226,7 +226,7 @@ def __init__(self, appx_package: AppxPackage) -> None: self._publisher = appx_package.get_property("Publisher", str) self._architecture = appx_package.get_property("Architecture", int) signature_kind = appx_package.get_property("SignatureKind", int) - self._signature_kind = EpPackage.SignatureKind(signature_kind) + self._signature_kind = EPPackage.SignatureKind(signature_kind) self._install_location = appx_package.get_property("InstallLocation", str) self._status = appx_package.get_property("Status", int) try: diff --git a/src/winml/modelkit/sysinfo/sysinfo.py b/src/winml/modelkit/sysinfo/sysinfo.py index 1112a94fb..f29a505c5 100644 --- a/src/winml/modelkit/sysinfo/sysinfo.py +++ b/src/winml/modelkit/sysinfo/sysinfo.py @@ -5,7 +5,7 @@ import re from .hardware import CPU, GPU, NPU, RAM -from .software import OS, EpPackage, PipPackage, PythonRuntime +from .software import OS, EPPackage, PipPackage, PythonRuntime class WindowsAppRuntimeVersion: @@ -49,7 +49,7 @@ def __init__(self) -> None: self._os = OS.get() self._python_runtime = PythonRuntime.get() self._pip_packages = PipPackage.get_all() - self._ep_packages = EpPackage.get_all() + self._ep_packages = EPPackage.get_all() self._windows_app_runtime_version = WindowsAppRuntimeVersion(self._pip_packages) @property @@ -88,7 +88,7 @@ def pip_packages(self) -> list[PipPackage]: return self._pip_packages @property - def ep_packages(self) -> list[EpPackage]: + def ep_packages(self) -> list[EPPackage]: """List of execution provider packages.""" return self._ep_packages diff --git a/tests/CLAUDE.md b/tests/CLAUDE.md new file mode 100644 index 000000000..da7844b25 --- /dev/null +++ b/tests/CLAUDE.md @@ -0,0 +1,17 @@ +# Test Convention + +Inherits all rules from [`/CLAUDE.md`](/CLAUDE.md). Additional test-specific rules below. + +Reference: [`/docs/pytest-best-practices.md`](/docs/pytest-best-practices.md) + +## Always + +- Place unit tests under `tests/unit//` mirroring `src/winml/modelkit//` +- Place integration tests under `tests/integration/`, e2e under `tests/e2e/` +- Put shared fixtures in the narrowest `conftest.py` that covers all consumers + +## Never + +- Create module directories directly under `tests/` — use `tests/unit//` instead +- Put `test_*.py` files in `assets/`, `fixtures/`, or `mock_data/` — those are helpers only +- Duplicate fixtures across multiple `conftest.py` files diff --git a/tests/optim/assets/fusionpipe/builders/attention.py b/tests/optim/assets/fusionpipe/builders/attention.py deleted file mode 100644 index 0d542820c..000000000 --- a/tests/optim/assets/fusionpipe/builders/attention.py +++ /dev/null @@ -1,553 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Attention pattern builders for FusionPipe testing. - -Creates ONNX graphs that match ORT's attention fusion patterns. -Based on: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/transformers/bert_model_generator.py - -Reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_attention.py -""" - -from __future__ import annotations - -import math - -import numpy as np -from onnx import ModelProto, TensorProto, helper - - -def bert_attention_builder( - input1_name: str, - input2_name: str, - mask_name: str, - output_name: str, - prefix: str, - initializers: list, - hidden_size: int = 16, - num_heads: int = 2, -) -> list: - """Create BERT-style attention pattern matching ORT's expected structure. - - Pattern (from ORT bert_model_generator.py): - input1 + input2 -> Add -> LayerNorm -> Q/K/V projections - Mask: Unsqueeze -> Unsqueeze -> Cast -> Sub -> Mul - QK: MatMul(Q, K^T) -> Div -> Add(mask) -> Softmax - Output: MatMul(attn, V) -> Transpose -> Reshape -> MatMul -> Add - Residual: output + layernorm_out -> Add -> LayerNorm - - This pattern is recognized by FusionAttention class. - - Args: - input1_name: Name of first input tensor [batch, seq, hidden] - input2_name: Name of second input tensor (for skip connection) - mask_name: Name of attention mask tensor [batch, seq] - output_name: Name of output tensor [batch, seq, hidden] - prefix: Unique prefix for node names - initializers: List to append weight tensors - hidden_size: Hidden dimension (default: 16) - num_heads: Number of attention heads (default: 2) - - Returns: - List of ONNX nodes forming the attention pattern - """ - rng = np.random.RandomState(hash(prefix) % (2**32)) - head_size = hidden_size // num_heads - nodes = [] - - # Weights - ln_weight = helper.make_tensor( - f"{prefix}ln_weight", TensorProto.FLOAT, [hidden_size], - rng.randn(hidden_size).astype(np.float32), - ) - ln_bias = helper.make_tensor( - f"{prefix}ln_bias", TensorProto.FLOAT, [hidden_size], - rng.randn(hidden_size).astype(np.float32), - ) - initializers.extend([ln_weight, ln_bias]) - - # Q, K, V projection weights - for proj in ["q", "k", "v"]: - weight = helper.make_tensor( - f"{prefix}{proj}_weight", TensorProto.FLOAT, [hidden_size, hidden_size], - rng.randn(hidden_size, hidden_size).astype(np.float32), - ) - bias = helper.make_tensor( - f"{prefix}{proj}_bias", TensorProto.FLOAT, [hidden_size], - rng.randn(hidden_size).astype(np.float32), - ) - initializers.extend([weight, bias]) - - # Output projection weights - out_weight = helper.make_tensor( - f"{prefix}out_weight", TensorProto.FLOAT, [hidden_size, hidden_size], - rng.randn(hidden_size, hidden_size).astype(np.float32), - ) - out_bias = helper.make_tensor( - f"{prefix}out_bias", TensorProto.FLOAT, [hidden_size], - rng.randn(hidden_size).astype(np.float32), - ) - initializers.extend([out_weight, out_bias]) - - # Reshape constants - reshape_qk = helper.make_tensor( - f"{prefix}reshape_qk", TensorProto.INT64, [4], - np.array([0, 0, num_heads, head_size], dtype=np.int64), - ) - reshape_out = helper.make_tensor( - f"{prefix}reshape_out", TensorProto.INT64, [3], - np.array([0, 0, hidden_size], dtype=np.int64), - ) - initializers.extend([reshape_qk, reshape_out]) - - # Div weight (sqrt(head_size)) - div_weight = helper.make_tensor( - f"{prefix}div_weight", TensorProto.FLOAT, [1], - np.array([math.sqrt(head_size)], dtype=np.float32), - ) - # Mask constants - sub_weight = helper.make_tensor( - f"{prefix}sub_weight", TensorProto.FLOAT, [1], - np.array([1.0], dtype=np.float32), - ) - mul_weight = helper.make_tensor( - f"{prefix}mul_weight", TensorProto.FLOAT, [1], - np.array([-10000.0], dtype=np.float32), - ) - # Unsqueeze axes - axes_1 = helper.make_tensor(f"{prefix}axes_1", TensorProto.INT64, [1], [1]) - axes_2 = helper.make_tensor(f"{prefix}axes_2", TensorProto.INT64, [1], [2]) - initializers.extend([div_weight, sub_weight, mul_weight, axes_1, axes_2]) - - # === NODES === - - # 1. Add + LayerNorm (entry point) - nodes.append(helper.make_node( - "Add", [input1_name, input2_name], [f"{prefix}ln_input"], - name=f"{prefix}add_ln", - )) - nodes.append(helper.make_node( - "LayerNormalization", - [f"{prefix}ln_input", f"{prefix}ln_weight", f"{prefix}ln_bias"], - [f"{prefix}ln_out"], - name=f"{prefix}layernorm", - axis=-1, epsilon=1e-5, - )) - - # 2. Q projection: MatMul -> Add -> Reshape -> Transpose - nodes.append(helper.make_node( - "MatMul", [f"{prefix}ln_out", f"{prefix}q_weight"], [f"{prefix}q_mm"], - name=f"{prefix}matmul_q", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}q_mm", f"{prefix}q_bias"], [f"{prefix}q_add"], - name=f"{prefix}add_q", - )) - nodes.append(helper.make_node( - "Reshape", [f"{prefix}q_add", f"{prefix}reshape_qk"], [f"{prefix}q_reshape"], - name=f"{prefix}reshape_q", - )) - nodes.append(helper.make_node( - "Transpose", [f"{prefix}q_reshape"], [f"{prefix}q_trans"], - name=f"{prefix}transpose_q", perm=[0, 2, 1, 3], - )) - - # 3. K projection: MatMul -> Add -> Reshape -> Transpose (different perm for K^T) - nodes.append(helper.make_node( - "MatMul", [f"{prefix}ln_out", f"{prefix}k_weight"], [f"{prefix}k_mm"], - name=f"{prefix}matmul_k", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}k_mm", f"{prefix}k_bias"], [f"{prefix}k_add"], - name=f"{prefix}add_k", - )) - nodes.append(helper.make_node( - "Reshape", [f"{prefix}k_add", f"{prefix}reshape_qk"], [f"{prefix}k_reshape"], - name=f"{prefix}reshape_k", - )) - nodes.append(helper.make_node( - "Transpose", [f"{prefix}k_reshape"], [f"{prefix}k_trans"], - name=f"{prefix}transpose_k", perm=[0, 2, 3, 1], # K^T - )) - - # 4. V projection: MatMul -> Add -> Reshape -> Transpose - nodes.append(helper.make_node( - "MatMul", [f"{prefix}ln_out", f"{prefix}v_weight"], [f"{prefix}v_mm"], - name=f"{prefix}matmul_v", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}v_mm", f"{prefix}v_bias"], [f"{prefix}v_add"], - name=f"{prefix}add_v", - )) - nodes.append(helper.make_node( - "Reshape", [f"{prefix}v_add", f"{prefix}reshape_qk"], [f"{prefix}v_reshape"], - name=f"{prefix}reshape_v", - )) - nodes.append(helper.make_node( - "Transpose", [f"{prefix}v_reshape"], [f"{prefix}v_trans"], - name=f"{prefix}transpose_v", perm=[0, 2, 1, 3], - )) - - # 5. Mask processing: Unsqueeze -> Unsqueeze -> Cast -> Sub -> Mul - nodes.append(helper.make_node( - "Unsqueeze", [mask_name, f"{prefix}axes_1"], [f"{prefix}mask_unsq1"], - name=f"{prefix}unsqueeze1", - )) - nodes.append(helper.make_node( - "Unsqueeze", [f"{prefix}mask_unsq1", f"{prefix}axes_2"], [f"{prefix}mask_unsq2"], - name=f"{prefix}unsqueeze2", - )) - nodes.append(helper.make_node( - "Cast", [f"{prefix}mask_unsq2"], [f"{prefix}mask_cast"], - name=f"{prefix}cast_mask", to=TensorProto.FLOAT, - )) - nodes.append(helper.make_node( - "Sub", [f"{prefix}sub_weight", f"{prefix}mask_cast"], [f"{prefix}mask_sub"], - name=f"{prefix}sub_mask", - )) - nodes.append(helper.make_node( - "Mul", [f"{prefix}mask_sub", f"{prefix}mul_weight"], [f"{prefix}mask_out"], - name=f"{prefix}mul_mask", - )) - - # 6. QK attention: MatMul(Q, K^T) -> Div -> Add(mask) -> Softmax - nodes.append(helper.make_node( - "MatMul", [f"{prefix}q_trans", f"{prefix}k_trans"], [f"{prefix}qk_mm"], - name=f"{prefix}matmul_qk", - )) - nodes.append(helper.make_node( - "Div", [f"{prefix}qk_mm", f"{prefix}div_weight"], [f"{prefix}qk_div"], - name=f"{prefix}div_qk", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}qk_div", f"{prefix}mask_out"], [f"{prefix}qk_add"], - name=f"{prefix}add_qk", - )) - nodes.append(helper.make_node( - "Softmax", [f"{prefix}qk_add"], [f"{prefix}attn_weights"], - name=f"{prefix}softmax", axis=3, - )) - - # 7. Attention @ V: MatMul -> Transpose -> Reshape - nodes.append(helper.make_node( - "MatMul", [f"{prefix}attn_weights", f"{prefix}v_trans"], [f"{prefix}attn_v"], - name=f"{prefix}matmul_attn_v", - )) - nodes.append(helper.make_node( - "Transpose", [f"{prefix}attn_v"], [f"{prefix}attn_trans"], - name=f"{prefix}transpose_attn", perm=[0, 2, 1, 3], - )) - nodes.append(helper.make_node( - "Reshape", [f"{prefix}attn_trans", f"{prefix}reshape_out"], [f"{prefix}attn_reshape"], - name=f"{prefix}reshape_attn", - )) - - # 8. Output projection: MatMul -> Add - nodes.append(helper.make_node( - "MatMul", [f"{prefix}attn_reshape", f"{prefix}out_weight"], [f"{prefix}out_mm"], - name=f"{prefix}matmul_out", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}out_mm", f"{prefix}out_bias"], [f"{prefix}out_add"], - name=f"{prefix}add_out", - )) - - # 9. Residual + Final LayerNorm: Add(output, ln_out) -> LayerNorm - nodes.append(helper.make_node( - "Add", [f"{prefix}out_add", f"{prefix}ln_out"], [f"{prefix}skip_out"], - name=f"{prefix}add_skip", - )) - nodes.append(helper.make_node( - "LayerNormalization", - [f"{prefix}skip_out", f"{prefix}ln_weight", f"{prefix}ln_bias"], - [output_name], - name=f"{prefix}layernorm2", - axis=-1, epsilon=1e-5, - )) - - return nodes - - -def create_bert_attention_model( - hidden_size: int = 16, - num_heads: int = 2, - seq_len: int = 10, - batch_size: int = 1, -) -> ModelProto: - """Create complete ONNX model with BERT attention pattern. - - This model matches ORT's bert_model_generator.py structure and should - be fusible by FusionAttention. - - Args: - hidden_size: Hidden dimension (default: 16) - num_heads: Number of attention heads (default: 2) - seq_len: Sequence length (default: 10) - batch_size: Batch size (default: 1) - - Returns: - Complete ONNX ModelProto ready for fusion testing - """ - initializers: list = [] - nodes = bert_attention_builder( - input1_name="input_1", - input2_name="input_2", - mask_name="attention_mask", - output_name="output", - prefix="attn_", - initializers=initializers, - hidden_size=hidden_size, - num_heads=num_heads, - ) - - # Inputs - input1 = helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] - ) - input2 = helper.make_tensor_value_info( - "input_2", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] - ) - mask = helper.make_tensor_value_info( - "attention_mask", TensorProto.INT64, [batch_size, seq_len] - ) - output = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] - ) - - graph = helper.make_graph( - nodes, - "bert_attention_test", - [input1, input2, mask], - [output], - initializers, - ) - - model = helper.make_model( - graph, - opset_imports=[helper.make_opsetid("", 17)], # opset 17 for LayerNormalization - ) - model.ir_version = 8 - - return model - - -def gpt2_attention_builder( - input_name: str, - output_name: str, - prefix: str, - initializers: list, - hidden_size: int = 16, - num_heads: int = 2, - seq_len: int = 3, -) -> list: - """Create GPT-2 style causal attention pattern. - - Note: GPT-2 attention has a different structure than BERT. - This is a simplified version for testing. - - Args: - input_name: Name of input tensor [batch, seq, hidden] - output_name: Name of output tensor [batch, seq, hidden] - prefix: Unique prefix for node names - initializers: List to append weight tensors - hidden_size: Hidden dimension (default: 16) - num_heads: Number of attention heads (default: 2) - seq_len: Sequence length (default: 3) - - Returns: - List of ONNX nodes forming the GPT-2 attention pattern - """ - rng = np.random.RandomState(hash(prefix) % (2**32)) - head_size = hidden_size // num_heads - nodes = [] - - # Combined QKV projection (GPT-2 style) - qkv_weight = helper.make_tensor( - f"{prefix}qkv_weight", TensorProto.FLOAT, [hidden_size, 3 * hidden_size], - rng.randn(hidden_size, 3 * hidden_size).astype(np.float32), - ) - qkv_bias = helper.make_tensor( - f"{prefix}qkv_bias", TensorProto.FLOAT, [3 * hidden_size], - rng.randn(3 * hidden_size).astype(np.float32), - ) - initializers.extend([qkv_weight, qkv_bias]) - - # Split sizes for QKV - split_sizes = helper.make_tensor( - f"{prefix}split_sizes", TensorProto.INT64, [3], - np.array([hidden_size, hidden_size, hidden_size], dtype=np.int64), - ) - initializers.append(split_sizes) - - # Reshape and other constants - reshape_shape = helper.make_tensor( - f"{prefix}reshape_shape", TensorProto.INT64, [4], - np.array([0, 0, num_heads, head_size], dtype=np.int64), - ) - reshape_back = helper.make_tensor( - f"{prefix}reshape_back", TensorProto.INT64, [3], - np.array([0, 0, hidden_size], dtype=np.int64), - ) - scale = helper.make_tensor( - f"{prefix}scale", TensorProto.FLOAT, [], - [np.float32(np.sqrt(head_size))], - ) - initializers.extend([reshape_shape, reshape_back, scale]) - - # Output projection - out_weight = helper.make_tensor( - f"{prefix}out_weight", TensorProto.FLOAT, [hidden_size, hidden_size], - rng.randn(hidden_size, hidden_size).astype(np.float32), - ) - out_bias = helper.make_tensor( - f"{prefix}out_bias", TensorProto.FLOAT, [hidden_size], - rng.randn(hidden_size).astype(np.float32), - ) - initializers.extend([out_weight, out_bias]) - - # QKV MatMul + Add - nodes.append(helper.make_node( - "MatMul", [input_name, f"{prefix}qkv_weight"], [f"{prefix}qkv_matmul"], - name=f"{prefix}qkv_matmul_node", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}qkv_matmul", f"{prefix}qkv_bias"], [f"{prefix}qkv_out"], - name=f"{prefix}qkv_add_node", - )) - - # Split QKV - nodes.append(helper.make_node( - "Split", [f"{prefix}qkv_out", f"{prefix}split_sizes"], - [f"{prefix}q_split", f"{prefix}k_split", f"{prefix}v_split"], - name=f"{prefix}qkv_split", axis=-1, - )) - - # Reshape and transpose Q, K, V - nodes.extend( - helper.make_node( - "Reshape", [f"{prefix}{proj}_split", f"{prefix}reshape_shape"], - [f"{prefix}{proj}_reshaped"], - name=f"{prefix}{proj}_reshape", allowzero=0, - ) - for proj in ["q", "k", "v"] - ) - - # Transpose Q and V - nodes.extend( - helper.make_node( - "Transpose", [f"{prefix}{proj}_reshaped"], [f"{prefix}{proj}_transposed"], - name=f"{prefix}{proj}_transpose", perm=[0, 2, 1, 3], - ) - for proj in ["q", "v"] - ) - - # K transpose for Q @ K^T - nodes.append(helper.make_node( - "Transpose", [f"{prefix}k_reshaped"], [f"{prefix}k_transposed"], - name=f"{prefix}k_transpose", perm=[0, 2, 3, 1], - )) - - # Q @ K^T - nodes.append(helper.make_node( - "MatMul", [f"{prefix}q_transposed", f"{prefix}k_transposed"], [f"{prefix}qk"], - name=f"{prefix}qk_matmul", - )) - - # Scale - nodes.append(helper.make_node( - "Div", [f"{prefix}qk", f"{prefix}scale"], [f"{prefix}qk_scaled"], - name=f"{prefix}div_scale", - )) - - # Softmax - nodes.append(helper.make_node( - "Softmax", [f"{prefix}qk_scaled"], [f"{prefix}attn_weights"], - name=f"{prefix}softmax", axis=3, - )) - - # Attention @ V - nodes.append(helper.make_node( - "MatMul", [f"{prefix}attn_weights", f"{prefix}v_transposed"], [f"{prefix}attn_out"], - name=f"{prefix}attn_v_matmul", - )) - - # Transpose back - nodes.append(helper.make_node( - "Transpose", [f"{prefix}attn_out"], [f"{prefix}attn_transposed"], - name=f"{prefix}attn_transpose", perm=[0, 2, 1, 3], - )) - - # Reshape back - nodes.append(helper.make_node( - "Reshape", [f"{prefix}attn_transposed", f"{prefix}reshape_back"], - [f"{prefix}attn_reshaped"], - name=f"{prefix}attn_reshape_back", allowzero=0, - )) - - # Output projection - nodes.append(helper.make_node( - "MatMul", [f"{prefix}attn_reshaped", f"{prefix}out_weight"], [f"{prefix}out_matmul"], - name=f"{prefix}out_matmul_node", - )) - nodes.append(helper.make_node( - "Add", [f"{prefix}out_matmul", f"{prefix}out_bias"], [output_name], - name=f"{prefix}out_add_node", - )) - - return nodes - - -def create_gpt2_attention_model( - hidden_size: int = 16, - num_heads: int = 2, - seq_len: int = 10, - batch_size: int = 1, -) -> ModelProto: - """Create complete ONNX model with GPT-2 attention pattern. - - Note: This is a simplified GPT-2 attention without causal masking. - Full GPT-2 fusion requires FusionGptAttention class. - - Args: - hidden_size: Hidden dimension (default: 16) - num_heads: Number of attention heads (default: 2) - seq_len: Sequence length (default: 10) - batch_size: Batch size (default: 1) - - Returns: - Complete ONNX ModelProto ready for testing - """ - initializers: list = [] - nodes = gpt2_attention_builder( - input_name="input", - output_name="output", - prefix="gpt2_attn_", - initializers=initializers, - hidden_size=hidden_size, - num_heads=num_heads, - seq_len=seq_len, - ) - - input_tensor = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] - ) - output_tensor = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] - ) - - graph = helper.make_graph( - nodes, - "gpt2_attention_test", - [input_tensor], - [output_tensor], - initializers, - ) - - model = helper.make_model( - graph, - opset_imports=[helper.make_opsetid("", 17)], - ) - model.ir_version = 8 - - return model diff --git a/tests/optim/fusions/__init__.py b/tests/optim/fusions/__init__.py deleted file mode 100644 index 862c45ce3..000000000 --- a/tests/optim/fusions/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/tests/optracing/__init__.py b/tests/optracing/__init__.py deleted file mode 100644 index 862c45ce3..000000000 --- a/tests/optracing/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/tests/sysinfo/__init__.py b/tests/sysinfo/__init__.py deleted file mode 100644 index 862c45ce3..000000000 --- a/tests/sysinfo/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/tests/unit/analyze/core/model_validators/test_validators.py b/tests/unit/analyze/core/model_validators/test_validators.py index 08afc84ff..23e2fff64 100644 --- a/tests/unit/analyze/core/model_validators/test_validators.py +++ b/tests/unit/analyze/core/model_validators/test_validators.py @@ -21,7 +21,6 @@ ModelValidatorManager, ) from winml.modelkit.analyze.models.onnx_model import ONNXModel -from winml.modelkit.analyze.models.onnx_op import OnnxOP from winml.modelkit.analyze.models.runtime_checks import ( NodeTag, PatternRuntime, @@ -48,9 +47,6 @@ def create_runtime_result_with_tags( op_type=op_type, ) - # Create OnnxOP for matched node - _matched_node = OnnxOP(node_name=node_name, op_type=op_type) - # Create a mock node proto for testing from onnx import helper diff --git a/tests/unit/analyze/core/test_node_checkers.py b/tests/unit/analyze/core/test_node_checkers.py index 2bcf8fc48..511b61174 100644 --- a/tests/unit/analyze/core/test_node_checkers.py +++ b/tests/unit/analyze/core/test_node_checkers.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """ -Unit tests for EpContextNodeChecker. +Unit tests for EPContextNodeChecker. Tests verify: - can_check() method correctly identifies EPContext nodes @@ -15,7 +15,7 @@ from onnx import helper from winml.modelkit.analyze.core.node_checkers.ep_context_node_checker import ( - EpContextNodeChecker, + EPContextNodeChecker, ) from winml.modelkit.analyze.models.runtime_checks import ( AlternativeType, @@ -27,13 +27,13 @@ from winml.modelkit.pattern.models import OperatorPattern, PatternType -class TestEpContextNodeChecker: - """Test EpContextNodeChecker implementation.""" +class TestEPContextNodeChecker: + """Test EPContextNodeChecker implementation.""" @pytest.fixture def ep_context_checker(self): - """Create EpContextNodeChecker instance.""" - return EpContextNodeChecker() + """Create EPContextNodeChecker instance.""" + return EPContextNodeChecker() @pytest.fixture def sample_pattern_match(self): diff --git a/tests/unit/analyze/core/test_pattern_deduplication.py b/tests/unit/analyze/core/test_pattern_deduplication.py index 1b59c1c50..e0e34deae 100644 --- a/tests/unit/analyze/core/test_pattern_deduplication.py +++ b/tests/unit/analyze/core/test_pattern_deduplication.py @@ -220,8 +220,8 @@ def test_matched_nodes_returns_string_list(self): assert match.matched_nodes[0] == "conv1" def test_matched_node_names_returns_onnx_ops(self): - """Test that matched_node_names returns OnnxOP objects.""" - from winml.modelkit.analyze.models.onnx_op import OnnxOP + """Test that matched_node_names returns ONNXOp objects.""" + from winml.modelkit.analyze.models.onnx_op import ONNXOp pattern = SubgraphPattern( pattern_id="SUBGRAPH/Test", @@ -245,10 +245,10 @@ def test_matched_node_names_returns_onnx_ops(self): type_param_to_type={}, ) - # matched_node_names should return list of OnnxOP objects + # matched_node_names should return list of ONNXOp objects assert isinstance(match.matched_node_names, list) assert len(match.matched_node_names) == 1 - assert isinstance(match.matched_node_names[0], OnnxOP) + assert isinstance(match.matched_node_names[0], ONNXOp) assert match.matched_node_names[0].op_type == "Relu" assert match.matched_node_names[0].node_name == "relu1" diff --git a/tests/build/__init__.py b/tests/unit/build/__init__.py similarity index 100% rename from tests/build/__init__.py rename to tests/unit/build/__init__.py diff --git a/tests/build/test_hf.py b/tests/unit/build/test_hf.py similarity index 100% rename from tests/build/test_hf.py rename to tests/unit/build/test_hf.py diff --git a/tests/build/test_module_summary.py b/tests/unit/build/test_module_summary.py similarity index 100% rename from tests/build/test_module_summary.py rename to tests/unit/build/test_module_summary.py diff --git a/tests/build/test_onnx.py b/tests/unit/build/test_onnx.py similarity index 100% rename from tests/build/test_onnx.py rename to tests/unit/build/test_onnx.py diff --git a/tests/cache/__init__.py b/tests/unit/cache/__init__.py similarity index 100% rename from tests/cache/__init__.py rename to tests/unit/cache/__init__.py diff --git a/tests/cache/test_model.py b/tests/unit/cache/test_model.py similarity index 99% rename from tests/cache/test_model.py rename to tests/unit/cache/test_model.py index e75ceadb6..145d96a7d 100644 --- a/tests/cache/test_model.py +++ b/tests/unit/cache/test_model.py @@ -169,6 +169,7 @@ def test_same_output_dir(self) -> None: # What from_pretrained computes from winml.modelkit.cache import get_model_dir + fp_dir = get_model_dir(model_id, cache_dir=cache_dir) # What CLI build --use-cache computes (same function now) diff --git a/tests/cache/test_path.py b/tests/unit/cache/test_path.py similarity index 92% rename from tests/cache/test_path.py rename to tests/unit/cache/test_path.py index 8c6cbf1c8..9e2edf382 100644 --- a/tests/cache/test_path.py +++ b/tests/unit/cache/test_path.py @@ -39,9 +39,7 @@ def test_env_var_override(self, monkeypatch: pytest.MonkeyPatch) -> None: result = get_cache_dir() assert result == Path("/custom/cache") - def test_explicit_override_takes_priority( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_explicit_override_takes_priority(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("WMK_CACHE_DIR", "/env/cache") result = get_cache_dir(override="/explicit/cache") assert result == Path("/explicit/cache") @@ -50,9 +48,7 @@ def test_explicit_override_as_path(self) -> None: result = get_cache_dir(override=Path("/some/path")) assert result == Path("/some/path") - def test_none_override_falls_through( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_none_override_falls_through(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("WMK_CACHE_DIR", raising=False) result = get_cache_dir(override=None) assert result == Path.home() / ".cache" / "winml" @@ -70,9 +66,7 @@ def test_appends_artifacts(self) -> None: result = get_artifacts_dir(Path("/cache/root")) assert result == Path("/cache/root/artifacts") - def test_none_resolves_default( - self, monkeypatch: pytest.MonkeyPatch - ) -> None: + def test_none_resolves_default(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("WMK_CACHE_DIR", raising=False) result = get_artifacts_dir() assert result == Path.home() / ".cache" / "winml" / "artifacts" diff --git a/tests/commands/test_build.py b/tests/unit/commands/test_build.py similarity index 100% rename from tests/commands/test_build.py rename to tests/unit/commands/test_build.py diff --git a/tests/commands/test_build_module.py b/tests/unit/commands/test_build_module.py similarity index 79% rename from tests/commands/test_build_module.py rename to tests/unit/commands/test_build_module.py index 70019f1e0..d4124f7eb 100644 --- a/tests/commands/test_build_module.py +++ b/tests/unit/commands/test_build_module.py @@ -28,11 +28,15 @@ def test_single_config_returns_single(self, tmp_path: Path) -> None: from winml.modelkit.config import WinMLBuildConfig cfg = tmp_path / "config.json" - cfg.write_text(json.dumps({ - "loader": {"task": "fill-mask", "model_type": "bert"}, - "export": {}, - "optim": {}, - })) + cfg.write_text( + json.dumps( + { + "loader": {"task": "fill-mask", "model_type": "bert"}, + "export": {}, + "optim": {}, + } + ) + ) result = _load_config(str(cfg), no_quant=False, no_compile=False) assert isinstance(result, WinMLBuildConfig) @@ -42,28 +46,32 @@ def test_array_config_returns_list(self, tmp_path: Path) -> None: from winml.modelkit.config import WinMLBuildConfig cfg = tmp_path / "modules.json" - cfg.write_text(json.dumps([ - { - "loader": { - "task": "fill-mask", - "model_type": "bert", - "model_class": "BertAttention", - "module_path": "encoder.layer.0.attention", - }, - "export": {}, - "optim": {}, - }, - { - "loader": { - "task": "fill-mask", - "model_type": "bert", - "model_class": "BertAttention", - "module_path": "encoder.layer.1.attention", - }, - "export": {}, - "optim": {}, - }, - ])) + cfg.write_text( + json.dumps( + [ + { + "loader": { + "task": "fill-mask", + "model_type": "bert", + "model_class": "BertAttention", + "module_path": "encoder.layer.0.attention", + }, + "export": {}, + "optim": {}, + }, + { + "loader": { + "task": "fill-mask", + "model_type": "bert", + "model_class": "BertAttention", + "module_path": "encoder.layer.1.attention", + }, + "export": {}, + "optim": {}, + }, + ] + ) + ) result = _load_config(str(cfg), no_quant=False, no_compile=False) assert isinstance(result, list) @@ -75,14 +83,23 @@ def test_array_config_returns_list(self, tmp_path: Path) -> None: def test_array_config_applies_no_quant(self, tmp_path: Path) -> None: """--no-quant applies to every config in the array.""" cfg = tmp_path / "modules.json" - cfg.write_text(json.dumps([ - { - "loader": {"task": "fill-mask", "model_type": "bert", - "model_class": "X", "module_path": "a"}, - "export": {}, "optim": {}, - "quant": {"task": "fill-mask", "model_name": "X", "samples": 1}, - }, - ])) + cfg.write_text( + json.dumps( + [ + { + "loader": { + "task": "fill-mask", + "model_type": "bert", + "model_class": "X", + "module_path": "a", + }, + "export": {}, + "optim": {}, + "quant": {"task": "fill-mask", "model_name": "X", "samples": 1}, + }, + ] + ) + ) result = _load_config(str(cfg), no_quant=True, no_compile=False) assert isinstance(result, list) @@ -159,12 +176,8 @@ def test_build_modules_calls_build_per_instance(self, tmp_path: Path) -> None: mock_result.elapsed = 1.0 with ( - patch( - "winml.modelkit.build.build_hf_model", return_value=mock_result - ) as mock_build, - patch( - "winml.modelkit.commands.build._instantiate_parent_model" - ) as mock_parent, + patch("winml.modelkit.build.build_hf_model", return_value=mock_result) as mock_build, + patch("winml.modelkit.commands.build._instantiate_parent_model") as mock_parent, ): mock_model = MagicMock() mock_parent.return_value = mock_model diff --git a/tests/test_cli.py b/tests/unit/commands/test_cli.py similarity index 100% rename from tests/test_cli.py rename to tests/unit/commands/test_cli.py diff --git a/tests/commands/test_compile_quantize_flags.py b/tests/unit/commands/test_compile_quantize_flags.py similarity index 100% rename from tests/commands/test_compile_quantize_flags.py rename to tests/unit/commands/test_compile_quantize_flags.py diff --git a/tests/commands/test_config_cli.py b/tests/unit/commands/test_config_cli.py similarity index 90% rename from tests/commands/test_config_cli.py rename to tests/unit/commands/test_config_cli.py index 85c531fbd..701cba937 100644 --- a/tests/commands/test_config_cli.py +++ b/tests/unit/commands/test_config_cli.py @@ -83,18 +83,25 @@ def test_help_shows_all_options(self, runner: CliRunner) -> None: # All documented options must appear in help text expected_options = [ - "--model", "-m", - "--task", "-t", + "--model", + "-m", + "--task", + "-t", "--model-class", "--model-type", "--module", - "--config", "-c", + "--config", + "-c", "--shape-config", - "--device", "-d", - "--precision", "-p", - "--output", "-o", + "--device", + "-d", + "--precision", + "-p", + "--output", + "-o", "--library", - "--verbose", "-v", + "--verbose", + "-v", "--no-quant", "--no-compile", "--trust-remote-code", @@ -166,9 +173,7 @@ def test_output_to_file( output_file = tmp_path / "out.json" result = runner.invoke(config, ["-m", "test", "-o", str(output_file)]) - assert result.exit_code == 0, ( - f"Output to file should succeed: {result.output}" - ) + assert result.exit_code == 0, f"Output to file should succeed: {result.output}" def test_model_type_without_model( self, @@ -178,12 +183,8 @@ def test_model_type_without_model( """--model-type bert --task fill-mask should be a valid entry point (no -m needed).""" from winml.modelkit.commands.config import config - result = runner.invoke( - config, ["--model-type", "bert", "--task", "fill-mask"] - ) - assert result.exit_code == 0, ( - f"model-type without model should succeed: {result.output}" - ) + result = runner.invoke(config, ["--model-type", "bert", "--task", "fill-mask"]) + assert result.exit_code == 0, f"model-type without model should succeed: {result.output}" def test_config_file_override( self, @@ -198,9 +199,7 @@ def test_config_file_override( override_file.write_text('{"loader": {"task": "text-classification"}}') result = runner.invoke(config, ["-m", "test", "-c", str(override_file)]) - assert result.exit_code == 0, ( - f"Config file override should succeed: {result.output}" - ) + assert result.exit_code == 0, f"Config file override should succeed: {result.output}" def test_shape_config_file( self, @@ -215,10 +214,7 @@ def test_shape_config_file( shapes_file.write_text('{"height": 224, "width": 224}') result = runner.invoke(config, ["-m", "test", "--shape-config", str(shapes_file)]) - assert result.exit_code == 0, ( - f"Shape config file should succeed: {result.output}" - ) - + assert result.exit_code == 0, f"Shape config file should succeed: {result.output}" def test_no_quant_sets_quant_none( self, @@ -343,9 +339,7 @@ def test_qdq_onnx_sets_quant_none(self, runner: CliRunner, tmp_path: Path) -> No f"Expected quant=null for QDQ model, got: {data.get('quant')}" ) - def test_qdq_onnx_output_confirms_no_quant( - self, runner: CliRunner, tmp_path: Path - ) -> None: + def test_qdq_onnx_output_confirms_no_quant(self, runner: CliRunner, tmp_path: Path) -> None: """Config for a QDQ ONNX should produce export=null and quant=null.""" from winml.modelkit.commands.config import config @@ -363,9 +357,7 @@ def test_qdq_onnx_output_confirms_no_quant( assert data.get("export") is None, "QDQ ONNX build should have export=null" assert data.get("quant") is None, "QDQ ONNX build should have quant=null" - def test_qdq_overrides_device_precision( - self, runner: CliRunner, tmp_path: Path - ) -> None: + def test_qdq_overrides_device_precision(self, runner: CliRunner, tmp_path: Path) -> None: """QDQ detection should keep quant=null even with -d npu -p int8.""" from winml.modelkit.commands.config import config @@ -376,19 +368,13 @@ def test_qdq_overrides_device_precision( patch("winml.modelkit.onnx.is_compiled_onnx", return_value=False), patch("winml.modelkit.onnx.is_quantized_onnx", return_value=True), ): - result = runner.invoke( - config, ["-m", str(onnx_file), "-d", "npu", "-p", "int8"] - ) + result = runner.invoke(config, ["-m", str(onnx_file), "-d", "npu", "-p", "int8"]) assert result.exit_code == 0, f"Failed: {result.output}" data = _extract_json(result.output) - assert data.get("quant") is None, ( - "QDQ detection should take precedence over -d npu -p int8" - ) + assert data.get("quant") is None, "QDQ detection should take precedence over -d npu -p int8" - def test_non_qdq_onnx_has_default_quant( - self, runner: CliRunner, tmp_path: Path - ) -> None: + def test_non_qdq_onnx_has_default_quant(self, runner: CliRunner, tmp_path: Path) -> None: """Config for non-QDQ ONNX should have default quant settings (not null).""" from winml.modelkit.commands.config import config diff --git a/tests/commands/test_export.py b/tests/unit/commands/test_export.py similarity index 98% rename from tests/commands/test_export.py rename to tests/unit/commands/test_export.py index 8c388dfbc..390ea56dd 100644 --- a/tests/commands/test_export.py +++ b/tests/unit/commands/test_export.py @@ -483,12 +483,16 @@ def test_export_uses_resolve_export_config( ], ) mock_loader_cfg = WinMLLoaderConfig( - task="image-classification", model_type="resnet", + task="image-classification", + model_type="resnet", ) - with patch("winml.modelkit.loader.load_hf_model") as mock_load, patch( - "winml.modelkit.export.config.resolve_export_config", - return_value=(mock_export_cfg, mock_loader_cfg), + with ( + patch("winml.modelkit.loader.load_hf_model") as mock_load, + patch( + "winml.modelkit.export.config.resolve_export_config", + return_value=(mock_export_cfg, mock_loader_cfg), + ), ): mock_model = MagicMock() mock_load.return_value = (mock_model, None, "image-classification") diff --git a/tests/commands/test_inspect_cli.py b/tests/unit/commands/test_inspect_cli.py similarity index 92% rename from tests/commands/test_inspect_cli.py rename to tests/unit/commands/test_inspect_cli.py index d31ff9e04..a5dca0f24 100644 --- a/tests/commands/test_inspect_cli.py +++ b/tests/unit/commands/test_inspect_cli.py @@ -84,8 +84,16 @@ def test_help_shows_all_options(self, runner: CliRunner) -> None: result = runner.invoke(inspect, ["--help"]) assert result.exit_code == 0 for flag in [ - "--model", "-m", "--format", "-f", "--verbose", "-v", - "--task", "-t", "--hierarchy", "-H", + "--model", + "-m", + "--format", + "-f", + "--verbose", + "-v", + "--task", + "-t", + "--hierarchy", + "-H", ]: assert flag in result.output, f"Missing flag {flag} in help output" @@ -111,7 +119,9 @@ class TestInspectOutputFormat: """Test output format dispatching (json vs table).""" def test_json_format_accepted( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect @@ -126,7 +136,9 @@ def test_json_format_accepted( mock_table.assert_not_called() def test_table_format_default( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect @@ -150,7 +162,9 @@ class TestInspectFlagCombinations: """Test flag combinations and kwarg passing.""" def test_all_flags_combine( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect @@ -167,7 +181,9 @@ def test_all_flags_combine( assert result.exit_code == 0, f"Failed: {result.output}" def test_task_override_passed_to_api( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect @@ -182,7 +198,9 @@ def test_task_override_passed_to_api( assert call_kwargs["task_override"] == "fill-mask" def test_hierarchy_flag_passed_to_api( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect @@ -197,7 +215,9 @@ def test_hierarchy_flag_passed_to_api( assert call_kwargs["include_hierarchy"] is True def test_verbose_flag_default_false( - self, runner: CliRunner, mock_inspect_result: MagicMock, + self, + runner: CliRunner, + mock_inspect_result: MagicMock, ) -> None: from winml.modelkit.commands.inspect import inspect diff --git a/tests/commands/test_perf_cli.py b/tests/unit/commands/test_perf_cli.py similarity index 95% rename from tests/commands/test_perf_cli.py rename to tests/unit/commands/test_perf_cli.py index 3a4fea18c..7cc2d4335 100644 --- a/tests/commands/test_perf_cli.py +++ b/tests/unit/commands/test_perf_cli.py @@ -103,7 +103,6 @@ def test_valid_device_choices(self, runner: CliRunner, device: str) -> None: assert device in result.output - # ============================================================================= # OUTPUT PATH TESTS # ============================================================================= @@ -114,16 +113,19 @@ class TestPerfOutputPath: def test_hf_model_path(self) -> None: from winml.modelkit.commands.perf import generate_output_path + result = generate_output_path("microsoft/resnet-50") assert result.name == "microsoft_resnet-50_perf.json" def test_onnx_file_uses_stem(self) -> None: from winml.modelkit.commands.perf import generate_output_path + result = generate_output_path("/path/to/model.onnx") assert result.name == "model_perf.json" def test_onnx_no_leading_underscore(self) -> None: from winml.modelkit.commands.perf import generate_output_path + result = generate_output_path("./model.onnx") assert not result.name.startswith("._") assert result.name == "model_perf.json" @@ -131,6 +133,7 @@ def test_onnx_no_leading_underscore(self) -> None: def test_windows_path_handled(self) -> None: """Backslashes in paths should be replaced.""" from winml.modelkit.commands.perf import generate_output_path + result = generate_output_path("C:\\models\\bert-base") assert "\\" not in result.name # On Windows, Path("C:_models_bert-base_perf.json").name strips the @@ -267,19 +270,23 @@ def test_no_quantize_false_passes_no_override(self) -> None: override = mock_from_pretrained.call_args.kwargs["config"] assert override is None - def test_cli_onnx_goes_through_perfbenchmark( - self, runner: CliRunner, tmp_path: Path - ) -> None: + def test_cli_onnx_goes_through_perfbenchmark(self, runner: CliRunner, tmp_path: Path) -> None: """CLI with .onnx file should route through PerfBenchmark, not _run_onnx_benchmark.""" onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake onnx") - with patch.object( - PerfBenchmark, "run", return_value=MagicMock(), - ) as mock_run, patch( - "winml.modelkit.commands.perf.display_console_report", - ), patch( - "winml.modelkit.commands.perf.write_json_report", + with ( + patch.object( + PerfBenchmark, + "run", + return_value=MagicMock(), + ) as mock_run, + patch( + "winml.modelkit.commands.perf.display_console_report", + ), + patch( + "winml.modelkit.commands.perf.write_json_report", + ), ): result = runner.invoke( perf, @@ -290,9 +297,7 @@ def test_cli_onnx_goes_through_perfbenchmark( assert result.exit_code == 0, result.output mock_run.assert_called_once() - def test_cli_onnx_not_found_error( - self, runner: CliRunner, tmp_path: Path - ) -> None: + def test_cli_onnx_not_found_error(self, runner: CliRunner, tmp_path: Path) -> None: """CLI with non-existent .onnx file should raise FileNotFoundError.""" missing = tmp_path / "missing.onnx" result = runner.invoke( diff --git a/tests/commands/test_perf_module.py b/tests/unit/commands/test_perf_module.py similarity index 100% rename from tests/commands/test_perf_module.py rename to tests/unit/commands/test_perf_module.py diff --git a/tests/compiler/test_compile_command.py b/tests/unit/compiler/test_compile_command.py similarity index 100% rename from tests/compiler/test_compile_command.py rename to tests/unit/compiler/test_compile_command.py diff --git a/tests/compiler/test_compiler_configs.py b/tests/unit/compiler/test_compiler_configs.py similarity index 98% rename from tests/compiler/test_compiler_configs.py rename to tests/unit/compiler/test_compiler_configs.py index ccb9a1a67..581321cd5 100644 --- a/tests/compiler/test_compiler_configs.py +++ b/tests/unit/compiler/test_compiler_configs.py @@ -159,9 +159,7 @@ def test_from_dict_basic(self): } config = WinMLCompileConfig.from_dict(data) assert config.ep_config.provider == "qnn" - assert config.ep_config.provider_options == { - "htp_performance_mode": "default" - } + assert config.ep_config.provider_options == {"htp_performance_mode": "default"} assert config.ep_config.enable_ep_context is True assert config.validate is True @@ -270,9 +268,7 @@ def test_no_quantize_no_warning(self, factory_method: str): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") factory() - deprecation_warnings = [ - x for x in w if issubclass(x.category, DeprecationWarning) - ] + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] assert len(deprecation_warnings) == 0 diff --git a/tests/compiler/test_compiler_stages.py b/tests/unit/compiler/test_compiler_stages.py similarity index 100% rename from tests/compiler/test_compiler_stages.py rename to tests/unit/compiler/test_compiler_stages.py diff --git a/tests/compiler/test_utils.py b/tests/unit/compiler/test_utils.py similarity index 100% rename from tests/compiler/test_utils.py rename to tests/unit/compiler/test_utils.py diff --git a/tests/config/__init__.py b/tests/unit/config/__init__.py similarity index 100% rename from tests/config/__init__.py rename to tests/unit/config/__init__.py diff --git a/tests/config/conftest.py b/tests/unit/config/conftest.py similarity index 100% rename from tests/config/conftest.py rename to tests/unit/config/conftest.py diff --git a/tests/config/test_build.py b/tests/unit/config/test_build.py similarity index 99% rename from tests/config/test_build.py rename to tests/unit/config/test_build.py index b2bf7b07b..82fda679e 100644 --- a/tests/config/test_build.py +++ b/tests/unit/config/test_build.py @@ -33,7 +33,7 @@ _build_submodule_config, resolve_quant_compile_config, ) -from winml.modelkit.export import OnnxConfigNotFoundError, resolve_io_specs +from winml.modelkit.export import ONNXConfigNotFoundError, resolve_io_specs from winml.modelkit.export.config import InputTensorSpec, OutputTensorSpec, WinMLExportConfig from winml.modelkit.loader.config import WinMLLoaderConfig from winml.modelkit.optim.config import WinMLOptimizationConfig @@ -368,7 +368,7 @@ def test_optimum_fails_registry_fills_in( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=OnnxConfigNotFoundError("blip not supported"), + side_effect=ONNXConfigNotFoundError("blip not supported"), ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"blip": blip_like_config}), ): @@ -473,7 +473,7 @@ def test_registry_merge_does_not_mutate_singleton( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=OnnxConfigNotFoundError("unsupported"), + side_effect=ONNXConfigNotFoundError("unsupported"), ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"some-vision": registry_config}), ): @@ -512,7 +512,7 @@ def test_underscore_normalization( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=OnnxConfigNotFoundError("unsupported"), + side_effect=ONNXConfigNotFoundError("unsupported"), ), # Registry uses hyphens patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {"clip-text-model": clip_config}), @@ -535,7 +535,7 @@ def test_no_registry_no_optimum_returns_empty( ), patch( "winml.modelkit.config.build._resolve_export_config_from_specs", - side_effect=OnnxConfigNotFoundError("unsupported"), + side_effect=ONNXConfigNotFoundError("unsupported"), ), patch("winml.modelkit.models.hf.MODEL_BUILD_CONFIGS", {}), ): diff --git a/tests/config/test_build_onnx.py b/tests/unit/config/test_build_onnx.py similarity index 97% rename from tests/config/test_build_onnx.py rename to tests/unit/config/test_build_onnx.py index 3e0973680..218e0401b 100644 --- a/tests/config/test_build_onnx.py +++ b/tests/unit/config/test_build_onnx.py @@ -369,7 +369,8 @@ def test_task_stored_in_loader(self, tmp_path) -> None: ), ): config = generate_onnx_build_config( - str(onnx_file), task="image-classification", + str(onnx_file), + task="image-classification", ) assert config.loader.task == "image-classification" @@ -416,7 +417,9 @@ def test_override_applied_last(self, tmp_path) -> None: ), ): config = generate_onnx_build_config( - str(onnx_file), device="npu", override=override, + str(onnx_file), + device="npu", + override=override, ) assert config.optim["gelu_fusion"] is True @@ -449,7 +452,9 @@ def test_override_quant_none_on_raw(self, tmp_path) -> None: # Call with a real override that sets quant=None override_cfg = WinMLBuildConfig.from_dict(override_dict) config = generate_onnx_build_config( - str(onnx_file), device="npu", override=override_cfg, + str(onnx_file), + device="npu", + override=override_cfg, ) mock_merge.assert_called_once() @@ -476,7 +481,8 @@ def test_override_on_compiled_model(self, tmp_path) -> None: patch("winml.modelkit.onnx.is_quantized_onnx", return_value=False), ): config = generate_onnx_build_config( - str(onnx_file), override=override, + str(onnx_file), + override=override, ) # Override is applied after compiled detection, so quant is non-None. @@ -502,7 +508,9 @@ def test_override_none_is_noop(self, tmp_path) -> None: ), ): config = generate_onnx_build_config( - str(onnx_file), device="npu", override=None, + str(onnx_file), + device="npu", + override=None, ) # Without override, raw+npu should have quant and compile @@ -622,7 +630,9 @@ def test_ep_override_forwarded(self, tmp_path) -> None: ), ): config = generate_onnx_build_config( - str(onnx_file), device="gpu", ep="migraphx", + str(onnx_file), + device="gpu", + ep="migraphx", ) assert config.compile is not None @@ -696,7 +706,8 @@ def test_ep_override_changes_provider(self) -> None: return_value=("gpu", ["gpu", "cpu"]), ): _quant, compile_cfg = resolve_quant_compile_config( - device="gpu", ep="tensorrt", + device="gpu", + ep="tensorrt", ) assert compile_cfg is not None @@ -732,7 +743,8 @@ def test_explicit_int8_precision_on_npu(self) -> None: return_value=("npu", ["npu", "cpu"]), ): quant, _compile_cfg = resolve_quant_compile_config( - device="npu", precision="int8", + device="npu", + precision="int8", ) assert quant is not None @@ -746,7 +758,8 @@ def test_explicit_fp32_precision_no_quant(self) -> None: return_value=("gpu", ["gpu", "cpu"]), ): quant, _compile_cfg = resolve_quant_compile_config( - device="gpu", precision="fp32", + device="gpu", + precision="fp32", ) assert quant is None diff --git a/tests/config/test_precision.py b/tests/unit/config/test_precision.py similarity index 96% rename from tests/config/test_precision.py rename to tests/unit/config/test_precision.py index 1ddb3a64f..d1f57cc11 100644 --- a/tests/config/test_precision.py +++ b/tests/unit/config/test_precision.py @@ -98,7 +98,9 @@ def test_auto_device_picks_best( ) -> None: """device='auto' + explicit precision picks best from available_devices.""" policy = resolve_precision( - device="auto", precision=precision, available_devices=available, + device="auto", + precision=precision, + available_devices=available, ) assert policy.device == exp_device @@ -124,8 +126,6 @@ def test_unknown_precision_raises(self) -> None: resolve_precision(device="cpu", precision="bfloat16") - - # ============================================================================= # TestGpuLlmWarning - GPU + LLM task warning # ============================================================================= @@ -265,9 +265,7 @@ class TestResolveQuantTypes: ("int16", "int16", "uint16"), ], ) - def test_named_presets( - self, precision: str, exp_weight: str, exp_act: str - ) -> None: + def test_named_presets(self, precision: str, exp_weight: str, exp_act: str) -> None: """Named quantized presets resolve to correct weight/activation types.""" w, a = resolve_quant_types(precision) assert w == exp_weight @@ -283,9 +281,7 @@ def test_named_presets( ("w16a16", "int16", "uint16"), ], ) - def test_mixed_format_valid( - self, precision: str, exp_weight: str, exp_act: str - ) -> None: + def test_mixed_format_valid(self, precision: str, exp_weight: str, exp_act: str) -> None: """Valid w{x}a{y} combinations resolve to correct types.""" w, a = resolve_quant_types(precision) assert w == exp_weight @@ -343,9 +339,7 @@ def test_non_numeric_mixed_raises(self) -> None: ("Int16", "int16", "uint16"), ], ) - def test_case_insensitive( - self, precision: str, exp_weight: str, exp_act: str - ) -> None: + def test_case_insensitive(self, precision: str, exp_weight: str, exp_act: str) -> None: """resolve_quant_types should be case-insensitive.""" w, a = resolve_quant_types(precision) assert w == exp_weight @@ -400,9 +394,7 @@ def test_unsupported_bits_return_false(self, precision: str) -> None: assert is_quantized_precision(precision) is False # ---- False cases: completely invalid ---- - @pytest.mark.parametrize( - "precision", ["garbage", "wXaY", "", "bfloat16", "w0a0"] - ) + @pytest.mark.parametrize("precision", ["garbage", "wXaY", "", "bfloat16", "w0a0"]) def test_invalid_strings_return_false(self, precision: str) -> None: """Completely invalid precision strings must return False.""" assert is_quantized_precision(precision) is False @@ -435,10 +427,10 @@ class TestMixedPrecisionAutoDevice: "precision,available,exp_device", [ ("w8a16", ["npu", "gpu", "cpu"], "npu"), # prefers NPU - ("w8a16", ["gpu", "cpu"], "gpu"), # no NPU, falls to first - ("w8a8", ["npu", "gpu", "cpu"], "npu"), # prefers NPU - ("w16a16", ["npu", "cpu"], "npu"), # prefers NPU - ("w8a16", ["cpu"], "cpu"), # only CPU available + ("w8a16", ["gpu", "cpu"], "gpu"), # no NPU, falls to first + ("w8a8", ["npu", "gpu", "cpu"], "npu"), # prefers NPU + ("w16a16", ["npu", "cpu"], "npu"), # prefers NPU + ("w8a16", ["cpu"], "cpu"), # only CPU available ], ) def test_mixed_precision_auto_device( @@ -449,7 +441,9 @@ def test_mixed_precision_auto_device( ) -> None: """device='auto' + w{x}a{y} precision picks best from available_devices.""" policy = resolve_precision( - device="auto", precision=precision, available_devices=available, + device="auto", + precision=precision, + available_devices=available, ) assert policy.device == exp_device @@ -594,9 +588,7 @@ def test_explicit_activation_overrides_precision(self) -> None: def test_both_explicit_override_precision(self) -> None: """Both explicit flags should override w8a16 defaults entirely.""" - w, a = self._resolve( - precision="w8a16", weight_type="int8", activation_type="int16" - ) + w, a = self._resolve(precision="w8a16", weight_type="int8", activation_type="int16") assert w == "int8" assert a == "int16" diff --git a/tests/core/test_node_metadata.py b/tests/unit/core/test_node_metadata.py similarity index 100% rename from tests/core/test_node_metadata.py rename to tests/unit/core/test_node_metadata.py diff --git a/tests/core/test_onnx_utils.py b/tests/unit/core/test_onnx_utils.py similarity index 88% rename from tests/core/test_onnx_utils.py rename to tests/unit/core/test_onnx_utils.py index 77248a3c6..353776031 100644 --- a/tests/core/test_onnx_utils.py +++ b/tests/unit/core/test_onnx_utils.py @@ -18,12 +18,8 @@ class TestGetIoConfig: def test_single_input_single_output(self) -> None: """Test simple model with one input and one output.""" # Create simple model: input -> Identity -> output - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 3, 224, 224] - ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 3, 224, 224] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 224, 224]) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 224, 224]) node = helper.make_node("Identity", ["input"], ["output"]) graph = helper.make_graph([node], "test_graph", [x_info], [y_info]) @@ -50,9 +46,7 @@ def test_multiple_inputs_outputs(self) -> None: node1 = helper.make_node("Identity", ["input_a"], ["output_x"]) node2 = helper.make_node("Identity", ["input_b"], ["output_y"]) - graph = helper.make_graph( - [node1, node2], "test_graph", [in_a, in_b], [out_x, out_y] - ) + graph = helper.make_graph([node1, node2], "test_graph", [in_a, in_b], [out_x, out_y]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) config = get_io_config(model) @@ -65,12 +59,8 @@ def test_multiple_inputs_outputs(self) -> None: def test_dynamic_dimensions(self) -> None: """Test model with dynamic batch dimension.""" # Create model with dynamic batch size (None in shape) - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, ["batch", 3, 224, 224] - ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, ["batch", 1000] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, 224, 224]) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 1000]) node = helper.make_node("Identity", ["input"], ["output"]) graph = helper.make_graph([node], "test_graph", [x_info], [y_info]) @@ -85,18 +75,10 @@ def test_dynamic_dimensions(self) -> None: def test_various_dtypes(self) -> None: """Test model with various data types.""" # Create inputs with different dtypes - float32_input = helper.make_tensor_value_info( - "float32_in", TensorProto.FLOAT, [1, 10] - ) - int64_input = helper.make_tensor_value_info( - "int64_in", TensorProto.INT64, [1, 10] - ) - float16_output = helper.make_tensor_value_info( - "float16_out", TensorProto.FLOAT16, [1, 10] - ) - int32_output = helper.make_tensor_value_info( - "int32_out", TensorProto.INT32, [1, 10] - ) + float32_input = helper.make_tensor_value_info("float32_in", TensorProto.FLOAT, [1, 10]) + int64_input = helper.make_tensor_value_info("int64_in", TensorProto.INT64, [1, 10]) + float16_output = helper.make_tensor_value_info("float16_out", TensorProto.FLOAT16, [1, 10]) + int32_output = helper.make_tensor_value_info("int32_out", TensorProto.INT32, [1, 10]) node1 = helper.make_node("Identity", ["float32_in"], ["float16_out"]) node2 = helper.make_node("Identity", ["int64_in"], ["int32_out"]) @@ -152,9 +134,7 @@ def test_mixed_static_dynamic_dims(self) -> None: x_info = helper.make_tensor_value_info( "input", TensorProto.FLOAT, ["batch", 3, "height", "width"] ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, ["batch", 1000] - ) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 1000]) node = helper.make_node("Identity", ["input"], ["output"]) graph = helper.make_graph([node], "test_graph", [x_info], [y_info]) @@ -202,9 +182,7 @@ def test_text_model_pattern(self) -> None: ) node = helper.make_node("Identity", ["input_ids"], ["last_hidden_state"]) - graph = helper.make_graph( - [node], "encoder", [input_ids, attention_mask], [hidden_state] - ) + graph = helper.make_graph([node], "encoder", [input_ids, attention_mask], [hidden_state]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) config = get_io_config(model) diff --git a/tests/dataset_tests/test_object_detection.py b/tests/unit/datasets/test_object_detection.py similarity index 94% rename from tests/dataset_tests/test_object_detection.py rename to tests/unit/datasets/test_object_detection.py index 3085457cf..e87936a90 100644 --- a/tests/dataset_tests/test_object_detection.py +++ b/tests/unit/datasets/test_object_detection.py @@ -77,11 +77,10 @@ class TestObjectDetectionDatasetDeriveOverrides: def dataset_class(self) -> type: """Get ObjectDetectionDataset class without instantiation.""" from winml.modelkit.datasets.object_detection import ObjectDetectionDataset + return ObjectDetectionDataset - def test_no_io_config_returns_empty_overrides( - self, dataset_class: type - ) -> None: + def test_no_io_config_returns_empty_overrides(self, dataset_class: type) -> None: """Should return empty dict when io_config is None.""" # Create instance without full initialization to test method instance = object.__new__(dataset_class) @@ -89,9 +88,7 @@ def test_no_io_config_returns_empty_overrides( assert overrides == {} - def test_sets_do_pad_false_when_no_pixel_mask( - self, dataset_class: type - ) -> None: + def test_sets_do_pad_false_when_no_pixel_mask(self, dataset_class: type) -> None: """Should set do_pad=False when pixel_mask is not in io_config.""" instance = object.__new__(dataset_class) io_config = { @@ -103,9 +100,7 @@ def test_sets_do_pad_false_when_no_pixel_mask( assert overrides["do_pad"] is False - def test_does_not_set_do_pad_when_pixel_mask_present( - self, dataset_class: type - ) -> None: + def test_does_not_set_do_pad_when_pixel_mask_present(self, dataset_class: type) -> None: """Should not set do_pad when pixel_mask is in io_config.""" instance = object.__new__(dataset_class) io_config = { @@ -117,9 +112,7 @@ def test_does_not_set_do_pad_when_pixel_mask_present( assert "do_pad" not in overrides - def test_extracts_size_from_pixel_values_shape( - self, dataset_class: type - ) -> None: + def test_extracts_size_from_pixel_values_shape(self, dataset_class: type) -> None: """Should extract height/width from pixel_values shape.""" instance = object.__new__(dataset_class) io_config = { @@ -130,9 +123,7 @@ def test_extracts_size_from_pixel_values_shape( assert overrides["size"] == {"height": 800, "width": 1200} - def test_handles_dynamic_dimensions( - self, dataset_class: type - ) -> None: + def test_handles_dynamic_dimensions(self, dataset_class: type) -> None: """Should not set size when dimensions are dynamic (None).""" instance = object.__new__(dataset_class) io_config = { @@ -143,9 +134,7 @@ def test_handles_dynamic_dimensions( assert "size" not in overrides - def test_handles_partial_dynamic_dimensions( - self, dataset_class: type - ) -> None: + def test_handles_partial_dynamic_dimensions(self, dataset_class: type) -> None: """Should not set size when any dimension is dynamic.""" instance = object.__new__(dataset_class) io_config = { @@ -156,9 +145,7 @@ def test_handles_partial_dynamic_dimensions( assert "size" not in overrides - def test_handles_missing_shape_key( - self, dataset_class: type - ) -> None: + def test_handles_missing_shape_key(self, dataset_class: type) -> None: """Should handle pixel_values without shape key.""" instance = object.__new__(dataset_class) io_config = { @@ -171,9 +158,7 @@ def test_handles_missing_shape_key( assert overrides["do_pad"] is False assert "size" not in overrides - def test_handles_short_shape_list( - self, dataset_class: type - ) -> None: + def test_handles_short_shape_list(self, dataset_class: type) -> None: """Should handle shape with fewer than 4 dimensions.""" instance = object.__new__(dataset_class) io_config = { diff --git a/tests/dataset_tests/test_random_dataset.py b/tests/unit/datasets/test_random_dataset.py similarity index 100% rename from tests/dataset_tests/test_random_dataset.py rename to tests/unit/datasets/test_random_dataset.py diff --git a/tests/eval/test_align_labels.py b/tests/unit/eval/test_align_labels.py similarity index 100% rename from tests/eval/test_align_labels.py rename to tests/unit/eval/test_align_labels.py diff --git a/tests/eval/test_eval.py b/tests/unit/eval/test_eval.py similarity index 100% rename from tests/eval/test_eval.py rename to tests/unit/eval/test_eval.py diff --git a/tests/eval/test_image_segmentation_evaluator.py b/tests/unit/eval/test_image_segmentation_evaluator.py similarity index 100% rename from tests/eval/test_image_segmentation_evaluator.py rename to tests/unit/eval/test_image_segmentation_evaluator.py diff --git a/tests/eval/test_map_metric.py b/tests/unit/eval/test_map_metric.py similarity index 100% rename from tests/eval/test_map_metric.py rename to tests/unit/eval/test_map_metric.py diff --git a/tests/eval/test_object_detection_evaluator.py b/tests/unit/eval/test_object_detection_evaluator.py similarity index 100% rename from tests/eval/test_object_detection_evaluator.py rename to tests/unit/eval/test_object_detection_evaluator.py diff --git a/tests/test_text_classification.py b/tests/unit/eval/test_text_classification.py similarity index 99% rename from tests/test_text_classification.py rename to tests/unit/eval/test_text_classification.py index ec0b62771..b5060e921 100644 --- a/tests/test_text_classification.py +++ b/tests/unit/eval/test_text_classification.py @@ -607,6 +607,7 @@ def test_calibration_reader_compatibility(self): # Should be numpy arrays (not torch tensors) import numpy as np + assert isinstance(sample["input_ids"], np.ndarray) # Label should be excluded diff --git a/tests/export/conftest.py b/tests/unit/export/conftest.py similarity index 100% rename from tests/export/conftest.py rename to tests/unit/export/conftest.py diff --git a/tests/export/test_all_architectures_io.py b/tests/unit/export/test_all_architectures_io.py similarity index 100% rename from tests/export/test_all_architectures_io.py rename to tests/unit/export/test_all_architectures_io.py diff --git a/tests/export/test_blip_onnx_config.py b/tests/unit/export/test_blip_onnx_config.py similarity index 97% rename from tests/export/test_blip_onnx_config.py rename to tests/unit/export/test_blip_onnx_config.py index 8e0fd5d07..bae6955ac 100644 --- a/tests/export/test_blip_onnx_config.py +++ b/tests/unit/export/test_blip_onnx_config.py @@ -117,9 +117,7 @@ def test_text_inputs_use_max_position_embeddings(self, task: str, blip_config) - inputs = generate_dummy_inputs("blip", task, blip_config) seq_len = inputs["input_ids"].shape[1] expected = blip_config.text_config.max_position_embeddings - assert seq_len == expected, ( - f"Expected seq_len={expected}, got {seq_len}." - ) + assert seq_len == expected, f"Expected seq_len={expected}, got {seq_len}." @pytest.mark.parametrize( "task", diff --git a/tests/export/test_config_validation.py b/tests/unit/export/test_config_validation.py similarity index 100% rename from tests/export/test_config_validation.py rename to tests/unit/export/test_config_validation.py diff --git a/tests/export/test_io.py b/tests/unit/export/test_io.py similarity index 100% rename from tests/export/test_io.py rename to tests/unit/export/test_io.py diff --git a/tests/export/test_io_specs.py b/tests/unit/export/test_io_specs.py similarity index 99% rename from tests/export/test_io_specs.py rename to tests/unit/export/test_io_specs.py index 37d6b7ebf..fe8f3c877 100644 --- a/tests/export/test_io_specs.py +++ b/tests/unit/export/test_io_specs.py @@ -183,9 +183,7 @@ def test_value_ranges_is_dict( ranges = specs["value_ranges"] assert isinstance(ranges, dict) for name in ranges: - assert name in specs["input_names"], ( - f"Value range for unknown input '{name}'" - ) + assert name in specs["input_names"], f"Value range for unknown input '{name}'" def test_value_ranges_are_numeric_tuples( self, model_type: str, task: str, config_fixture: str, request: pytest.FixtureRequest diff --git a/tests/export/test_onnx_config_overrides.py b/tests/unit/export/test_onnx_config_overrides.py similarity index 100% rename from tests/export/test_onnx_config_overrides.py rename to tests/unit/export/test_onnx_config_overrides.py diff --git a/tests/export/test_pytorch_export.py b/tests/unit/export/test_pytorch_export.py similarity index 98% rename from tests/export/test_pytorch_export.py rename to tests/unit/export/test_pytorch_export.py index e9905a506..cc302ee6d 100644 --- a/tests/export/test_pytorch_export.py +++ b/tests/unit/export/test_pytorch_export.py @@ -222,9 +222,7 @@ def test_input_shape_in_onnx(self, tmp_path) -> None: export_pytorch(model, tmp_path / "model.onnx", config) onnx_model = onnx.load(str(tmp_path / "model.onnx")) - input_shape = [ - d.dim_value for d in onnx_model.graph.input[0].type.tensor_type.shape.dim - ] + input_shape = [d.dim_value for d in onnx_model.graph.input[0].type.tensor_type.shape.dim] assert input_shape == [1, 10] def test_no_input_tensors_raises(self, tmp_path) -> None: diff --git a/tests/export/test_zoedepth_onnx_config.py b/tests/unit/export/test_zoedepth_onnx_config.py similarity index 100% rename from tests/export/test_zoedepth_onnx_config.py rename to tests/unit/export/test_zoedepth_onnx_config.py diff --git a/tests/inspect/test_module_io_capture.py b/tests/unit/inspect/test_module_io_capture.py similarity index 96% rename from tests/inspect/test_module_io_capture.py rename to tests/unit/inspect/test_module_io_capture.py index d5817b7f3..50a64a5fe 100644 --- a/tests/inspect/test_module_io_capture.py +++ b/tests/unit/inspect/test_module_io_capture.py @@ -20,7 +20,9 @@ def __init__(self): self.linear = nn.Linear(64, 32) def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, ) -> torch.Tensor: return self.linear(hidden_states) * attention_mask.unsqueeze(-1) @@ -33,7 +35,9 @@ def __init__(self): self.sub = MultiInputModule() def forward( - self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, ) -> torch.Tensor: return self.sub(hidden_states, attention_mask=attention_mask) diff --git a/tests/loader/conftest.py b/tests/unit/loader/conftest.py similarity index 100% rename from tests/loader/conftest.py rename to tests/unit/loader/conftest.py diff --git a/tests/loader/test_all_architectures.py b/tests/unit/loader/test_all_architectures.py similarity index 100% rename from tests/loader/test_all_architectures.py rename to tests/unit/loader/test_all_architectures.py diff --git a/tests/loader/test_config_resolution.py b/tests/unit/loader/test_config_resolution.py similarity index 100% rename from tests/loader/test_config_resolution.py rename to tests/unit/loader/test_config_resolution.py diff --git a/tests/loader/test_detect_task_and_class.py b/tests/unit/loader/test_detect_task_and_class.py similarity index 100% rename from tests/loader/test_detect_task_and_class.py rename to tests/unit/loader/test_detect_task_and_class.py diff --git a/tests/loader/test_detect_task_from_config.py b/tests/unit/loader/test_detect_task_from_config.py similarity index 100% rename from tests/loader/test_detect_task_from_config.py rename to tests/unit/loader/test_detect_task_from_config.py diff --git a/tests/loader/test_get_supported_tasks.py b/tests/unit/loader/test_get_supported_tasks.py similarity index 100% rename from tests/loader/test_get_supported_tasks.py rename to tests/unit/loader/test_get_supported_tasks.py diff --git a/tests/loader/test_hf_model_class_mapping.py b/tests/unit/loader/test_hf_model_class_mapping.py similarity index 100% rename from tests/loader/test_hf_model_class_mapping.py rename to tests/unit/loader/test_hf_model_class_mapping.py diff --git a/tests/loader/test_load_hf_model.py b/tests/unit/loader/test_load_hf_model.py similarity index 100% rename from tests/loader/test_load_hf_model.py rename to tests/unit/loader/test_load_hf_model.py diff --git a/tests/loader/test_loader_config.py b/tests/unit/loader/test_loader_config.py similarity index 98% rename from tests/loader/test_loader_config.py rename to tests/unit/loader/test_loader_config.py index 422a81e4e..b7387510c 100644 --- a/tests/loader/test_loader_config.py +++ b/tests/unit/loader/test_loader_config.py @@ -7,7 +7,6 @@ Test specifications from docs/design/loader/hf.md section 9.5. """ - from winml.modelkit.config import WinMLBuildConfig from winml.modelkit.loader.config import WinMLLoaderConfig @@ -159,9 +158,7 @@ def test_model_config_to_dict_excludes_empty_loader(self): def test_model_config_to_dict_includes_loader_with_values(self): """Test WinMLBuildConfig.to_dict includes loader with values.""" - config = WinMLBuildConfig.from_dict( - {"loader": {"task": "image-classification"}} - ) + config = WinMLBuildConfig.from_dict({"loader": {"task": "image-classification"}}) d = config.to_dict() assert "loader" in d assert d["loader"]["task"] == "image-classification" diff --git a/tests/loader/test_loader_config_module_path.py b/tests/unit/loader/test_loader_config_module_path.py similarity index 100% rename from tests/loader/test_loader_config_module_path.py rename to tests/unit/loader/test_loader_config_module_path.py diff --git a/tests/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py similarity index 100% rename from tests/loader/test_resolve_loader_config.py rename to tests/unit/loader/test_resolve_loader_config.py diff --git a/tests/loader/test_resolve_task_and_model_class.py b/tests/unit/loader/test_resolve_task_and_model_class.py similarity index 100% rename from tests/loader/test_resolve_task_and_model_class.py rename to tests/unit/loader/test_resolve_task_and_model_class.py diff --git a/tests/loader/test_task_utils.py b/tests/unit/loader/test_task_utils.py similarity index 100% rename from tests/loader/test_task_utils.py rename to tests/unit/loader/test_task_utils.py diff --git a/tests/inspect/__init__.py b/tests/unit/models/auto/__init__.py similarity index 100% rename from tests/inspect/__init__.py rename to tests/unit/models/auto/__init__.py diff --git a/tests/models/auto/conftest.py b/tests/unit/models/auto/conftest.py similarity index 89% rename from tests/models/auto/conftest.py rename to tests/unit/models/auto/conftest.py index 0fe095d80..1f051da79 100644 --- a/tests/models/auto/conftest.py +++ b/tests/unit/models/auto/conftest.py @@ -126,20 +126,30 @@ def image_classification_onnx(tmp_path: Path) -> Path: # Simple conv weights (just for structure, not real weights) conv_w = helper.make_tensor( - "conv_w", TensorProto.FLOAT, [64, 3, 7, 7], - np.random.randn(64, 3, 7, 7).astype(np.float32).flatten() + "conv_w", + TensorProto.FLOAT, + [64, 3, 7, 7], + np.random.randn(64, 3, 7, 7).astype(np.float32).flatten(), ) # FC weights fc_w = helper.make_tensor( - "fc_w", TensorProto.FLOAT, [50176, 1000], # 64*28*28 = 50176 - np.random.randn(50176, 1000).astype(np.float32).flatten() + "fc_w", + TensorProto.FLOAT, + [50176, 1000], # 64*28*28 = 50176 + np.random.randn(50176, 1000).astype(np.float32).flatten(), ) # Nodes nodes = [ - helper.make_node("Conv", ["pixel_values", "conv_w"], ["conv_out"], - kernel_shape=[7, 7], strides=[8, 8], pads=[3, 3, 3, 3]), + helper.make_node( + "Conv", + ["pixel_values", "conv_w"], + ["conv_out"], + kernel_shape=[7, 7], + strides=[8, 8], + pads=[3, 3, 3, 3], + ), helper.make_node("Flatten", ["conv_out"], ["flat_out"], axis=1), helper.make_node("MatMul", ["flat_out", "fc_w"], ["logits"]), ] @@ -179,26 +189,23 @@ def sequence_classification_onnx(tmp_path: Path) -> Path: Output: logits (1, 2) """ # Inputs - input_ids = helper.make_tensor_value_info( - "input_ids", TensorProto.INT64, [1, 128] - ) - attention_mask = helper.make_tensor_value_info( - "attention_mask", TensorProto.INT64, [1, 128] - ) + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT64, [1, 128]) + attention_mask = helper.make_tensor_value_info("attention_mask", TensorProto.INT64, [1, 128]) # Output logits = helper.make_tensor_value_info("logits", TensorProto.FLOAT, [1, 2]) # Embedding weights (vocab_size=1000, hidden_size=64) embed_w = helper.make_tensor( - "embed_w", TensorProto.FLOAT, [1000, 64], - np.random.randn(1000, 64).astype(np.float32).flatten() + "embed_w", + TensorProto.FLOAT, + [1000, 64], + np.random.randn(1000, 64).astype(np.float32).flatten(), ) # FC weights fc_w = helper.make_tensor( - "fc_w", TensorProto.FLOAT, [64, 2], - np.random.randn(64, 2).astype(np.float32).flatten() + "fc_w", TensorProto.FLOAT, [64, 2], np.random.randn(64, 2).astype(np.float32).flatten() ) # Nodes @@ -244,17 +251,16 @@ def sample_text_input() -> dict[str, np.ndarray]: # Skip markers for conditional tests requires_npu = pytest.mark.skipif( os.environ.get("WINML_TEST_NPU", "0") != "1", - reason="NPU tests require WINML_TEST_NPU=1 and NPU hardware" + reason="NPU tests require WINML_TEST_NPU=1 and NPU hardware", ) requires_gpu = pytest.mark.skipif( os.environ.get("WINML_TEST_GPU", "0") != "1", - reason="GPU tests require WINML_TEST_GPU=1 and GPU hardware" + reason="GPU tests require WINML_TEST_GPU=1 and GPU hardware", ) requires_hf_hub = pytest.mark.skipif( - os.environ.get("WINML_TEST_OFFLINE", "0") == "1", - reason="HF Hub tests disabled in offline mode" + os.environ.get("WINML_TEST_OFFLINE", "0") == "1", reason="HF Hub tests disabled in offline mode" ) diff --git a/tests/models/auto/test_auto_model.py b/tests/unit/models/auto/test_auto_model.py similarity index 97% rename from tests/models/auto/test_auto_model.py rename to tests/unit/models/auto/test_auto_model.py index b5fd53275..a59ae3aa5 100644 --- a/tests/models/auto/test_auto_model.py +++ b/tests/unit/models/auto/test_auto_model.py @@ -82,7 +82,8 @@ def test_image_classification_patterns(self): # WINML_MODEL_CLASS_MAPPING uses (model_type, task) tuples as keys # Check that the mapping structure is correct image_class_patterns = [ - k for k in WINML_MODEL_CLASS_MAPPING + k + for k in WINML_MODEL_CLASS_MAPPING if isinstance(k, tuple) and "image-classification" in k ] # May be empty if no specializations registered - that's ok @@ -93,7 +94,8 @@ def test_sequence_classification_patterns(self): from winml.modelkit.models.winml import WINML_MODEL_CLASS_MAPPING seq_class_patterns = [ - k for k in WINML_MODEL_CLASS_MAPPING + k + for k in WINML_MODEL_CLASS_MAPPING if isinstance(k, tuple) and "sequence-classification" in k ] # May be empty if no specializations registered - that's ok @@ -177,9 +179,7 @@ def test_get_winml_class_unsupported_task_returns_generic(self): ("image-segmentation", "segformer", "WinMLModelForImageSegmentation"), ], ) - def test_model_class_names( - self, task: str, model_type: str, expected_class_name: str - ): + def test_model_class_names(self, task: str, model_type: str, expected_class_name: str): """Test model class naming convention.""" from winml.modelkit.models.winml import get_winml_class diff --git a/tests/models/auto/test_auto_onnx.py b/tests/unit/models/auto/test_auto_onnx.py similarity index 98% rename from tests/models/auto/test_auto_onnx.py rename to tests/unit/models/auto/test_auto_onnx.py index 779ef99cd..5c9b66f9c 100644 --- a/tests/models/auto/test_auto_onnx.py +++ b/tests/unit/models/auto/test_auto_onnx.py @@ -61,7 +61,9 @@ def test_auto_generates_config_when_none(self, fake_onnx: Path, tmp_path: Path): mock_get_class.return_value = lambda **kw: mock_instance WinMLAutoModel.from_onnx( - str(fake_onnx), task="image-classification", device="npu", + str(fake_onnx), + task="image-classification", + device="npu", ) mock_build.assert_called_once() @@ -193,5 +195,3 @@ def test_passes_ep_from_kwargs(self, fake_onnx: Path, tmp_path: Path): call_kwargs = mock_from_onnx.call_args.kwargs assert call_kwargs["ep"] == "qnn" - - diff --git a/tests/models/auto/test_config.py b/tests/unit/models/auto/test_config.py similarity index 100% rename from tests/models/auto/test_config.py rename to tests/unit/models/auto/test_config.py diff --git a/tests/models/auto/test_image_classification.py b/tests/unit/models/auto/test_image_classification.py similarity index 96% rename from tests/models/auto/test_image_classification.py rename to tests/unit/models/auto/test_image_classification.py index e315d2988..7dce67ad0 100644 --- a/tests/models/auto/test_image_classification.py +++ b/tests/unit/models/auto/test_image_classification.py @@ -34,13 +34,9 @@ def create_mock_model(num_labels: int = 1000): """ from winml.modelkit.models.winml import WinMLModelForImageClassification - model = WinMLModelForImageClassification.__new__( - WinMLModelForImageClassification - ) + model = WinMLModelForImageClassification.__new__(WinMLModelForImageClassification) mock_session = MagicMock() - mock_session.run.return_value = { - "logits": np.random.randn(1, num_labels).astype(np.float32) - } + mock_session.run.return_value = {"logits": np.random.randn(1, num_labels).astype(np.float32)} mock_session.io_config = { "input_names": ["pixel_values"], "output_names": ["logits"], diff --git a/tests/models/auto/test_image_segmentation.py b/tests/unit/models/auto/test_image_segmentation.py similarity index 95% rename from tests/models/auto/test_image_segmentation.py rename to tests/unit/models/auto/test_image_segmentation.py index 2b2bf9bdf..ee7583f6c 100644 --- a/tests/models/auto/test_image_segmentation.py +++ b/tests/unit/models/auto/test_image_segmentation.py @@ -45,18 +45,12 @@ def create_mock_model( WinMLModelForImageSegmentation, ) - model = WinMLModelForImageSegmentation.__new__( - WinMLModelForImageSegmentation - ) + model = WinMLModelForImageSegmentation.__new__(WinMLModelForImageSegmentation) mock_session = MagicMock() mock_session.run.return_value = { - "logits": np.random.randn(1, num_queries, num_classes + 1).astype( - np.float32 - ), + "logits": np.random.randn(1, num_queries, num_classes + 1).astype(np.float32), "pred_boxes": np.random.randn(1, num_queries, 4).astype(np.float32), - "pred_masks": np.random.randn(1, num_queries, output_h, output_w).astype( - np.float32 - ), + "pred_masks": np.random.randn(1, num_queries, output_h, output_w).astype(np.float32), } mock_session.io_config = { "input_names": ["pixel_values"], @@ -189,9 +183,7 @@ def test_forward_missing_outputs_are_none(self): WinMLModelForImageSegmentation, ) - model = WinMLModelForImageSegmentation.__new__( - WinMLModelForImageSegmentation - ) + model = WinMLModelForImageSegmentation.__new__(WinMLModelForImageSegmentation) mock_session = MagicMock() # Only logits output (no pred_masks or pred_boxes) mock_session.run.return_value = { @@ -298,9 +290,7 @@ def test_registered_in_task_mapping(self): # ============================================================================= -def create_mock_semantic_model( - num_labels: int = 150, output_h: int = 128, output_w: int = 128 -): +def create_mock_semantic_model(num_labels: int = 150, output_h: int = 128, output_w: int = 128): """Create a WinMLModelForSemanticSegmentation with mocked session. Semantic segmentation outputs: logits [B, num_labels, H, W] @@ -309,14 +299,10 @@ def create_mock_semantic_model( WinMLModelForSemanticSegmentation, ) - model = WinMLModelForSemanticSegmentation.__new__( - WinMLModelForSemanticSegmentation - ) + model = WinMLModelForSemanticSegmentation.__new__(WinMLModelForSemanticSegmentation) mock_session = MagicMock() mock_session.run.return_value = { - "logits": np.random.randn(1, num_labels, output_h, output_w).astype( - np.float32 - ) + "logits": np.random.randn(1, num_labels, output_h, output_w).astype(np.float32) } mock_session.io_config = { "input_names": ["pixel_values"], @@ -397,10 +383,7 @@ def test_registered_in_task_mapping(self): from winml.modelkit.models.winml import TASK_TO_WINML_CLASS assert "semantic-segmentation" in TASK_TO_WINML_CLASS - assert ( - TASK_TO_WINML_CLASS["semantic-segmentation"] - == "WinMLModelForSemanticSegmentation" - ) + assert TASK_TO_WINML_CLASS["semantic-segmentation"] == "WinMLModelForSemanticSegmentation" class TestOutputTypeDistinction: diff --git a/tests/models/auto/test_integration.py b/tests/unit/models/auto/test_integration.py similarity index 97% rename from tests/models/auto/test_integration.py rename to tests/unit/models/auto/test_integration.py index 37bc428e7..b1ad3600b 100644 --- a/tests/models/auto/test_integration.py +++ b/tests/unit/models/auto/test_integration.py @@ -29,14 +29,10 @@ def _make_mock_model(num_labels: int = 1000): """Create an image classification model with mocked session.""" from winml.modelkit.models.winml import WinMLModelForImageClassification - model = WinMLModelForImageClassification.__new__( - WinMLModelForImageClassification - ) + model = WinMLModelForImageClassification.__new__(WinMLModelForImageClassification) mock_session = MagicMock() - mock_session.run.return_value = { - "logits": np.random.randn(1, num_labels).astype(np.float32) - } + mock_session.run.return_value = {"logits": np.random.randn(1, num_labels).astype(np.float32)} mock_session.io_config = { "input_names": ["pixel_values"], "output_names": ["logits"], diff --git a/tests/models/auto/test_sequence_classification.py b/tests/unit/models/auto/test_sequence_classification.py similarity index 94% rename from tests/models/auto/test_sequence_classification.py rename to tests/unit/models/auto/test_sequence_classification.py index 128365771..d7734876a 100644 --- a/tests/models/auto/test_sequence_classification.py +++ b/tests/unit/models/auto/test_sequence_classification.py @@ -32,13 +32,9 @@ def create_mock_model(num_labels: int = 2): """ from winml.modelkit.models.winml import WinMLModelForSequenceClassification - model = WinMLModelForSequenceClassification.__new__( - WinMLModelForSequenceClassification - ) + model = WinMLModelForSequenceClassification.__new__(WinMLModelForSequenceClassification) mock_session = MagicMock() - mock_session.run.return_value = { - "logits": np.random.randn(1, num_labels).astype(np.float32) - } + mock_session.run.return_value = {"logits": np.random.randn(1, num_labels).astype(np.float32)} mock_session.io_config = { "input_names": ["input_ids", "attention_mask", "token_type_ids"], "output_names": ["logits"], @@ -67,9 +63,7 @@ def test_inherits_from_base(self): WinMLPreTrainedModel, ) - assert issubclass( - WinMLModelForSequenceClassification, WinMLPreTrainedModel - ) + assert issubclass(WinMLModelForSequenceClassification, WinMLPreTrainedModel) def test_has_forward_method(self): """Test class has forward method.""" diff --git a/tests/models/clip/test_loader_config.py b/tests/unit/models/clip/test_loader_config.py similarity index 96% rename from tests/models/clip/test_loader_config.py rename to tests/unit/models/clip/test_loader_config.py index 861bb1d0c..039d380db 100644 --- a/tests/models/clip/test_loader_config.py +++ b/tests/unit/models/clip/test_loader_config.py @@ -120,10 +120,7 @@ def test_mapping_has_image_feature_extraction(self): def test_feature_extraction_maps_to_text_model(self): """feature-extraction maps to CLIPTextModelWithProjection.""" - assert ( - MODEL_CLASS_MAPPING[("clip", "feature-extraction")] - is CLIPTextModelWithProjection - ) + assert MODEL_CLASS_MAPPING[("clip", "feature-extraction")] is CLIPTextModelWithProjection def test_image_feature_extraction_maps_to_vision_model(self): """image-feature-extraction maps to CLIPVisionModelWithProjection.""" @@ -187,9 +184,7 @@ def test_config_has_correct_model_type(self, clip_config): def test_feature_extraction_resolves_to_text_model(self, clip_config): """task='feature-extraction' resolves to CLIPTextModelWithProjection via specialization.""" - task, resolved_class = resolve_task_and_model_class( - clip_config, task="feature-extraction" - ) + task, resolved_class = resolve_task_and_model_class(clip_config, task="feature-extraction") assert task == "feature-extraction" assert resolved_class is CLIPTextModelWithProjection @@ -205,9 +200,7 @@ def test_image_feature_extraction_resolves_to_vision_model(self, clip_config): def test_preserves_original_task_name(self, clip_config): """Returns original task name, not normalized version.""" - task, _ = resolve_task_and_model_class( - clip_config, task="image-feature-extraction" - ) + task, _ = resolve_task_and_model_class(clip_config, task="image-feature-extraction") # Should preserve "image-feature-extraction", not normalize to "feature-extraction" assert task == "image-feature-extraction" diff --git a/tests/models/clip/test_onnx_config.py b/tests/unit/models/clip/test_onnx_config.py similarity index 98% rename from tests/models/clip/test_onnx_config.py rename to tests/unit/models/clip/test_onnx_config.py index b93395e72..54be3c7ef 100644 --- a/tests/models/clip/test_onnx_config.py +++ b/tests/unit/models/clip/test_onnx_config.py @@ -157,9 +157,7 @@ def test_registration(self): def test_outputs_includes_image_embeds(self, clip_vision_config): """Outputs include image_embeds instead of pooler_output.""" - onnx_config = CLIPVisionModelIOConfig( - clip_vision_config, task="feature-extraction" - ) + onnx_config = CLIPVisionModelIOConfig(clip_vision_config, task="feature-extraction") outputs = onnx_config.outputs assert "image_embeds" in outputs diff --git a/tests/models/detr/test_onnx_config.py b/tests/unit/models/detr/test_onnx_config.py similarity index 100% rename from tests/models/detr/test_onnx_config.py rename to tests/unit/models/detr/test_onnx_config.py diff --git a/tests/models/sam2/test_onnx_config.py b/tests/unit/models/sam2/test_onnx_config.py similarity index 100% rename from tests/models/sam2/test_onnx_config.py rename to tests/unit/models/sam2/test_onnx_config.py diff --git a/tests/models/segformer/test_onnx_config.py b/tests/unit/models/segformer/test_onnx_config.py similarity index 100% rename from tests/models/segformer/test_onnx_config.py rename to tests/unit/models/segformer/test_onnx_config.py diff --git a/tests/unit/test_onnx_inspection.py b/tests/unit/onnx/test_onnx_inspection.py similarity index 98% rename from tests/unit/test_onnx_inspection.py rename to tests/unit/onnx/test_onnx_inspection.py index 5fb333638..3568e0299 100644 --- a/tests/unit/test_onnx_inspection.py +++ b/tests/unit/onnx/test_onnx_inspection.py @@ -56,7 +56,10 @@ def _make_qdq_model() -> onnx.ModelProto: y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]) graph = helper.make_graph( - [q, dq], "qdq_test", [x, scale_vi, zp_vi], [y], + [q, dq], + "qdq_test", + [x, scale_vi, zp_vi], + [y], initializer=[scale_init, zp_init], ) return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) diff --git a/tests/unit/test_onnx_metadata.py b/tests/unit/onnx/test_onnx_metadata.py similarity index 90% rename from tests/unit/test_onnx_metadata.py rename to tests/unit/onnx/test_onnx_metadata.py index 67738e795..8f9f6c60a 100644 --- a/tests/unit/test_onnx_metadata.py +++ b/tests/unit/onnx/test_onnx_metadata.py @@ -79,37 +79,52 @@ def tagged_cnn_model() -> onnx.ModelProto: # Nodes with hierarchy tags conv1 = helper.make_node( - "Conv", ["X", "conv1_weight"], ["conv1_out"], - name="/model/block/conv1/Conv", pads=[1, 1, 1, 1], + "Conv", + ["X", "conv1_weight"], + ["conv1_out"], + name="/model/block/conv1/Conv", + pads=[1, 1, 1, 1], ) _tag_node(conv1, "/Model/Block/Conv1") relu1 = helper.make_node( - "Relu", ["conv1_out"], ["relu1_out"], + "Relu", + ["conv1_out"], + ["relu1_out"], name="/model/block/relu1/Relu", ) _tag_node(relu1, "/Model/Block/Relu1") conv2 = helper.make_node( - "Conv", ["relu1_out", "conv2_weight"], ["conv2_out"], - name="/model/block/conv2/Conv", pads=[1, 1, 1, 1], + "Conv", + ["relu1_out", "conv2_weight"], + ["conv2_out"], + name="/model/block/conv2/Conv", + pads=[1, 1, 1, 1], ) _tag_node(conv2, "/Model/Block/Conv2") relu2 = helper.make_node( - "Relu", ["conv2_out"], ["relu2_out"], + "Relu", + ["conv2_out"], + ["relu2_out"], name="/model/block/relu2/Relu", ) _tag_node(relu2, "/Model/Block/Relu2") conv3 = helper.make_node( - "Conv", ["relu2_out", "conv3_weight"], ["conv3_out"], - name="/model/block/conv3/Conv", pads=[1, 1, 1, 1], + "Conv", + ["relu2_out", "conv3_weight"], + ["conv3_out"], + name="/model/block/conv3/Conv", + pads=[1, 1, 1, 1], ) _tag_node(conv3, "/Model/Block/Conv3") add = helper.make_node( - "Add", ["conv3_out", "X"], ["Y"], + "Add", + ["conv3_out", "X"], + ["Y"], name="/model/block/Add", ) _tag_node(add, "/Model/Block/Add") @@ -117,7 +132,8 @@ def tagged_cnn_model() -> onnx.ModelProto: graph = helper.make_graph( [conv1, relu1, conv2, relu2, conv3, add], "tagged_cnn", - [x], [y], + [x], + [y], initializer=inits, ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) @@ -211,10 +227,7 @@ def test_restores_node_props(self, tagged_cnn_model): _strip_all_metadata(tagged_cnn_model) result = restore_metadata(tagged_cnn_model, snapshot) - conv1 = next( - n for n in tagged_cnn_model.graph.node - if n.name == "/model/block/conv1/Conv" - ) + conv1 = next(n for n in tagged_cnn_model.graph.node if n.name == "/model/block/conv1/Conv") props = {p.key: p.value for p in conv1.metadata_props} assert props["winml.hierarchy.tag"] == "/Model/Block/Conv1" assert result.nodes_restored == 6 @@ -224,10 +237,7 @@ def test_restores_winml_attributes(self, tagged_cnn_model): _strip_all_metadata(tagged_cnn_model) restore_metadata(tagged_cnn_model, snapshot) - add_node = next( - n for n in tagged_cnn_model.graph.node - if n.name == "/model/block/Add" - ) + add_node = next(n for n in tagged_cnn_model.graph.node if n.name == "/model/block/Add") attr_map = {a.name: a.s.decode() for a in add_node.attribute if a.name.startswith("winml.")} assert attr_map["winml.node.origin"] == "export" assert attr_map["winml.hierarchy.tag"] == "/Model/Block/Add" @@ -237,13 +247,8 @@ def test_does_not_duplicate(self, tagged_cnn_model): snapshot = capture_metadata(tagged_cnn_model) restore_metadata(tagged_cnn_model, snapshot) - conv1 = next( - n for n in tagged_cnn_model.graph.node - if n.name == "/model/block/conv1/Conv" - ) - tag_count = sum( - 1 for p in conv1.metadata_props if p.key == "winml.hierarchy.tag" - ) + conv1 = next(n for n in tagged_cnn_model.graph.node if n.name == "/model/block/conv1/Conv") + tag_count = sum(1 for p in conv1.metadata_props if p.key == "winml.hierarchy.tag") assert tag_count == 1 def test_unmatched_nodes_skipped(self): @@ -302,7 +307,8 @@ def test_survives_shape_inference(self, tagged_cnn_model): # Shape inference creates a new ModelProto (may strip metadata) new_model = onnx.shape_inference.infer_shapes( - tagged_cnn_model, strict_mode=False, + tagged_cnn_model, + strict_mode=False, ) restore_metadata(new_model, snapshot) @@ -341,7 +347,9 @@ def test_new_nodes_not_affected(self, tagged_cnn_model): # Add a synthetic node that wasn't in original new_node = helper.make_node( - "QuantizeLinear", ["X", "scale", "zp"], ["Q_out"], + "QuantizeLinear", + ["X", "scale", "zp"], + ["Q_out"], name="input_QuantizeLinear", ) tagged_cnn_model.graph.node.append(new_node) @@ -351,10 +359,7 @@ def test_new_nodes_not_affected(self, tagged_cnn_model): # Original 6 restored, new node untouched assert result.nodes_restored == 6 - q_node = next( - n for n in tagged_cnn_model.graph.node - if n.name == "input_QuantizeLinear" - ) + q_node = next(n for n in tagged_cnn_model.graph.node if n.name == "input_QuantizeLinear") assert len(q_node.metadata_props) == 0 def test_model_level_io_metadata_survives(self, tagged_cnn_model): diff --git a/tests/onnx/test_persistence.py b/tests/unit/onnx/test_persistence.py similarity index 100% rename from tests/onnx/test_persistence.py rename to tests/unit/onnx/test_persistence.py diff --git a/tests/optim/__init__.py b/tests/unit/optim/__init__.py similarity index 100% rename from tests/optim/__init__.py rename to tests/unit/optim/__init__.py diff --git a/tests/optim/assets/fusionpipe/builders/__init__.py b/tests/unit/optim/assets/fusionpipe/builders/__init__.py similarity index 100% rename from tests/optim/assets/fusionpipe/builders/__init__.py rename to tests/unit/optim/assets/fusionpipe/builders/__init__.py diff --git a/tests/unit/optim/assets/fusionpipe/builders/attention.py b/tests/unit/optim/assets/fusionpipe/builders/attention.py new file mode 100644 index 000000000..6199157e5 --- /dev/null +++ b/tests/unit/optim/assets/fusionpipe/builders/attention.py @@ -0,0 +1,769 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Attention pattern builders for FusionPipe testing. + +Creates ONNX graphs that match ORT's attention fusion patterns. +Based on: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/transformers/bert_model_generator.py + +Reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_attention.py +""" + +from __future__ import annotations + +import math + +import numpy as np +from onnx import ModelProto, TensorProto, helper + + +def bert_attention_builder( + input1_name: str, + input2_name: str, + mask_name: str, + output_name: str, + prefix: str, + initializers: list, + hidden_size: int = 16, + num_heads: int = 2, +) -> list: + """Create BERT-style attention pattern matching ORT's expected structure. + + Pattern (from ORT bert_model_generator.py): + input1 + input2 -> Add -> LayerNorm -> Q/K/V projections + Mask: Unsqueeze -> Unsqueeze -> Cast -> Sub -> Mul + QK: MatMul(Q, K^T) -> Div -> Add(mask) -> Softmax + Output: MatMul(attn, V) -> Transpose -> Reshape -> MatMul -> Add + Residual: output + layernorm_out -> Add -> LayerNorm + + This pattern is recognized by FusionAttention class. + + Args: + input1_name: Name of first input tensor [batch, seq, hidden] + input2_name: Name of second input tensor (for skip connection) + mask_name: Name of attention mask tensor [batch, seq] + output_name: Name of output tensor [batch, seq, hidden] + prefix: Unique prefix for node names + initializers: List to append weight tensors + hidden_size: Hidden dimension (default: 16) + num_heads: Number of attention heads (default: 2) + + Returns: + List of ONNX nodes forming the attention pattern + """ + rng = np.random.RandomState(hash(prefix) % (2**32)) + head_size = hidden_size // num_heads + nodes = [] + + # Weights + ln_weight = helper.make_tensor( + f"{prefix}ln_weight", + TensorProto.FLOAT, + [hidden_size], + rng.randn(hidden_size).astype(np.float32), + ) + ln_bias = helper.make_tensor( + f"{prefix}ln_bias", + TensorProto.FLOAT, + [hidden_size], + rng.randn(hidden_size).astype(np.float32), + ) + initializers.extend([ln_weight, ln_bias]) + + # Q, K, V projection weights + for proj in ["q", "k", "v"]: + weight = helper.make_tensor( + f"{prefix}{proj}_weight", + TensorProto.FLOAT, + [hidden_size, hidden_size], + rng.randn(hidden_size, hidden_size).astype(np.float32), + ) + bias = helper.make_tensor( + f"{prefix}{proj}_bias", + TensorProto.FLOAT, + [hidden_size], + rng.randn(hidden_size).astype(np.float32), + ) + initializers.extend([weight, bias]) + + # Output projection weights + out_weight = helper.make_tensor( + f"{prefix}out_weight", + TensorProto.FLOAT, + [hidden_size, hidden_size], + rng.randn(hidden_size, hidden_size).astype(np.float32), + ) + out_bias = helper.make_tensor( + f"{prefix}out_bias", + TensorProto.FLOAT, + [hidden_size], + rng.randn(hidden_size).astype(np.float32), + ) + initializers.extend([out_weight, out_bias]) + + # Reshape constants + reshape_qk = helper.make_tensor( + f"{prefix}reshape_qk", + TensorProto.INT64, + [4], + np.array([0, 0, num_heads, head_size], dtype=np.int64), + ) + reshape_out = helper.make_tensor( + f"{prefix}reshape_out", + TensorProto.INT64, + [3], + np.array([0, 0, hidden_size], dtype=np.int64), + ) + initializers.extend([reshape_qk, reshape_out]) + + # Div weight (sqrt(head_size)) + div_weight = helper.make_tensor( + f"{prefix}div_weight", + TensorProto.FLOAT, + [1], + np.array([math.sqrt(head_size)], dtype=np.float32), + ) + # Mask constants + sub_weight = helper.make_tensor( + f"{prefix}sub_weight", + TensorProto.FLOAT, + [1], + np.array([1.0], dtype=np.float32), + ) + mul_weight = helper.make_tensor( + f"{prefix}mul_weight", + TensorProto.FLOAT, + [1], + np.array([-10000.0], dtype=np.float32), + ) + # Unsqueeze axes + axes_1 = helper.make_tensor(f"{prefix}axes_1", TensorProto.INT64, [1], [1]) + axes_2 = helper.make_tensor(f"{prefix}axes_2", TensorProto.INT64, [1], [2]) + initializers.extend([div_weight, sub_weight, mul_weight, axes_1, axes_2]) + + # === NODES === + + # 1. Add + LayerNorm (entry point) + nodes.append( + helper.make_node( + "Add", + [input1_name, input2_name], + [f"{prefix}ln_input"], + name=f"{prefix}add_ln", + ) + ) + nodes.append( + helper.make_node( + "LayerNormalization", + [f"{prefix}ln_input", f"{prefix}ln_weight", f"{prefix}ln_bias"], + [f"{prefix}ln_out"], + name=f"{prefix}layernorm", + axis=-1, + epsilon=1e-5, + ) + ) + + # 2. Q projection: MatMul -> Add -> Reshape -> Transpose + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}ln_out", f"{prefix}q_weight"], + [f"{prefix}q_mm"], + name=f"{prefix}matmul_q", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}q_mm", f"{prefix}q_bias"], + [f"{prefix}q_add"], + name=f"{prefix}add_q", + ) + ) + nodes.append( + helper.make_node( + "Reshape", + [f"{prefix}q_add", f"{prefix}reshape_qk"], + [f"{prefix}q_reshape"], + name=f"{prefix}reshape_q", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}q_reshape"], + [f"{prefix}q_trans"], + name=f"{prefix}transpose_q", + perm=[0, 2, 1, 3], + ) + ) + + # 3. K projection: MatMul -> Add -> Reshape -> Transpose (different perm for K^T) + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}ln_out", f"{prefix}k_weight"], + [f"{prefix}k_mm"], + name=f"{prefix}matmul_k", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}k_mm", f"{prefix}k_bias"], + [f"{prefix}k_add"], + name=f"{prefix}add_k", + ) + ) + nodes.append( + helper.make_node( + "Reshape", + [f"{prefix}k_add", f"{prefix}reshape_qk"], + [f"{prefix}k_reshape"], + name=f"{prefix}reshape_k", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}k_reshape"], + [f"{prefix}k_trans"], + name=f"{prefix}transpose_k", + perm=[0, 2, 3, 1], # K^T + ) + ) + + # 4. V projection: MatMul -> Add -> Reshape -> Transpose + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}ln_out", f"{prefix}v_weight"], + [f"{prefix}v_mm"], + name=f"{prefix}matmul_v", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}v_mm", f"{prefix}v_bias"], + [f"{prefix}v_add"], + name=f"{prefix}add_v", + ) + ) + nodes.append( + helper.make_node( + "Reshape", + [f"{prefix}v_add", f"{prefix}reshape_qk"], + [f"{prefix}v_reshape"], + name=f"{prefix}reshape_v", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}v_reshape"], + [f"{prefix}v_trans"], + name=f"{prefix}transpose_v", + perm=[0, 2, 1, 3], + ) + ) + + # 5. Mask processing: Unsqueeze -> Unsqueeze -> Cast -> Sub -> Mul + nodes.append( + helper.make_node( + "Unsqueeze", + [mask_name, f"{prefix}axes_1"], + [f"{prefix}mask_unsq1"], + name=f"{prefix}unsqueeze1", + ) + ) + nodes.append( + helper.make_node( + "Unsqueeze", + [f"{prefix}mask_unsq1", f"{prefix}axes_2"], + [f"{prefix}mask_unsq2"], + name=f"{prefix}unsqueeze2", + ) + ) + nodes.append( + helper.make_node( + "Cast", + [f"{prefix}mask_unsq2"], + [f"{prefix}mask_cast"], + name=f"{prefix}cast_mask", + to=TensorProto.FLOAT, + ) + ) + nodes.append( + helper.make_node( + "Sub", + [f"{prefix}sub_weight", f"{prefix}mask_cast"], + [f"{prefix}mask_sub"], + name=f"{prefix}sub_mask", + ) + ) + nodes.append( + helper.make_node( + "Mul", + [f"{prefix}mask_sub", f"{prefix}mul_weight"], + [f"{prefix}mask_out"], + name=f"{prefix}mul_mask", + ) + ) + + # 6. QK attention: MatMul(Q, K^T) -> Div -> Add(mask) -> Softmax + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}q_trans", f"{prefix}k_trans"], + [f"{prefix}qk_mm"], + name=f"{prefix}matmul_qk", + ) + ) + nodes.append( + helper.make_node( + "Div", + [f"{prefix}qk_mm", f"{prefix}div_weight"], + [f"{prefix}qk_div"], + name=f"{prefix}div_qk", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}qk_div", f"{prefix}mask_out"], + [f"{prefix}qk_add"], + name=f"{prefix}add_qk", + ) + ) + nodes.append( + helper.make_node( + "Softmax", + [f"{prefix}qk_add"], + [f"{prefix}attn_weights"], + name=f"{prefix}softmax", + axis=3, + ) + ) + + # 7. Attention @ V: MatMul -> Transpose -> Reshape + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}attn_weights", f"{prefix}v_trans"], + [f"{prefix}attn_v"], + name=f"{prefix}matmul_attn_v", + ) + ) + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}attn_v"], + [f"{prefix}attn_trans"], + name=f"{prefix}transpose_attn", + perm=[0, 2, 1, 3], + ) + ) + nodes.append( + helper.make_node( + "Reshape", + [f"{prefix}attn_trans", f"{prefix}reshape_out"], + [f"{prefix}attn_reshape"], + name=f"{prefix}reshape_attn", + ) + ) + + # 8. Output projection: MatMul -> Add + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}attn_reshape", f"{prefix}out_weight"], + [f"{prefix}out_mm"], + name=f"{prefix}matmul_out", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}out_mm", f"{prefix}out_bias"], + [f"{prefix}out_add"], + name=f"{prefix}add_out", + ) + ) + + # 9. Residual + Final LayerNorm: Add(output, ln_out) -> LayerNorm + nodes.append( + helper.make_node( + "Add", + [f"{prefix}out_add", f"{prefix}ln_out"], + [f"{prefix}skip_out"], + name=f"{prefix}add_skip", + ) + ) + nodes.append( + helper.make_node( + "LayerNormalization", + [f"{prefix}skip_out", f"{prefix}ln_weight", f"{prefix}ln_bias"], + [output_name], + name=f"{prefix}layernorm2", + axis=-1, + epsilon=1e-5, + ) + ) + + return nodes + + +def create_bert_attention_model( + hidden_size: int = 16, + num_heads: int = 2, + seq_len: int = 10, + batch_size: int = 1, +) -> ModelProto: + """Create complete ONNX model with BERT attention pattern. + + This model matches ORT's bert_model_generator.py structure and should + be fusible by FusionAttention. + + Args: + hidden_size: Hidden dimension (default: 16) + num_heads: Number of attention heads (default: 2) + seq_len: Sequence length (default: 10) + batch_size: Batch size (default: 1) + + Returns: + Complete ONNX ModelProto ready for fusion testing + """ + initializers: list = [] + nodes = bert_attention_builder( + input1_name="input_1", + input2_name="input_2", + mask_name="attention_mask", + output_name="output", + prefix="attn_", + initializers=initializers, + hidden_size=hidden_size, + num_heads=num_heads, + ) + + # Inputs + input1 = helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] + ) + input2 = helper.make_tensor_value_info( + "input_2", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] + ) + mask = helper.make_tensor_value_info("attention_mask", TensorProto.INT64, [batch_size, seq_len]) + output = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] + ) + + graph = helper.make_graph( + nodes, + "bert_attention_test", + [input1, input2, mask], + [output], + initializers, + ) + + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 17)], # opset 17 for LayerNormalization + ) + model.ir_version = 8 + + return model + + +def gpt2_attention_builder( + input_name: str, + output_name: str, + prefix: str, + initializers: list, + hidden_size: int = 16, + num_heads: int = 2, + seq_len: int = 3, +) -> list: + """Create GPT-2 style causal attention pattern. + + Note: GPT-2 attention has a different structure than BERT. + This is a simplified version for testing. + + Args: + input_name: Name of input tensor [batch, seq, hidden] + output_name: Name of output tensor [batch, seq, hidden] + prefix: Unique prefix for node names + initializers: List to append weight tensors + hidden_size: Hidden dimension (default: 16) + num_heads: Number of attention heads (default: 2) + seq_len: Sequence length (default: 3) + + Returns: + List of ONNX nodes forming the GPT-2 attention pattern + """ + rng = np.random.RandomState(hash(prefix) % (2**32)) + head_size = hidden_size // num_heads + nodes = [] + + # Combined QKV projection (GPT-2 style) + qkv_weight = helper.make_tensor( + f"{prefix}qkv_weight", + TensorProto.FLOAT, + [hidden_size, 3 * hidden_size], + rng.randn(hidden_size, 3 * hidden_size).astype(np.float32), + ) + qkv_bias = helper.make_tensor( + f"{prefix}qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + rng.randn(3 * hidden_size).astype(np.float32), + ) + initializers.extend([qkv_weight, qkv_bias]) + + # Split sizes for QKV + split_sizes = helper.make_tensor( + f"{prefix}split_sizes", + TensorProto.INT64, + [3], + np.array([hidden_size, hidden_size, hidden_size], dtype=np.int64), + ) + initializers.append(split_sizes) + + # Reshape and other constants + reshape_shape = helper.make_tensor( + f"{prefix}reshape_shape", + TensorProto.INT64, + [4], + np.array([0, 0, num_heads, head_size], dtype=np.int64), + ) + reshape_back = helper.make_tensor( + f"{prefix}reshape_back", + TensorProto.INT64, + [3], + np.array([0, 0, hidden_size], dtype=np.int64), + ) + scale = helper.make_tensor( + f"{prefix}scale", + TensorProto.FLOAT, + [], + [np.float32(np.sqrt(head_size))], + ) + initializers.extend([reshape_shape, reshape_back, scale]) + + # Output projection + out_weight = helper.make_tensor( + f"{prefix}out_weight", + TensorProto.FLOAT, + [hidden_size, hidden_size], + rng.randn(hidden_size, hidden_size).astype(np.float32), + ) + out_bias = helper.make_tensor( + f"{prefix}out_bias", + TensorProto.FLOAT, + [hidden_size], + rng.randn(hidden_size).astype(np.float32), + ) + initializers.extend([out_weight, out_bias]) + + # QKV MatMul + Add + nodes.append( + helper.make_node( + "MatMul", + [input_name, f"{prefix}qkv_weight"], + [f"{prefix}qkv_matmul"], + name=f"{prefix}qkv_matmul_node", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}qkv_matmul", f"{prefix}qkv_bias"], + [f"{prefix}qkv_out"], + name=f"{prefix}qkv_add_node", + ) + ) + + # Split QKV + nodes.append( + helper.make_node( + "Split", + [f"{prefix}qkv_out", f"{prefix}split_sizes"], + [f"{prefix}q_split", f"{prefix}k_split", f"{prefix}v_split"], + name=f"{prefix}qkv_split", + axis=-1, + ) + ) + + # Reshape and transpose Q, K, V + nodes.extend( + helper.make_node( + "Reshape", + [f"{prefix}{proj}_split", f"{prefix}reshape_shape"], + [f"{prefix}{proj}_reshaped"], + name=f"{prefix}{proj}_reshape", + allowzero=0, + ) + for proj in ["q", "k", "v"] + ) + + # Transpose Q and V + nodes.extend( + helper.make_node( + "Transpose", + [f"{prefix}{proj}_reshaped"], + [f"{prefix}{proj}_transposed"], + name=f"{prefix}{proj}_transpose", + perm=[0, 2, 1, 3], + ) + for proj in ["q", "v"] + ) + + # K transpose for Q @ K^T + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}k_reshaped"], + [f"{prefix}k_transposed"], + name=f"{prefix}k_transpose", + perm=[0, 2, 3, 1], + ) + ) + + # Q @ K^T + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}q_transposed", f"{prefix}k_transposed"], + [f"{prefix}qk"], + name=f"{prefix}qk_matmul", + ) + ) + + # Scale + nodes.append( + helper.make_node( + "Div", + [f"{prefix}qk", f"{prefix}scale"], + [f"{prefix}qk_scaled"], + name=f"{prefix}div_scale", + ) + ) + + # Softmax + nodes.append( + helper.make_node( + "Softmax", + [f"{prefix}qk_scaled"], + [f"{prefix}attn_weights"], + name=f"{prefix}softmax", + axis=3, + ) + ) + + # Attention @ V + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}attn_weights", f"{prefix}v_transposed"], + [f"{prefix}attn_out"], + name=f"{prefix}attn_v_matmul", + ) + ) + + # Transpose back + nodes.append( + helper.make_node( + "Transpose", + [f"{prefix}attn_out"], + [f"{prefix}attn_transposed"], + name=f"{prefix}attn_transpose", + perm=[0, 2, 1, 3], + ) + ) + + # Reshape back + nodes.append( + helper.make_node( + "Reshape", + [f"{prefix}attn_transposed", f"{prefix}reshape_back"], + [f"{prefix}attn_reshaped"], + name=f"{prefix}attn_reshape_back", + allowzero=0, + ) + ) + + # Output projection + nodes.append( + helper.make_node( + "MatMul", + [f"{prefix}attn_reshaped", f"{prefix}out_weight"], + [f"{prefix}out_matmul"], + name=f"{prefix}out_matmul_node", + ) + ) + nodes.append( + helper.make_node( + "Add", + [f"{prefix}out_matmul", f"{prefix}out_bias"], + [output_name], + name=f"{prefix}out_add_node", + ) + ) + + return nodes + + +def create_gpt2_attention_model( + hidden_size: int = 16, + num_heads: int = 2, + seq_len: int = 10, + batch_size: int = 1, +) -> ModelProto: + """Create complete ONNX model with GPT-2 attention pattern. + + Note: This is a simplified GPT-2 attention without causal masking. + Full GPT-2 fusion requires FusionGptAttention class. + + Args: + hidden_size: Hidden dimension (default: 16) + num_heads: Number of attention heads (default: 2) + seq_len: Sequence length (default: 10) + batch_size: Batch size (default: 1) + + Returns: + Complete ONNX ModelProto ready for testing + """ + initializers: list = [] + nodes = gpt2_attention_builder( + input_name="input", + output_name="output", + prefix="gpt2_attn_", + initializers=initializers, + hidden_size=hidden_size, + num_heads=num_heads, + seq_len=seq_len, + ) + + input_tensor = helper.make_tensor_value_info( + "input", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] + ) + output_tensor = helper.make_tensor_value_info( + "output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] + ) + + graph = helper.make_graph( + nodes, + "gpt2_attention_test", + [input_tensor], + [output_tensor], + initializers, + ) + + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 17)], + ) + model.ir_version = 8 + + return model diff --git a/tests/optim/assets/fusionpipe/builders/layernorm.py b/tests/unit/optim/assets/fusionpipe/builders/layernorm.py similarity index 100% rename from tests/optim/assets/fusionpipe/builders/layernorm.py rename to tests/unit/optim/assets/fusionpipe/builders/layernorm.py diff --git a/tests/optim/assets/fusionpipe/generate_patterns.py b/tests/unit/optim/assets/fusionpipe/generate_patterns.py similarity index 98% rename from tests/optim/assets/fusionpipe/generate_patterns.py rename to tests/unit/optim/assets/fusionpipe/generate_patterns.py index 798ebcf96..9d7f0d4d7 100644 --- a/tests/optim/assets/fusionpipe/generate_patterns.py +++ b/tests/unit/optim/assets/fusionpipe/generate_patterns.py @@ -13,7 +13,7 @@ These synthetic patterns may not fuse - tests verify config passing, not fusion. Usage in pytest fixtures: - from tests.optim.assets.generate_fusion_patterns import ( + from tests.unit.optim.assets.generate_fusion_patterns import ( create_self_attention_model, create_gqa_model, create_groupnorm_model, @@ -38,9 +38,7 @@ def attention_model() -> onnx.ModelProto: # ============================================================================= -def make_compatible_model( - graph: onnx.GraphProto, opset_version: int = 14 -) -> onnx.ModelProto: +def make_compatible_model(graph: onnx.GraphProto, opset_version: int = 14) -> onnx.ModelProto: """Create model with IR version compatible with ORT.""" model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset_version)]) model.ir_version = 8 @@ -544,9 +542,7 @@ def create_groupnorm_model( # Reshape for GroupNorm: [B, C, H, W] → [B, G, C//G, H, W] channels_per_group = channels // num_groups - gn_shape = np.array( - [batch_size, num_groups, channels_per_group, height, width], dtype=np.int64 - ) + gn_shape = np.array([batch_size, num_groups, channels_per_group, height, width], dtype=np.int64) initializers.append(numpy_helper.from_array(gn_shape, f"{prefix}gn_shape")) nodes.append( diff --git a/tests/optim/assets/graphpipe/builders/__init__.py b/tests/unit/optim/assets/graphpipe/builders/__init__.py similarity index 100% rename from tests/optim/assets/graphpipe/builders/__init__.py rename to tests/unit/optim/assets/graphpipe/builders/__init__.py diff --git a/tests/optim/assets/graphpipe/builders/activation.py b/tests/unit/optim/assets/graphpipe/builders/activation.py similarity index 100% rename from tests/optim/assets/graphpipe/builders/activation.py rename to tests/unit/optim/assets/graphpipe/builders/activation.py diff --git a/tests/optim/assets/graphpipe/builders/attention.py b/tests/unit/optim/assets/graphpipe/builders/attention.py similarity index 89% rename from tests/optim/assets/graphpipe/builders/attention.py rename to tests/unit/optim/assets/graphpipe/builders/attention.py index cd3798438..54e41202e 100644 --- a/tests/optim/assets/graphpipe/builders/attention.py +++ b/tests/unit/optim/assets/graphpipe/builders/attention.py @@ -169,9 +169,7 @@ def attention_builder(): ) # --- Q projection: MatMul -> Add -> Reshape -> Transpose --- - nodes.append( - helper.make_node("MatMul", ["ln_out", "wq"], ["q_matmul"], name="q_matmul") - ) + nodes.append(helper.make_node("MatMul", ["ln_out", "wq"], ["q_matmul"], name="q_matmul")) nodes.append(helper.make_node("Add", ["q_matmul", "bq"], ["q_add"], name="q_add")) # Reshape shape via SEPARATE Constant node (ORT pattern matcher requires separate shapes) nodes.append( @@ -185,9 +183,7 @@ def attention_builder(): ), ) ) - nodes.append( - helper.make_node("Reshape", ["q_add", "q_shape"], ["q_reshape"], name="q_reshape") - ) + nodes.append(helper.make_node("Reshape", ["q_add", "q_shape"], ["q_reshape"], name="q_reshape")) nodes.append( helper.make_node( "Transpose", @@ -199,9 +195,7 @@ def attention_builder(): ) # --- K projection: MatMul -> Add -> Reshape -> Transpose (special perm for K^T) --- - nodes.append( - helper.make_node("MatMul", ["ln_out", "wk"], ["k_matmul"], name="k_matmul") - ) + nodes.append(helper.make_node("MatMul", ["ln_out", "wk"], ["k_matmul"], name="k_matmul")) nodes.append(helper.make_node("Add", ["k_matmul", "bk"], ["k_add"], name="k_add")) # Separate Constant for K reshape nodes.append( @@ -215,9 +209,7 @@ def attention_builder(): ), ) ) - nodes.append( - helper.make_node("Reshape", ["k_add", "k_shape"], ["k_reshape"], name="k_reshape") - ) + nodes.append(helper.make_node("Reshape", ["k_add", "k_shape"], ["k_reshape"], name="k_reshape")) nodes.append( helper.make_node( "Transpose", @@ -229,9 +221,7 @@ def attention_builder(): ) # --- V projection: MatMul -> Add -> Reshape -> Transpose --- - nodes.append( - helper.make_node("MatMul", ["ln_out", "wv"], ["v_matmul"], name="v_matmul") - ) + nodes.append(helper.make_node("MatMul", ["ln_out", "wv"], ["v_matmul"], name="v_matmul")) nodes.append(helper.make_node("Add", ["v_matmul", "bv"], ["v_add"], name="v_add")) # Separate Constant for V reshape nodes.append( @@ -245,9 +235,7 @@ def attention_builder(): ), ) ) - nodes.append( - helper.make_node("Reshape", ["v_add", "v_shape"], ["v_reshape"], name="v_reshape") - ) + nodes.append(helper.make_node("Reshape", ["v_add", "v_shape"], ["v_reshape"], name="v_reshape")) nodes.append( helper.make_node( "Transpose", @@ -259,9 +247,7 @@ def attention_builder(): ) # --- Attention: Q @ K^T -> Div(scale) -> Add(mask) -> Softmax -> @ V --- - nodes.append( - helper.make_node("MatMul", ["q_trans", "k_trans"], ["qk"], name="qk_matmul") - ) + nodes.append(helper.make_node("MatMul", ["q_trans", "k_trans"], ["qk"], name="qk_matmul")) # Scale factor: sqrt(head_size) for Div nodes.append( helper.make_node( @@ -269,25 +255,15 @@ def attention_builder(): inputs=[], outputs=["scale"], name="scale_const", - value=helper.make_tensor( - "scale", TensorProto.FLOAT, [], [math.sqrt(head_size)] - ), + value=helper.make_tensor("scale", TensorProto.FLOAT, [], [math.sqrt(head_size)]), ) ) nodes.append(helper.make_node("Div", ["qk", "scale"], ["qk_scaled"], name="qk_div")) nodes.append( - helper.make_node( - "Add", ["qk_scaled", "attention_mask"], ["qk_masked"], name="mask_add" - ) - ) - nodes.append( - helper.make_node( - "Softmax", ["qk_masked"], ["attn_probs"], name="softmax", axis=3 - ) - ) - nodes.append( - helper.make_node("MatMul", ["attn_probs", "v_trans"], ["attn_out"], name="attn_v") + helper.make_node("Add", ["qk_scaled", "attention_mask"], ["qk_masked"], name="mask_add") ) + nodes.append(helper.make_node("Softmax", ["qk_masked"], ["attn_probs"], name="softmax", axis=3)) + nodes.append(helper.make_node("MatMul", ["attn_probs", "v_trans"], ["attn_out"], name="attn_v")) # --- Output: Transpose -> Reshape -> MatMul -> Add -> Add(skip) --- nodes.append( @@ -306,26 +282,18 @@ def attention_builder(): inputs=[], outputs=["out_shape"], name="out_shape", - value=helper.make_tensor( - "out_shape", TensorProto.INT64, [3], [0, 0, hidden_size] - ), + value=helper.make_tensor("out_shape", TensorProto.INT64, [3], [0, 0, hidden_size]), ) ) nodes.append( - helper.make_node( - "Reshape", ["out_trans", "out_shape"], ["out_reshape"], name="out_reshape" - ) + helper.make_node("Reshape", ["out_trans", "out_shape"], ["out_reshape"], name="out_reshape") ) nodes.append( helper.make_node("MatMul", ["out_reshape", "wo"], ["out_matmul"], name="out_matmul") ) - nodes.append( - helper.make_node("Add", ["out_matmul", "bo"], ["out_add"], name="out_add") - ) + nodes.append(helper.make_node("Add", ["out_matmul", "bo"], ["out_add"], name="out_add")) # Skip connection: add back ln_out (this is the 4th edge from LayerNorm output) - nodes.append( - helper.make_node("Add", ["out_add", "ln_out"], ["output"], name="skip_add") - ) + nodes.append(helper.make_node("Add", ["out_add", "ln_out"], ["output"], name="skip_add")) # ========================================================================= # GRAPH AND MODEL @@ -334,9 +302,7 @@ def attention_builder(): input_1 = helper.make_tensor_value_info( "input_1", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] ) - input_2 = helper.make_tensor_value_info( - "input_2", TensorProto.INT32, [batch_size, seq_len] - ) + input_2 = helper.make_tensor_value_info("input_2", TensorProto.INT32, [batch_size, seq_len]) output = helper.make_tensor_value_info( "output", TensorProto.FLOAT, [batch_size, seq_len, hidden_size] ) @@ -398,9 +364,7 @@ def multi_head_attention_builder( ) # Scale factor initializers.append( - numpy_helper.from_array( - np.array([1.0 / math.sqrt(16)], dtype=np.float32), f"{prefix}scale" - ) + numpy_helper.from_array(np.array([1.0 / math.sqrt(16)], dtype=np.float32), f"{prefix}scale") ) return [ @@ -498,14 +462,10 @@ def rotary_embeddings_builder( # Position encoding: cos and sin values (simplified) initializers.append( - numpy_helper.from_array( - rng.uniform(0.8, 1.0, 64).astype(np.float32), f"{prefix}cos_pos" - ) + numpy_helper.from_array(rng.uniform(0.8, 1.0, 64).astype(np.float32), f"{prefix}cos_pos") ) initializers.append( - numpy_helper.from_array( - rng.uniform(-0.2, 0.2, 64).astype(np.float32), f"{prefix}sin_pos" - ) + numpy_helper.from_array(rng.uniform(-0.2, 0.2, 64).astype(np.float32), f"{prefix}sin_pos") ) # Reshape for rotation: [1, 64] -> [1, 32, 2] initializers.append( diff --git a/tests/optim/assets/graphpipe/builders/conv.py b/tests/unit/optim/assets/graphpipe/builders/conv.py similarity index 82% rename from tests/optim/assets/graphpipe/builders/conv.py rename to tests/unit/optim/assets/graphpipe/builders/conv.py index bfbc4a8b2..447b81fdf 100644 --- a/tests/optim/assets/graphpipe/builders/conv.py +++ b/tests/unit/optim/assets/graphpipe/builders/conv.py @@ -39,9 +39,7 @@ from onnx import helper, numpy_helper -def conv_bn_builder( - input_name: str, output_name: str, prefix: str, initializers: list -) -> list: +def conv_bn_builder(input_name: str, output_name: str, prefix: str, initializers: list) -> list: """Build Conv -> BN pattern (shape-preserving: 16->16). Tests ConvBNFusion capability (ORT name: FuseConvBN). @@ -51,22 +49,12 @@ def conv_bn_builder( # Conv 16->16 channels initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) - ) - initializers.append( - numpy_helper.from_array(np.ones(16, dtype=np.float32), f"{prefix}bn_scale") - ) - initializers.append( - numpy_helper.from_array(np.zeros(16, dtype=np.float32), f"{prefix}bn_bias") - ) - initializers.append( - numpy_helper.from_array(np.zeros(16, dtype=np.float32), f"{prefix}bn_mean") - ) - initializers.append( - numpy_helper.from_array(np.ones(16, dtype=np.float32), f"{prefix}bn_var") + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) + initializers.append(numpy_helper.from_array(np.ones(16, dtype=np.float32), f"{prefix}bn_scale")) + initializers.append(numpy_helper.from_array(np.zeros(16, dtype=np.float32), f"{prefix}bn_bias")) + initializers.append(numpy_helper.from_array(np.zeros(16, dtype=np.float32), f"{prefix}bn_mean")) + initializers.append(numpy_helper.from_array(np.ones(16, dtype=np.float32), f"{prefix}bn_var")) return [ helper.make_node( @@ -105,21 +93,15 @@ def conv_add_relu_builder( rng = np.random.RandomState(hash(prefix) % (2**32)) initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # P1-08: Create bias only once, reused below initializers.append( - numpy_helper.from_array( - rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias" - ) + numpy_helper.from_array(rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias") ) # P1-05: Expand to same shape (identity) initializers.append( - numpy_helper.from_array( - np.array([1, 16, 32, 32], dtype=np.int64), f"{prefix}expand_shape" - ) + numpy_helper.from_array(np.array([1, 16, 32, 32], dtype=np.int64), f"{prefix}expand_shape") ) # Dropout in inference mode (training_mode not set = False by default in opset 12+) @@ -158,9 +140,7 @@ def conv_add_relu_builder( [f"{prefix}add2_out"], name=f"{prefix}add2", ), - helper.make_node( - "Relu", [f"{prefix}add2_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Relu", [f"{prefix}add2_out"], [output_name], name=f"{prefix}relu"), ] @@ -188,15 +168,11 @@ def conv_activation_builder( # Conv 16->16 channels with bias initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # Add bias to Conv to make it more realistic initializers.append( - numpy_helper.from_array( - rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b" - ) + numpy_helper.from_array(rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b") ) return [ @@ -219,9 +195,7 @@ def conv_activation_builder( ] -def conv_mul_builder( - input_name: str, output_name: str, prefix: str, initializers: list -) -> list: +def conv_mul_builder(input_name: str, output_name: str, prefix: str, initializers: list) -> list: """Build Conv+Mul(scale) pattern: Input -> Conv -> Mul. P3-08: Tests conv-mul-fusion capability (ORT name: ConvMulFusion). @@ -231,9 +205,7 @@ def conv_mul_builder( # Conv 16->16 channels, shape-preserving initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # Scale factor for Mul (broadcasting: [1, 16, 1, 1] or just [16]) scale_value = rng.randn(16, 1, 1).astype(np.float32) * 0.1 + 1.0 @@ -279,15 +251,11 @@ def conv_add_activation_builder( # Conv 16->16 channels WITHOUT bias (important for fusion!) initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # 1D bias for Add - shape [16, 1, 1] to broadcast with NCHW output [1, 16, 32, 32] initializers.append( - numpy_helper.from_array( - rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias" - ) + numpy_helper.from_array(rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias") ) return [ @@ -350,27 +318,19 @@ def nchwc_transformer_builder( # Reshape input from [1, 64] to [1, 16, 2, 2] for Conv2D initializers.append( - numpy_helper.from_array( - np.array([1, 16, 2, 2], dtype=np.int64), f"{prefix}reshape_shape" - ) + numpy_helper.from_array(np.array([1, 16, 2, 2], dtype=np.int64), f"{prefix}reshape_shape") ) # Conv weights: 16 input channels, 16 output channels, 3x3 kernel # NCHWc transformation works best with channel counts divisible by block size initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) initializers.append( - numpy_helper.from_array( - rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b" - ) + numpy_helper.from_array(rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b") ) # Reshape back to [1, 64] initializers.append( - numpy_helper.from_array( - np.array([1, 64], dtype=np.int64), f"{prefix}out_shape" - ) + numpy_helper.from_array(np.array([1, 64], dtype=np.int64), f"{prefix}out_shape") ) return [ @@ -421,16 +381,12 @@ def conv_add_fusion_builder( # Conv 16->16 channels WITHOUT bias (important for fusion!) initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # 1D bias for Add - shape [16, 1, 1] to broadcast with NCHW output [1, 16, 32, 32] # ONNX broadcasts from rightmost axis, so [16, 1, 1] broadcasts to [1, 16, H, W] initializers.append( - numpy_helper.from_array( - rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias" - ) + numpy_helper.from_array(rng.randn(16, 1, 1).astype(np.float32) * 0.1, f"{prefix}bias") ) return [ @@ -467,26 +423,18 @@ def nhwc_transformer_builder( # Reshape input from [1, 64] to [1, 16, 2, 2] for Conv2D (NCHW format) initializers.append( - numpy_helper.from_array( - np.array([1, 16, 2, 2], dtype=np.int64), f"{prefix}reshape_shape" - ) + numpy_helper.from_array(np.array([1, 16, 2, 2], dtype=np.int64), f"{prefix}reshape_shape") ) # Conv weights: 16 input channels, 16 output channels, 3x3 kernel initializers.append( - numpy_helper.from_array( - rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(16, 16, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) initializers.append( - numpy_helper.from_array( - rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b" - ) + numpy_helper.from_array(rng.randn(16).astype(np.float32) * 0.1, f"{prefix}conv_b") ) # Reshape back to [1, 64] initializers.append( - numpy_helper.from_array( - np.array([1, 64], dtype=np.int64), f"{prefix}out_shape" - ) + numpy_helper.from_array(np.array([1, 64], dtype=np.int64), f"{prefix}out_shape") ) return [ @@ -532,9 +480,7 @@ def nhwc_transformer_builder( ] -def pad_conv_builder( - input_name: str, output_name: str, prefix: str, initializers: list -) -> list: +def pad_conv_builder(input_name: str, output_name: str, prefix: str, initializers: list) -> list: """Build Pad+Conv pattern: Input -> Pad -> Conv. P3-03: Tests pad-fusion capability (ORT name: Pad_Fusion). @@ -547,27 +493,19 @@ def pad_conv_builder( # Reshape input from [1, 64] to [1, 1, 8, 8] for Conv2D initializers.append( - numpy_helper.from_array( - np.array([1, 1, 8, 8], dtype=np.int64), f"{prefix}reshape_shape" - ) + numpy_helper.from_array(np.array([1, 1, 8, 8], dtype=np.int64), f"{prefix}reshape_shape") ) # Pad values for Conv: [batch_start, c_start, H_start, W_start, batch_end, c_end, H_end, W_end] initializers.append( - numpy_helper.from_array( - np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.int64), f"{prefix}pads" - ) + numpy_helper.from_array(np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.int64), f"{prefix}pads") ) # Conv weights: 1 input channel, 1 output channel, 3x3 kernel initializers.append( - numpy_helper.from_array( - rng.randn(1, 1, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w" - ) + numpy_helper.from_array(rng.randn(1, 1, 3, 3).astype(np.float32) * 0.1, f"{prefix}conv_w") ) # Reshape back to [1, 64] initializers.append( - numpy_helper.from_array( - np.array([1, 64], dtype=np.int64), f"{prefix}out_shape" - ) + numpy_helper.from_array(np.array([1, 64], dtype=np.int64), f"{prefix}out_shape") ) return [ diff --git a/tests/optim/assets/graphpipe/builders/core.py b/tests/unit/optim/assets/graphpipe/builders/core.py similarity index 100% rename from tests/optim/assets/graphpipe/builders/core.py rename to tests/unit/optim/assets/graphpipe/builders/core.py diff --git a/tests/optim/assets/graphpipe/builders/elimination.py b/tests/unit/optim/assets/graphpipe/builders/elimination.py similarity index 86% rename from tests/optim/assets/graphpipe/builders/elimination.py rename to tests/unit/optim/assets/graphpipe/builders/elimination.py index 5127e9950..d68ce72f5 100644 --- a/tests/optim/assets/graphpipe/builders/elimination.py +++ b/tests/unit/optim/assets/graphpipe/builders/elimination.py @@ -60,14 +60,10 @@ def slice_elimination_builder( numpy_helper.from_array(np.array([0, 0], dtype=np.int64), f"{prefix}starts") ) initializers.append( - numpy_helper.from_array( - np.array([INT64_MAX, INT64_MAX], dtype=np.int64), f"{prefix}ends" - ) + numpy_helper.from_array(np.array([INT64_MAX, INT64_MAX], dtype=np.int64), f"{prefix}ends") ) # Axes in any order, as long as all dims are covered - initializers.append( - numpy_helper.from_array(np.array([0, 1], dtype=np.int64), f"{prefix}axes") - ) + initializers.append(numpy_helper.from_array(np.array([0, 1], dtype=np.int64), f"{prefix}axes")) return [ # Identity slice (starts=0, ends=INT64_MAX for all dims) - eliminable by ORT @@ -78,12 +74,8 @@ def slice_elimination_builder( name=f"{prefix}slice", ), # Downstream consumer - allows CanRemoveNode to pass - helper.make_node( - "Abs", [f"{prefix}slice_out"], [f"{prefix}abs_out"], name=f"{prefix}abs" - ), - helper.make_node( - "Relu", [f"{prefix}abs_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Abs", [f"{prefix}slice_out"], [f"{prefix}abs_out"], name=f"{prefix}abs"), + helper.make_node("Relu", [f"{prefix}abs_out"], [output_name], name=f"{prefix}relu"), ] @@ -111,9 +103,7 @@ def unsqueeze_elimination_builder( # Create a constant tensor with shape [64] that Unsqueeze will expand to [1, 64] # This constant is what makes the Unsqueeze eliminable const_data = rng.randn(64).astype(np.float32) * 0.01 - initializers.append( - numpy_helper.from_array(const_data, f"{prefix}const_tensor") - ) + initializers.append(numpy_helper.from_array(const_data, f"{prefix}const_tensor")) # Axes for Unsqueeze (opset 13+) - must be constant initializer # Adding axis 0 transforms [64] -> [1, 64] @@ -136,9 +126,7 @@ def unsqueeze_elimination_builder( [f"{prefix}add_out"], name=f"{prefix}add", ), - helper.make_node( - "Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu"), ] @@ -160,9 +148,7 @@ def reshape_elimination_builder( # Create a constant tensor with shape [64] that Reshape will transform to [1, 64] const_data = rng.randn(64).astype(np.float32) * 0.01 - initializers.append( - numpy_helper.from_array(const_data, f"{prefix}const_tensor") - ) + initializers.append(numpy_helper.from_array(const_data, f"{prefix}const_tensor")) # Shape initializer to reshape [64] -> [1, 64] initializers.append( @@ -184,9 +170,7 @@ def reshape_elimination_builder( [f"{prefix}add_out"], name=f"{prefix}add", ), - helper.make_node( - "Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu"), ] @@ -220,9 +204,7 @@ def expand_elimination_builder( [f"{prefix}expand_out"], name=f"{prefix}expand", ), - helper.make_node( - "Relu", [f"{prefix}expand_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Relu", [f"{prefix}expand_out"], [output_name], name=f"{prefix}relu"), ] @@ -266,23 +248,13 @@ def concat_slice_elimination_builder( initializers.append(numpy_helper.from_array(const2, f"{prefix}const2")) # Slice parameters to extract first segment [0:20] - initializers.append( - numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}starts") - ) - initializers.append( - numpy_helper.from_array(np.array([20], dtype=np.int64), f"{prefix}ends") - ) - initializers.append( - numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}axes") - ) - initializers.append( - numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}steps") - ) + initializers.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}starts")) + initializers.append(numpy_helper.from_array(np.array([20], dtype=np.int64), f"{prefix}ends")) + initializers.append(numpy_helper.from_array(np.array([0], dtype=np.int64), f"{prefix}axes")) + initializers.append(numpy_helper.from_array(np.array([1], dtype=np.int64), f"{prefix}steps")) # Pad to restore shape for template compatibility: [20] -> [64] - initializers.append( - numpy_helper.from_array(np.array([0, 44], dtype=np.int64), f"{prefix}pads") - ) + initializers.append(numpy_helper.from_array(np.array([0, 44], dtype=np.int64), f"{prefix}pads")) initializers.append( numpy_helper.from_array(np.array(0.0, dtype=np.float32), f"{prefix}pad_value") ) @@ -336,7 +308,5 @@ def concat_slice_elimination_builder( [f"{prefix}add_out"], name=f"{prefix}add", ), - helper.make_node( - "Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu" - ), + helper.make_node("Relu", [f"{prefix}add_out"], [output_name], name=f"{prefix}relu"), ] diff --git a/tests/optim/assets/graphpipe/builders/gelu.py b/tests/unit/optim/assets/graphpipe/builders/gelu.py similarity index 100% rename from tests/optim/assets/graphpipe/builders/gelu.py rename to tests/unit/optim/assets/graphpipe/builders/gelu.py diff --git a/tests/optim/assets/graphpipe/builders/gemm.py b/tests/unit/optim/assets/graphpipe/builders/gemm.py similarity index 98% rename from tests/optim/assets/graphpipe/builders/gemm.py rename to tests/unit/optim/assets/graphpipe/builders/gemm.py index 5850a0466..ef57a9f33 100644 --- a/tests/optim/assets/graphpipe/builders/gemm.py +++ b/tests/unit/optim/assets/graphpipe/builders/gemm.py @@ -58,9 +58,7 @@ def gemm_activation_builder( ] -def gemm_sum_builder( - input_name: str, output_name: str, prefix: str, initializers: list -) -> list: +def gemm_sum_builder(input_name: str, output_name: str, prefix: str, initializers: list) -> list: """Build Gemm+Sum pattern: Input → Gemm → Sum. P3-06: Tests gemm-sum-fusion capability (ORT name: GemmSumFusion). diff --git a/tests/optim/assets/graphpipe/builders/layernorm.py b/tests/unit/optim/assets/graphpipe/builders/layernorm.py similarity index 96% rename from tests/optim/assets/graphpipe/builders/layernorm.py rename to tests/unit/optim/assets/graphpipe/builders/layernorm.py index a056b9a3e..9250671bb 100644 --- a/tests/optim/assets/graphpipe/builders/layernorm.py +++ b/tests/unit/optim/assets/graphpipe/builders/layernorm.py @@ -520,12 +520,30 @@ def embed_layer_norm_builder() -> onnx.ModelProto: TensorProto.FLOAT, [batch_size, sequence_length, hidden_size], [ - 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, - 9.0, 8.0, 7.0, 6.0, - 1.0, 2.0, 3.0, 4.0, - 5.0, 6.0, 7.0, 8.0, - 9.0, 8.0, 7.0, 6.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 8.0, + 7.0, + 6.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 8.0, + 7.0, + 6.0, ], ), # Segment embedding table [vocab_size=2, hidden_size=4] @@ -569,10 +587,22 @@ def embed_layer_norm_builder() -> onnx.ModelProto: TensorProto.FLOAT, [hidden_size, hidden_size], [ - 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, - 1.0, 2.0, 3.0, 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, ], ), # Add bias [hidden] diff --git a/tests/optim/assets/graphpipe/builders/matmul.py b/tests/unit/optim/assets/graphpipe/builders/matmul.py similarity index 99% rename from tests/optim/assets/graphpipe/builders/matmul.py rename to tests/unit/optim/assets/graphpipe/builders/matmul.py index f7db0d1ee..e776cfc70 100644 --- a/tests/optim/assets/graphpipe/builders/matmul.py +++ b/tests/unit/optim/assets/graphpipe/builders/matmul.py @@ -196,9 +196,7 @@ def matmul_scale_builder( ] -def matmul_bn_builder( - input_name: str, output_name: str, prefix: str, initializers: list -) -> list: +def matmul_bn_builder(input_name: str, output_name: str, prefix: str, initializers: list) -> list: """Build MatMul+BatchNorm pattern: Input → MatMul → BatchNormalization. P4-01: Tests matmul-bn-fusion capability (ORT name: MatMul_BatchNormalization_Fusion). diff --git a/tests/optim/assets/graphpipe/builders/misc.py b/tests/unit/optim/assets/graphpipe/builders/misc.py similarity index 99% rename from tests/optim/assets/graphpipe/builders/misc.py rename to tests/unit/optim/assets/graphpipe/builders/misc.py index d830d2fcd..0dacfcd15 100644 --- a/tests/optim/assets/graphpipe/builders/misc.py +++ b/tests/unit/optim/assets/graphpipe/builders/misc.py @@ -187,9 +187,7 @@ def not_where_builder(input_name: str, output_name: str, prefix: str, initialize # The seed is derived from prefix, so each pattern instance gets a different value. unique_threshold = rng.uniform(-1.0, 1.0) initializers.append( - numpy_helper.from_array( - np.array(unique_threshold, dtype=np.float32), f"{prefix}threshold" - ) + numpy_helper.from_array(np.array(unique_threshold, dtype=np.float32), f"{prefix}threshold") ) # Alternative values for Where - also unique per pattern to prevent any CSE initializers.append( diff --git a/tests/optim/assets/graphpipe/generate_patterns.py b/tests/unit/optim/assets/graphpipe/generate_patterns.py similarity index 99% rename from tests/optim/assets/graphpipe/generate_patterns.py rename to tests/unit/optim/assets/graphpipe/generate_patterns.py index 53dc0d8b1..13f2fb5f2 100644 --- a/tests/optim/assets/graphpipe/generate_patterns.py +++ b/tests/unit/optim/assets/graphpipe/generate_patterns.py @@ -226,9 +226,7 @@ def build(self) -> tuple[list, list, list, list]: ) # Create input/output value infos - inputs = [ - helper.make_tensor_value_info(input_name, TensorProto.FLOAT, list(self.x_shape)) - ] + inputs = [helper.make_tensor_value_info(input_name, TensorProto.FLOAT, list(self.x_shape))] outputs = [ helper.make_tensor_value_info(output_name, TensorProto.FLOAT, list(self.x_shape)) ] diff --git a/tests/optim/assets/surgerypipe/builders/mask.py b/tests/unit/optim/assets/surgerypipe/builders/mask.py similarity index 93% rename from tests/optim/assets/surgerypipe/builders/mask.py rename to tests/unit/optim/assets/surgerypipe/builders/mask.py index c01e4fc9b..064d118ed 100644 --- a/tests/optim/assets/surgerypipe/builders/mask.py +++ b/tests/unit/optim/assets/surgerypipe/builders/mask.py @@ -45,21 +45,15 @@ def build_causal_mask_model( ONNX model with causal mask pattern """ # Build causal mask (lower triangular with zeros, upper with mask_value) - causal_mask_values = np.triu( - np.full((seq_len, seq_len), mask_value, dtype=np.float32), k=1 - ) + causal_mask_values = np.triu(np.full((seq_len, seq_len), mask_value, dtype=np.float32), k=1) causal_mask_values = causal_mask_values.reshape(1, 1, seq_len, seq_len) # Initializers causal_mask_init = numpy_helper.from_array(causal_mask_values, "causal_mask.1") - mask_value_init = numpy_helper.from_array( - np.array(mask_value, dtype=np.float32), "mask_value" - ) + mask_value_init = numpy_helper.from_array(np.array(mask_value, dtype=np.float32), "mask_value") # Input/Output - input_tensor = helper.make_tensor_value_info( - "attention_mask", TensorProto.INT64, [1, seq_len] - ) + input_tensor = helper.make_tensor_value_info("attention_mask", TensorProto.INT64, [1, seq_len]) output_tensor = helper.make_tensor_value_info( "causal_mask", TensorProto.FLOAT, [1, 1, seq_len, seq_len] ) diff --git a/tests/optim/assets/transpipe/builders/__init__.py b/tests/unit/optim/assets/transpipe/builders/__init__.py similarity index 100% rename from tests/optim/assets/transpipe/builders/__init__.py rename to tests/unit/optim/assets/transpipe/builders/__init__.py diff --git a/tests/optim/assets/transpipe/builders/attention.py b/tests/unit/optim/assets/transpipe/builders/attention.py similarity index 100% rename from tests/optim/assets/transpipe/builders/attention.py rename to tests/unit/optim/assets/transpipe/builders/attention.py diff --git a/tests/optim/assets/transpipe/builders/layernorm.py b/tests/unit/optim/assets/transpipe/builders/layernorm.py similarity index 100% rename from tests/optim/assets/transpipe/builders/layernorm.py rename to tests/unit/optim/assets/transpipe/builders/layernorm.py diff --git a/tests/optim/assets/transpipe/test_builders.py b/tests/unit/optim/assets/transpipe/test_builders.py similarity index 100% rename from tests/optim/assets/transpipe/test_builders.py rename to tests/unit/optim/assets/transpipe/test_builders.py diff --git a/tests/optim/capabilities/conftest.py b/tests/unit/optim/capabilities/conftest.py similarity index 100% rename from tests/optim/capabilities/conftest.py rename to tests/unit/optim/capabilities/conftest.py diff --git a/tests/optim/capabilities/test_capability_isolation.py b/tests/unit/optim/capabilities/test_capability_isolation.py similarity index 100% rename from tests/optim/capabilities/test_capability_isolation.py rename to tests/unit/optim/capabilities/test_capability_isolation.py diff --git a/tests/optim/conftest.py b/tests/unit/optim/conftest.py similarity index 100% rename from tests/optim/conftest.py rename to tests/unit/optim/conftest.py diff --git a/tests/models/auto/__init__.py b/tests/unit/optim/fusions/__init__.py similarity index 100% rename from tests/models/auto/__init__.py rename to tests/unit/optim/fusions/__init__.py diff --git a/tests/optim/fusions/test_fusion_rmsnorm.py b/tests/unit/optim/fusions/test_fusion_rmsnorm.py similarity index 87% rename from tests/optim/fusions/test_fusion_rmsnorm.py rename to tests/unit/optim/fusions/test_fusion_rmsnorm.py index fe05d445f..efe7f488c 100644 --- a/tests/optim/fusions/test_fusion_rmsnorm.py +++ b/tests/unit/optim/fusions/test_fusion_rmsnorm.py @@ -54,9 +54,7 @@ def _make_rmsnorm_model( weight = numpy_helper.from_array(weight_value, "weight") pow_node = helper.make_node("Pow", ["input", "pow_exp"], ["pow_out"]) - reduce_mean = helper.make_node( - "ReduceMean", ["pow_out"], ["mean_out"], axes=[-1], keepdims=1 - ) + reduce_mean = helper.make_node("ReduceMean", ["pow_out"], ["mean_out"], axes=[-1], keepdims=1) add_eps = helper.make_node("Add", ["mean_out", "epsilon"], ["add_out"]) sqrt_node = helper.make_node("Sqrt", ["add_out"], ["sqrt_out"]) div_node = helper.make_node("Div", ["input", "sqrt_out"], ["div_out"]) @@ -82,9 +80,7 @@ def _make_multi_rmsnorm_model( """ rng = np.random.RandomState(42) - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 8, hidden_size] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 8, hidden_size]) nodes = [] initializers = [] @@ -95,18 +91,12 @@ def _make_multi_rmsnorm_model( weight_val = rng.randn(hidden_size).astype(np.float32) out_name = f"{prefix}output" - pow_exp = numpy_helper.from_array( - np.array(2.0, dtype=np.float32), f"{prefix}pow_exp" - ) - epsilon = numpy_helper.from_array( - np.array(1e-6, dtype=np.float32), f"{prefix}epsilon" - ) + pow_exp = numpy_helper.from_array(np.array(2.0, dtype=np.float32), f"{prefix}pow_exp") + epsilon = numpy_helper.from_array(np.array(1e-6, dtype=np.float32), f"{prefix}epsilon") weight = numpy_helper.from_array(weight_val, f"{prefix}weight") initializers.extend([pow_exp, epsilon, weight]) - pow_node = helper.make_node( - "Pow", [prev_output, f"{prefix}pow_exp"], [f"{prefix}pow_out"] - ) + pow_node = helper.make_node("Pow", [prev_output, f"{prefix}pow_exp"], [f"{prefix}pow_out"]) reduce_mean = helper.make_node( "ReduceMean", [f"{prefix}pow_out"], @@ -117,21 +107,13 @@ def _make_multi_rmsnorm_model( add_eps = helper.make_node( "Add", [f"{prefix}mean_out", f"{prefix}epsilon"], [f"{prefix}add_out"] ) - sqrt_node = helper.make_node( - "Sqrt", [f"{prefix}add_out"], [f"{prefix}sqrt_out"] - ) - div_node = helper.make_node( - "Div", [prev_output, f"{prefix}sqrt_out"], [f"{prefix}div_out"] - ) - mul_weight = helper.make_node( - "Mul", [f"{prefix}div_out", f"{prefix}weight"], [out_name] - ) + sqrt_node = helper.make_node("Sqrt", [f"{prefix}add_out"], [f"{prefix}sqrt_out"]) + div_node = helper.make_node("Div", [prev_output, f"{prefix}sqrt_out"], [f"{prefix}div_out"]) + mul_weight = helper.make_node("Mul", [f"{prefix}div_out", f"{prefix}weight"], [out_name]) nodes.extend([pow_node, reduce_mean, add_eps, sqrt_node, div_node, mul_weight]) prev_output = out_name - y_info = helper.make_tensor_value_info( - prev_output, TensorProto.FLOAT, [1, 8, hidden_size] - ) + y_info = helper.make_tensor_value_info(prev_output, TensorProto.FLOAT, [1, 8, hidden_size]) graph = helper.make_graph( nodes, "multi_rmsnorm_test", [x_info], [y_info], initializer=initializers @@ -217,16 +199,12 @@ def test_weight_adjustment(self): """Verify weight is multiplied by sqrt(hidden_size).""" hidden_size = 16 original_weight = np.arange(1, hidden_size + 1, dtype=np.float32) - model = _make_rmsnorm_model( - hidden_size=hidden_size, weight_value=original_weight - ) + model = _make_rmsnorm_model(hidden_size=hidden_size, weight_value=original_weight) result = _apply_fusion(model) # Find the adjusted weight initializer adjusted_inits = [ - init - for init in result.graph.initializer - if "l2norm_adjusted" in init.name + init for init in result.graph.initializer if "l2norm_adjusted" in init.name ] assert len(adjusted_inits) == 1 @@ -238,15 +216,11 @@ def test_all_ones_weight_collapses(self): """When weight is all 1.0, verify collapse to scalar [sqrt(N)].""" hidden_size = 32 ones_weight = np.ones(hidden_size, dtype=np.float32) - model = _make_rmsnorm_model( - hidden_size=hidden_size, weight_value=ones_weight - ) + model = _make_rmsnorm_model(hidden_size=hidden_size, weight_value=ones_weight) result = _apply_fusion(model) adjusted_inits = [ - init - for init in result.graph.initializer - if "l2norm_adjusted" in init.name + init for init in result.graph.initializer if "l2norm_adjusted" in init.name ] assert len(adjusted_inits) == 1 @@ -263,16 +237,10 @@ class TestNoMatch: def test_no_match_without_weight_initializer(self): """Mul without initializer input should NOT trigger fusion.""" hidden_size = 64 - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 8, hidden_size] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 8, hidden_size]) # A second runtime input instead of an initializer - w_info = helper.make_tensor_value_info( - "weight_runtime", TensorProto.FLOAT, [hidden_size] - ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 8, hidden_size] - ) + w_info = helper.make_tensor_value_info("weight_runtime", TensorProto.FLOAT, [hidden_size]) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 8, hidden_size]) pow_exp = numpy_helper.from_array(np.array(2.0, dtype=np.float32), "pow_exp") epsilon = numpy_helper.from_array(np.array(1e-6, dtype=np.float32), "epsilon") @@ -285,9 +253,7 @@ def test_no_match_without_weight_initializer(self): sqrt_node = helper.make_node("Sqrt", ["add_out"], ["sqrt_out"]) div_node = helper.make_node("Div", ["input", "sqrt_out"], ["div_out"]) # Mul uses runtime input, not initializer - mul_weight = helper.make_node( - "Mul", ["div_out", "weight_runtime"], ["output"] - ) + mul_weight = helper.make_node("Mul", ["div_out", "weight_runtime"], ["output"]) graph = helper.make_graph( [pow_node, reduce_mean, add_eps, sqrt_node, div_node, mul_weight], @@ -311,12 +277,8 @@ def test_no_match_wrong_pow_exponent(self): rng = np.random.RandomState(42) weight_value = rng.randn(hidden_size).astype(np.float32) - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 8, hidden_size] - ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 8, hidden_size] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 8, hidden_size]) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 8, hidden_size]) # Exponent is 3.0, not 2.0 pow_exp = numpy_helper.from_array(np.array(3.0, dtype=np.float32), "pow_exp") @@ -353,12 +315,8 @@ def test_no_match_incomplete_chain(self): rng = np.random.RandomState(42) weight_value = rng.randn(hidden_size).astype(np.float32) - x_info = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 8, hidden_size] - ) - y_info = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 8, hidden_size] - ) + x_info = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 8, hidden_size]) + y_info = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 8, hidden_size]) pow_exp = numpy_helper.from_array(np.array(2.0, dtype=np.float32), "pow_exp") epsilon = numpy_helper.from_array(np.array(1e-6, dtype=np.float32), "epsilon") diff --git a/tests/optim/integration/test_optimizer.py b/tests/unit/optim/integration/test_optimizer.py similarity index 100% rename from tests/optim/integration/test_optimizer.py rename to tests/unit/optim/integration/test_optimizer.py diff --git a/tests/optim/pipes/conftest.py b/tests/unit/optim/pipes/conftest.py similarity index 100% rename from tests/optim/pipes/conftest.py rename to tests/unit/optim/pipes/conftest.py diff --git a/tests/optim/pipes/test_constant_folding.py b/tests/unit/optim/pipes/test_constant_folding.py similarity index 100% rename from tests/optim/pipes/test_constant_folding.py rename to tests/unit/optim/pipes/test_constant_folding.py diff --git a/tests/optim/pipes/test_pipe_base.py b/tests/unit/optim/pipes/test_pipe_base.py similarity index 93% rename from tests/optim/pipes/test_pipe_base.py rename to tests/unit/optim/pipes/test_pipe_base.py index 577396f93..b4f1bba2b 100644 --- a/tests/optim/pipes/test_pipe_base.py +++ b/tests/unit/optim/pipes/test_pipe_base.py @@ -81,9 +81,7 @@ def test_base_pipe_requires_build_config(self) -> None: class IncompletePipe(BasePipe): # type: ignore[abstract] name: ClassVar[str] = "incomplete" - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: return model IncompletePipe() @@ -114,9 +112,7 @@ class CompletePipe(BasePipe): def build_config(cls, **kwargs: Any) -> PipeConfig: return PipeConfig() - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: return model # Should be instantiable @@ -137,9 +133,7 @@ class TestPipe(BasePipe): def build_config(cls, **kwargs: Any) -> TestPipeConfig: return TestPipeConfig(enabled=kwargs.get("enabled", False)) - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: # Simple pass-through for test return model @@ -171,9 +165,7 @@ class TestPipe(BasePipe): def build_config(cls, **kwargs: Any) -> PipeConfig: return PipeConfig() - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: return model # capabilities should be a dict @@ -196,9 +188,7 @@ class TestPipe(BasePipe): def build_config(cls, **kwargs: Any) -> PipeConfig: return PipeConfig() - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: return model # Default should_process returns True @@ -219,9 +209,7 @@ class TestPipe(BasePipe): def build_config(cls, **kwargs: Any) -> TestPipeConfig: return TestPipeConfig(value=kwargs.get("value", 0)) - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: return model # Should accept any kwargs without error @@ -238,9 +226,7 @@ class TestPipe(BasePipe): def build_config(cls, **kwargs: Any) -> PipeConfig: return PipeConfig() - def process( - self, model: onnx.ModelProto, config: PipeConfig - ) -> onnx.ModelProto: + def process(self, model: onnx.ModelProto, config: PipeConfig) -> onnx.ModelProto: # Return the model unchanged return model diff --git a/tests/optim/pipes/test_pipe_config.py b/tests/unit/optim/pipes/test_pipe_config.py similarity index 98% rename from tests/optim/pipes/test_pipe_config.py rename to tests/unit/optim/pipes/test_pipe_config.py index 3e5a1b1cf..c4be9f675 100644 --- a/tests/optim/pipes/test_pipe_config.py +++ b/tests/unit/optim/pipes/test_pipe_config.py @@ -71,10 +71,13 @@ def test_defaults_all_disabled(self) -> None: # Capabilities with default=True (like ConstantFolding) stay enabled # L1 variants (GeluFusion, LayerNormFusion) are added for proper isolation # Always-disabled (AttentionFusion, EmbedLayerNormFusion) are never enabled - caps_count = len([ - c for c in GRAPH_CAPABILITIES.values() - if hasattr(c, "ort_name") and c.ort_name and not c.default - ]) + caps_count = len( + [ + c + for c in GRAPH_CAPABILITIES.values() + if hasattr(c, "ort_name") and c.ort_name and not c.default + ] + ) expected_count = caps_count + len(L1_VARIANTS) + len(ALWAYS_DISABLED) assert len(config.disabled_optimizers) == expected_count @@ -275,10 +278,13 @@ def test_isolation_mode_only_enabled_run(self) -> None: # All others should be disabled (except default=True caps like ConstantFolding) # gelu_fusion enables 2 items (GeluFusionL2 + GeluFusion L1 variant) disabled_count = len(config.disabled_optimizers) - caps_count = len([ - c for c in GRAPH_CAPABILITIES.values() - if hasattr(c, "ort_name") and c.ort_name and not c.default - ]) + caps_count = len( + [ + c + for c in GRAPH_CAPABILITIES.values() + if hasattr(c, "ort_name") and c.ort_name and not c.default + ] + ) # Total disabled = caps(default=False) + L1_variants + always_disabled - enabled(2) expected_disabled = caps_count + len(L1_VARIANTS) + len(ALWAYS_DISABLED) - 2 assert disabled_count == expected_disabled diff --git a/tests/optim/pipes/test_pipe_fusion.py b/tests/unit/optim/pipes/test_pipe_fusion.py similarity index 96% rename from tests/optim/pipes/test_pipe_fusion.py rename to tests/unit/optim/pipes/test_pipe_fusion.py index 2e2e6735f..5bdcb7a8e 100644 --- a/tests/optim/pipes/test_pipe_fusion.py +++ b/tests/unit/optim/pipes/test_pipe_fusion.py @@ -194,9 +194,7 @@ def test_build_config_single_fusion(self, fusion_capabilities: dict) -> None: def test_build_config_multiple_fusions(self, fusion_capabilities: dict) -> None: """Enable multiple fusions via kwargs.""" - config = ORTFusionPipe.build_config( - attention_fusion=True, layer_norm_fusion=True - ) + config = ORTFusionPipe.build_config(attention_fusion=True, layer_norm_fusion=True) assert config.enable_attention is True assert config.enable_layer_norm is True @@ -214,9 +212,7 @@ def test_build_config_fusion_attrs_mapping(self, fusion_capabilities: dict) -> N config = ORTFusionPipe.build_config(layer_norm_fusion=True) assert config.enable_layer_norm is True - def test_build_config_respects_capability_defaults( - self, fusion_capabilities: dict - ) -> None: + def test_build_config_respects_capability_defaults(self, fusion_capabilities: dict) -> None: """Build config without overrides should respect capability defaults. Note: GELU capabilities are disabled due to ORT bundling issue. @@ -252,9 +248,7 @@ def test_build_config_ignores_unknown_kwargs(self, fusion_capabilities: dict) -> # Known parameter should still work assert config.enable_layer_norm is True - def test_build_config_with_model_type_and_fusions( - self, fusion_capabilities: dict - ) -> None: + def test_build_config_with_model_type_and_fusions(self, fusion_capabilities: dict) -> None: """Build config with both model type and fusion options.""" config = ORTFusionPipe.build_config( model_type="gpt2", layer_norm_fusion=True, attention_fusion=True @@ -377,9 +371,7 @@ def test_should_process_returns_true_with_multiple_fusions(self) -> None: Note: GELU capabilities (5) are disabled due to ORT bundling issue. """ - config = ORTFusionPipeConfig( - enable_attention=True, enable_layer_norm=True - ) + config = ORTFusionPipeConfig(enable_attention=True, enable_layer_norm=True) assert ORTFusionPipe.should_process(config) is True # Test all 12 fusion toggles enabled (GELU disabled) @@ -439,9 +431,9 @@ def test_should_process_checks_all_fusion_options(self) -> None: test_config_kwargs = all_options.copy() test_config_kwargs[option_name] = True test_config = ORTFusionPipeConfig(**test_config_kwargs) - assert ( - ORTFusionPipe.should_process(test_config) is True - ), f"should_process should return True when {option_name} is enabled" + assert ORTFusionPipe.should_process(test_config) is True, ( + f"should_process should return True when {option_name} is enabled" + ) class TestORTFusionPipeIntegration: @@ -558,9 +550,7 @@ def test_all_fusion_options_accessible(self) -> None: class TestORTFusionPipeEdgeCases: """Edge case tests for ORTFusionPipe.""" - def test_build_config_with_all_fusions_enabled( - self, fusion_capabilities: dict - ) -> None: + def test_build_config_with_all_fusions_enabled(self, fusion_capabilities: dict) -> None: """Test build_config with all fusion options enabled. This edge case ensures the system can handle enabling all fusion @@ -569,18 +559,14 @@ def test_build_config_with_all_fusions_enabled( Note: GELU capabilities are disabled due to ORT bundling issue. """ - config = ORTFusionPipe.build_config( - attention_fusion=True, layer_norm_fusion=True - ) + config = ORTFusionPipe.build_config(attention_fusion=True, layer_norm_fusion=True) assert isinstance(config, ORTFusionPipeConfig) # Verify all enabled fusions are reflected in config assert config.enable_attention is True assert config.enable_layer_norm is True - def test_build_config_with_different_model_types( - self, fusion_capabilities: dict - ) -> None: + def test_build_config_with_different_model_types(self, fusion_capabilities: dict) -> None: """Test build_config with different model types. This tests that build_config correctly handles various model types @@ -596,9 +582,7 @@ def test_build_config_with_different_model_types( assert isinstance(config, ORTFusionPipeConfig) assert config.model_type == model_type - def test_process_passthrough_when_no_fusions( - self, sample_model: onnx.ModelProto - ) -> None: + def test_process_passthrough_when_no_fusions(self, sample_model: onnx.ModelProto) -> None: """Test that model is returned unchanged when no fusions are enabled. When all fusion options are False, the should_process check should return @@ -841,15 +825,11 @@ def _count_nodes_by_op_type( count += 1 return count - def _has_node( - self, model: onnx.ModelProto, op_type: str, domain: str = "" - ) -> bool: + def _has_node(self, model: onnx.ModelProto, op_type: str, domain: str = "") -> bool: """Check if model has a node with specific op_type and domain.""" return self._count_nodes_by_op_type(model, op_type, domain) > 0 - def test_self_attention_model_structure( - self, self_attention_model: onnx.ModelProto - ) -> None: + def test_self_attention_model_structure(self, self_attention_model: onnx.ModelProto) -> None: """Verify self-attention test model has expected structure.""" # Should have MatMul nodes for Q, K, V projections matmul_count = self._count_nodes_by_op_type(self_attention_model, "MatMul") @@ -981,15 +961,11 @@ def test_groupnorm_model_structure(self, groupnorm_model: onnx.ModelProto) -> No assert reshape_count >= 2, f"Expected at least 2 Reshape nodes, got {reshape_count}" # Should have ReduceMean for normalization - reducemean_count = sum( - 1 for n in groupnorm_model.graph.node if n.op_type == "ReduceMean" - ) + reducemean_count = sum(1 for n in groupnorm_model.graph.node if n.op_type == "ReduceMean") assert reducemean_count >= 1, f"Expected ReduceMean nodes, got {reducemean_count}" @pytest.mark.skip(reason="GroupNorm fusion requires SD model type and specific pattern") - def test_groupnorm_channels_last_layout( - self, groupnorm_model: onnx.ModelProto - ) -> None: + def test_groupnorm_channels_last_layout(self, groupnorm_model: onnx.ModelProto) -> None: """Test GroupNorm fusion with channels_last (NHWC) layout. Note: Skipped because GroupNorm fusion is only enabled for diff --git a/tests/optim/pipes/test_pipe_fusion_direct.py b/tests/unit/optim/pipes/test_pipe_fusion_direct.py similarity index 98% rename from tests/optim/pipes/test_pipe_fusion_direct.py rename to tests/unit/optim/pipes/test_pipe_fusion_direct.py index 6e0b272d9..277abeecb 100644 --- a/tests/optim/pipes/test_pipe_fusion_direct.py +++ b/tests/unit/optim/pipes/test_pipe_fusion_direct.py @@ -490,9 +490,7 @@ def test_simplified_layernorm_fusion_numeric_equivalence( class TestFusionIntegration: """Integration tests for combined fusion operations.""" - def test_all_fusions_combined( - self, decomposed_layernorm_model: onnx.ModelProto - ) -> None: + def test_all_fusions_combined(self, decomposed_layernorm_model: onnx.ModelProto) -> None: """Test applying multiple fusions in sequence on LayerNorm model.""" from onnxruntime.transformers.fusion_layernorm import FusionLayerNormalization from onnxruntime.transformers.fusion_skiplayernorm import ( @@ -519,12 +517,8 @@ def test_all_fusions_combined( def test_no_fusion_when_pattern_not_present(self) -> None: """Test that fusion does nothing when pattern not present.""" # Create a simple model without any fusion patterns - input_tensor = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 10, 64] - ) - output_tensor = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 10, 64] - ) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 10, 64]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 10, 64]) relu_node = helper.make_node( "Relu", @@ -557,9 +551,7 @@ def test_no_fusion_when_pattern_not_present(self) -> None: f"Expected no change: before={before_count}, after={after_count}" ) - def test_fusion_isolation( - self, bert_attention_model: onnx.ModelProto - ) -> None: + def test_fusion_isolation(self, bert_attention_model: onnx.ModelProto) -> None: """Test that attention fusion doesn't create unrelated ops.""" optimized = apply_attention_fusions(bert_attention_model) op_types = get_op_types(optimized) diff --git a/tests/optim/pipes/test_pipe_graph.py b/tests/unit/optim/pipes/test_pipe_graph.py similarity index 100% rename from tests/optim/pipes/test_pipe_graph.py rename to tests/unit/optim/pipes/test_pipe_graph.py diff --git a/tests/optim/pipes/test_pipe_graph_isolated.py b/tests/unit/optim/pipes/test_pipe_graph_isolated.py similarity index 100% rename from tests/optim/pipes/test_pipe_graph_isolated.py rename to tests/unit/optim/pipes/test_pipe_graph_isolated.py diff --git a/tests/optim/pipes/test_pipe_rewrite.py b/tests/unit/optim/pipes/test_pipe_rewrite.py similarity index 100% rename from tests/optim/pipes/test_pipe_rewrite.py rename to tests/unit/optim/pipes/test_pipe_rewrite.py diff --git a/tests/optim/pipes/test_pipe_surgery.py b/tests/unit/optim/pipes/test_pipe_surgery.py similarity index 95% rename from tests/optim/pipes/test_pipe_surgery.py rename to tests/unit/optim/pipes/test_pipe_surgery.py index fbdaef951..baad76ae6 100644 --- a/tests/optim/pipes/test_pipe_surgery.py +++ b/tests/unit/optim/pipes/test_pipe_surgery.py @@ -187,9 +187,7 @@ def test_process_clamps_causal_mask_extreme_values( else: pytest.fail("causal_mask.1 not found in result model") - def test_process_clamps_mask_value_scalar( - self, causal_mask_model: onnx.ModelProto - ) -> None: + def test_process_clamps_mask_value_scalar(self, causal_mask_model: onnx.ModelProto) -> None: """Verify process clamps scalar mask_value constant.""" pipe = SurgeryPipe() config = SurgeryPipeConfig(clamp_constant_values=True, clamp_min=-1e4, clamp_max=1e4) @@ -205,9 +203,7 @@ def test_process_clamps_mask_value_scalar( else: pytest.fail("mask_value not found in result model") - def test_process_preserves_zero_values( - self, causal_mask_model: onnx.ModelProto - ) -> None: + def test_process_preserves_zero_values(self, causal_mask_model: onnx.ModelProto) -> None: """Verify process preserves zero values in causal mask (only clamps extremes).""" pipe = SurgeryPipe() config = SurgeryPipeConfig(clamp_constant_values=True) @@ -248,9 +244,7 @@ def test_process_does_not_modify_normal_constants( ) break - def test_process_custom_clamp_range( - self, causal_mask_model: onnx.ModelProto - ) -> None: + def test_process_custom_clamp_range(self, causal_mask_model: onnx.ModelProto) -> None: """Verify process uses custom clamp range.""" pipe = SurgeryPipe() config = SurgeryPipeConfig( @@ -268,9 +262,7 @@ def test_process_custom_clamp_range( assert tensor.max() <= 100, f"Max value {tensor.max()} above custom clamp_max" break - def test_process_returns_copy_not_original( - self, causal_mask_model: onnx.ModelProto - ) -> None: + def test_process_returns_copy_not_original(self, causal_mask_model: onnx.ModelProto) -> None: """Verify process returns a copy, not the original model.""" pipe = SurgeryPipe() config = SurgeryPipeConfig(clamp_constant_values=True) @@ -280,9 +272,7 @@ def test_process_returns_copy_not_original( # Result should be a different object assert result is not causal_mask_model - def test_process_model_remains_valid( - self, causal_mask_model: onnx.ModelProto - ) -> None: + def test_process_model_remains_valid(self, causal_mask_model: onnx.ModelProto) -> None: """Verify processed model is still valid ONNX.""" pipe = SurgeryPipe() config = SurgeryPipeConfig(clamp_constant_values=True) diff --git a/tests/optim/test_api.py b/tests/unit/optim/test_api.py similarity index 97% rename from tests/optim/test_api.py rename to tests/unit/optim/test_api.py index d656dc90b..18e191c80 100644 --- a/tests/optim/test_api.py +++ b/tests/unit/optim/test_api.py @@ -427,9 +427,7 @@ def test_model_validation_error_on_invalid_model(self, tmp_path: Path) -> None: with pytest.raises(ModelValidationError): optimize_onnx(invalid_file) - def test_configuration_error_on_invalid_config( - self, model_file: Path, tmp_path: Path - ) -> None: + def test_configuration_error_on_invalid_config(self, model_file: Path, tmp_path: Path) -> None: """Raise ConfigurationError for invalid config values.""" config_path = tmp_path / "config.json" with config_path.open("w") as f: @@ -450,9 +448,7 @@ class TestOptimizeOnnxIntegration: Uses all_patterns_model fixture from conftest.py for real optimization. """ - def test_end_to_end_basic( - self, all_patterns_model: onnx.ModelProto, tmp_path: Path - ) -> None: + def test_end_to_end_basic(self, all_patterns_model: onnx.ModelProto, tmp_path: Path) -> None: """Full pipeline: load, optimize, save.""" output_path = tmp_path / "optimized.onnx" result = optimize_onnx(all_patterns_model, output_path) @@ -479,18 +475,14 @@ def test_end_to_end_with_capabilities( assert isinstance(result, onnx.ModelProto) assert output_path.exists() - def test_model_proto_passthrough( - self, all_patterns_model: onnx.ModelProto - ) -> None: + def test_model_proto_passthrough(self, all_patterns_model: onnx.ModelProto) -> None: """ModelProto input should work without file I/O.""" result = optimize_onnx(all_patterns_model) assert isinstance(result, onnx.ModelProto) # Verify it's a valid model onnx.checker.check_model(result) - def test_returns_optimized_model( - self, all_patterns_model: onnx.ModelProto - ) -> None: + def test_returns_optimized_model(self, all_patterns_model: onnx.ModelProto) -> None: """Verify optimization actually modifies the model.""" # Run optimization result = optimize_onnx(all_patterns_model) diff --git a/tests/optim/test_error_paths.py b/tests/unit/optim/test_error_paths.py similarity index 94% rename from tests/optim/test_error_paths.py rename to tests/unit/optim/test_error_paths.py index c4f260c55..17446888c 100644 --- a/tests/optim/test_error_paths.py +++ b/tests/unit/optim/test_error_paths.py @@ -241,9 +241,7 @@ def mock_init_pipes(cls: type) -> None: class TestOptimizerPipeFailure: """Test Optimizer when a pipe fails (lines 133-135).""" - def test_pipe_processing_failure_raises( - self, simple_model: onnx.ModelProto - ) -> None: + def test_pipe_processing_failure_raises(self, simple_model: onnx.ModelProto) -> None: """Test that pipe failure raises and logs error.""" optimizer = Optimizer() @@ -275,9 +273,7 @@ class TestOptimizerNoPostValidation: Validation is handled by load_onnx (path-based, safe for any size). """ - def test_no_post_validation_check( - self, simple_model: onnx.ModelProto - ) -> None: + def test_no_post_validation_check(self, simple_model: onnx.ModelProto) -> None: """Optimizer.optimize() completes without calling check_model.""" optimizer = Optimizer() @@ -297,9 +293,7 @@ def mock_init_pipes(cls: type) -> None: class TestOptimizerResolveDependencies: """Test _resolve_dependencies default value path (line 195).""" - def test_resolve_dependencies_uses_defaults( - self, simple_model: onnx.ModelProto - ) -> None: + def test_resolve_dependencies_uses_defaults(self, simple_model: onnx.ModelProto) -> None: """Test that missing kwargs use capability defaults.""" optimizer = Optimizer() @@ -310,9 +304,7 @@ def test_resolve_dependencies_uses_defaults( # (the actual values depend on capability definitions) assert isinstance(result, dict) - def test_resolve_dependencies_partial_kwargs( - self, simple_model: onnx.ModelProto - ) -> None: + def test_resolve_dependencies_partial_kwargs(self, simple_model: onnx.ModelProto) -> None: """Test with some kwargs provided, others use defaults.""" optimizer = Optimizer() @@ -326,9 +318,7 @@ def test_resolve_dependencies_partial_kwargs( class TestShapeInferenceFallback: """Test shape inference fallback paths in modelkit.onnx.shape.""" - def test_symbolic_failure_falls_back_to_onnx( - self, simple_model: onnx.ModelProto - ) -> None: + def test_symbolic_failure_falls_back_to_onnx(self, simple_model: onnx.ModelProto) -> None: """Test that symbolic failure falls back to ONNX shape inference.""" from winml.modelkit.onnx.shape import infer_shapes @@ -345,9 +335,7 @@ def test_symbolic_failure_falls_back_to_onnx( result = infer_shapes(simple_model) assert result is simple_model - def test_both_inference_failures_returns_original( - self, simple_model: onnx.ModelProto - ) -> None: + def test_both_inference_failures_returns_original(self, simple_model: onnx.ModelProto) -> None: """Test that both failures return original model.""" from winml.modelkit.onnx.shape import infer_shapes @@ -364,9 +352,7 @@ def test_both_inference_failures_returns_original( result = infer_shapes(simple_model) assert result is simple_model - def test_symbolic_success_skips_onnx( - self, simple_model: onnx.ModelProto - ) -> None: + def test_symbolic_success_skips_onnx(self, simple_model: onnx.ModelProto) -> None: """Test that symbolic success skips ONNX inference.""" from winml.modelkit.onnx.shape import infer_shapes @@ -387,9 +373,7 @@ def test_symbolic_success_skips_onnx( class TestOptimizerPipeSkipping: """Test pipe skipping when should_process returns False.""" - def test_pipe_skipped_when_no_capabilities( - self, simple_model: onnx.ModelProto - ) -> None: + def test_pipe_skipped_when_no_capabilities(self, simple_model: onnx.ModelProto) -> None: """Test that pipes are skipped when should_process returns False.""" optimizer = Optimizer() diff --git a/tests/optim/test_registry_cli.py b/tests/unit/optim/test_registry_cli.py similarity index 100% rename from tests/optim/test_registry_cli.py rename to tests/unit/optim/test_registry_cli.py diff --git a/tests/optim/verification/__init__.py b/tests/unit/optim/verification/__init__.py similarity index 97% rename from tests/optim/verification/__init__.py rename to tests/unit/optim/verification/__init__.py index 68112c85c..c748a6ba9 100644 --- a/tests/optim/verification/__init__.py +++ b/tests/unit/optim/verification/__init__.py @@ -15,7 +15,7 @@ 4. Numeric Verification: Outputs must match within tolerance (optional) Usage: - from tests.optim.verification import verify_capability_effect + from tests.unit.optim.verification import verify_capability_effect verify_capability_effect( model_before=baseline, diff --git a/tests/optim/verification/conftest.py b/tests/unit/optim/verification/conftest.py similarity index 100% rename from tests/optim/verification/conftest.py rename to tests/unit/optim/verification/conftest.py diff --git a/tests/optim/verification/test_mandatory_optimizations.py b/tests/unit/optim/verification/test_mandatory_optimizations.py similarity index 100% rename from tests/optim/verification/test_mandatory_optimizations.py rename to tests/unit/optim/verification/test_mandatory_optimizations.py diff --git a/tests/optim/verification/test_ort_disable_limit_empirical.py b/tests/unit/optim/verification/test_ort_disable_limit_empirical.py similarity index 93% rename from tests/optim/verification/test_ort_disable_limit_empirical.py rename to tests/unit/optim/verification/test_ort_disable_limit_empirical.py index d7d88fb5b..3d03ee2bf 100644 --- a/tests/optim/verification/test_ort_disable_limit_empirical.py +++ b/tests/unit/optim/verification/test_ort_disable_limit_empirical.py @@ -47,12 +47,8 @@ def create_relu_model() -> onnx.ModelProto: This creates: input -> Relu -> output ORT has a ReluFusion optimizer that we can target. """ - input_tensor = helper.make_tensor_value_info( - "input", TensorProto.FLOAT, [1, 64, 224, 224] - ) - output_tensor = helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [1, 64, 224, 224] - ) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 64, 224, 224]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 64, 224, 224]) relu_node = helper.make_node( "Relu", @@ -217,9 +213,7 @@ def run_optimization_with_disabled_list( ) try: - _ = ort.InferenceSession( - str(input_file), sess_opts, providers=["CPUExecutionProvider"] - ) + _ = ort.InferenceSession(str(input_file), sess_opts, providers=["CPUExecutionProvider"]) success = True except Exception: success = False @@ -235,9 +229,7 @@ class TestORTDisableLimit: def test_ort_accepts_empty_disable_list(self, temp_dir: Path) -> None: """Verify ORT works with no disabled optimizers.""" model = create_relu_model() - optimized, success = run_optimization_with_disabled_list( - model, [], temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, [], temp_dir) assert success, "ORT should accept empty disable list" assert optimized is not None @@ -245,9 +237,7 @@ def test_ort_accepts_empty_disable_list(self, temp_dir: Path) -> None: def test_ort_accepts_single_disabled_optimizer(self, temp_dir: Path) -> None: """Verify ORT works with a single disabled optimizer.""" model = create_relu_model() - optimized, success = run_optimization_with_disabled_list( - model, ["GeluFusion"], temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, ["GeluFusion"], temp_dir) assert success, "ORT should accept single disabled optimizer" assert optimized is not None @@ -257,9 +247,7 @@ def test_ort_accepts_10_disabled_optimizers(self, temp_dir: Path) -> None: model = create_relu_model() optimizers = get_known_ort_optimizers()[:10] - optimized, success = run_optimization_with_disabled_list( - model, optimizers, temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, optimizers, temp_dir) assert success, "ORT should accept 10 disabled optimizers" assert optimized is not None @@ -275,9 +263,7 @@ def test_ort_accepts_32_disabled_optimizers(self, temp_dir: Path) -> None: assert len(optimizers) == 32, "Need exactly 32 optimizers for this test" - optimized, success = run_optimization_with_disabled_list( - model, optimizers, temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, optimizers, temp_dir) assert success, "ORT should accept 32 disabled optimizers" assert optimized is not None @@ -294,9 +280,7 @@ def test_ort_behavior_with_40_disabled_optimizers(self, temp_dir: Path) -> None: assert len(optimizers) == 40, "Need 40 optimizers for this test" - optimized, success = run_optimization_with_disabled_list( - model, optimizers, temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, optimizers, temp_dir) # Document the actual behavior print(f"\n[40 items] ORT session creation: {'SUCCESS' if success else 'FAILED'}") @@ -312,9 +296,7 @@ def test_ort_behavior_with_50_disabled_optimizers(self, temp_dir: Path) -> None: assert len(optimizers) >= 50, f"Need 50 optimizers, have {len(optimizers)}" optimizers = optimizers[:50] - optimized, success = run_optimization_with_disabled_list( - model, optimizers, temp_dir - ) + optimized, success = run_optimization_with_disabled_list(model, optimizers, temp_dir) print(f"\n[50 items] ORT session creation: {'SUCCESS' if success else 'FAILED'}") @@ -338,9 +320,7 @@ def test_character_limit_2048(self, temp_dir: Path) -> None: total_chars = len(";".join(fake_optimizers)) print(f"\n[Char limit test] Total characters: {total_chars}") - _optimized, success = run_optimization_with_disabled_list( - model, fake_optimizers, temp_dir - ) + _optimized, success = run_optimization_with_disabled_list(model, fake_optimizers, temp_dir) print(f"[Char limit test] ORT session creation: {'SUCCESS' if success else 'FAILED'}") @@ -357,16 +337,12 @@ def test_ort_actually_disables_specified_optimizers(self, temp_dir: Path) -> Non original_node_count = len(model.graph.node) # First, optimize without disabling anything - optimized_full, _ = run_optimization_with_disabled_list( - model, [], temp_dir - ) + optimized_full, _ = run_optimization_with_disabled_list(model, [], temp_dir) full_opt_nodes = len(optimized_full.graph.node) # Now optimize with maximum disabled optimizers all_optimizers = get_known_ort_optimizers() - optimized_disabled, _ = run_optimization_with_disabled_list( - model, all_optimizers, temp_dir - ) + optimized_disabled, _ = run_optimization_with_disabled_list(model, all_optimizers, temp_dir) disabled_opt_nodes = len(optimized_disabled.graph.node) print("\n[Disable effectiveness test]") @@ -395,9 +371,7 @@ def test_incremental_limit_discovery(self, temp_dir: Path) -> None: break subset = optimizers[:count] - _, success = run_optimization_with_disabled_list( - model, subset, temp_dir - ) + _, success = run_optimization_with_disabled_list(model, subset, temp_dir) results.append((count, success)) print("\n[Incremental limit discovery]") diff --git a/tests/optim/verification/test_ort_mapping.py b/tests/unit/optim/verification/test_ort_mapping.py similarity index 100% rename from tests/optim/verification/test_ort_mapping.py rename to tests/unit/optim/verification/test_ort_mapping.py diff --git a/tests/onnx/__init__.py b/tests/unit/optracing/__init__.py similarity index 100% rename from tests/onnx/__init__.py rename to tests/unit/optracing/__init__.py diff --git a/tests/optracing/fixtures/optrace_resnet50.csv b/tests/unit/optracing/fixtures/optrace_resnet50.csv similarity index 100% rename from tests/optracing/fixtures/optrace_resnet50.csv rename to tests/unit/optracing/fixtures/optrace_resnet50.csv diff --git a/tests/optracing/fixtures/qhas_resnet50.json b/tests/unit/optracing/fixtures/qhas_resnet50.json similarity index 99% rename from tests/optracing/fixtures/qhas_resnet50.json rename to tests/unit/optracing/fixtures/qhas_resnet50.json index 2198629d2..29e48d9b5 100644 --- a/tests/optracing/fixtures/qhas_resnet50.json +++ b/tests/unit/optracing/fixtures/qhas_resnet50.json @@ -386,4 +386,4 @@ "data": [] } } -} \ No newline at end of file +} diff --git a/tests/optracing/test_csv_parser.py b/tests/unit/optracing/test_csv_parser.py similarity index 99% rename from tests/optracing/test_csv_parser.py rename to tests/unit/optracing/test_csv_parser.py index 9a0afde99..14f6aa7cc 100644 --- a/tests/optracing/test_csv_parser.py +++ b/tests/unit/optracing/test_csv_parser.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test QNN profiling CSV parser.""" + from pathlib import Path from winml.modelkit.optracing.qnn.csv_parser import parse_qnn_profiling_csv diff --git a/tests/optracing/test_detection.py b/tests/unit/optracing/test_detection.py similarity index 99% rename from tests/optracing/test_detection.py rename to tests/unit/optracing/test_detection.py index 7f3934153..5670d1105 100644 --- a/tests/optracing/test_detection.py +++ b/tests/unit/optracing/test_detection.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test QNN EP detection for op-tracing.""" + from winml.modelkit.optracing import is_qnn_profiling_available diff --git a/tests/optracing/test_integration.py b/tests/unit/optracing/test_integration.py similarity index 94% rename from tests/optracing/test_integration.py rename to tests/unit/optracing/test_integration.py index ba12f986d..015bec03f 100644 --- a/tests/optracing/test_integration.py +++ b/tests/unit/optracing/test_integration.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Integration tests using real QNN profiling data.""" + import json from pathlib import Path @@ -27,9 +28,7 @@ def test_basic_pipeline_csv_to_json(tmp_path): op_path=op["name"], # CSV doesn't distinguish type vs path op_id=op["op_id"], duration_us=op["cycles"], # keep raw cycles as duration placeholder - percent_of_total=( - (op["cycles"] / total_cycles * 100) if total_cycles else 0 - ), + percent_of_total=((op["cycles"] / total_cycles * 100) if total_cycles else 0), ) for op in csv_data["operators"] ] @@ -92,9 +91,7 @@ def test_detail_pipeline_qhas_to_json(tmp_path): assert data["metadata"]["tracing_level"] == "detail" assert data["summary"]["time_us"] > 0 # At least one operator should have DRAM read data populated - assert any( - op["dram_read_bytes"] is not None for op in data["operators"] - ) + assert any(op["dram_read_bytes"] is not None for op in data["operators"]) def test_json_schema_basic(): @@ -103,9 +100,7 @@ def test_json_schema_basic(): model="test", device="npu", tracing_level="basic", - operators=[ - OperatorMetrics(name="Conv", op_path="/conv", duration_us=10.0) - ], + operators=[OperatorMetrics(name="Conv", op_path="/conv", duration_us=10.0)], ) data = result.to_dict() @@ -232,6 +227,4 @@ def test_cross_parser_top_operator_is_conv(): # The top op for ResNet should contain "conv" (the large 7x7 convolution) assert "conv" in top_csv, f"Expected 'conv' in top CSV op: {top_csv}" - assert "conv" in top_qhas_name, ( - f"Expected 'conv' in top QHAS op: {top_qhas_name}" - ) + assert "conv" in top_qhas_name, f"Expected 'conv' in top QHAS op: {top_qhas_name}" diff --git a/tests/optracing/test_perf_optracing_cli.py b/tests/unit/optracing/test_perf_optracing_cli.py similarity index 99% rename from tests/optracing/test_perf_optracing_cli.py rename to tests/unit/optracing/test_perf_optracing_cli.py index 22dc58082..0e532b7b5 100644 --- a/tests/optracing/test_perf_optracing_cli.py +++ b/tests/unit/optracing/test_perf_optracing_cli.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Tests for the --op-tracing CLI option on wmk perf.""" + from __future__ import annotations from unittest.mock import patch diff --git a/tests/optracing/test_qhas_parser.py b/tests/unit/optracing/test_qhas_parser.py similarity index 99% rename from tests/optracing/test_qhas_parser.py rename to tests/unit/optracing/test_qhas_parser.py index 728a475c4..9a767991a 100644 --- a/tests/optracing/test_qhas_parser.py +++ b/tests/unit/optracing/test_qhas_parser.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test QHAS JSON parser.""" + import json from pathlib import Path diff --git a/tests/optracing/test_qnn_profiler.py b/tests/unit/optracing/test_qnn_profiler.py similarity index 87% rename from tests/optracing/test_qnn_profiler.py rename to tests/unit/optracing/test_qnn_profiler.py index 973dafdcb..5cbfb3e70 100644 --- a/tests/optracing/test_qnn_profiler.py +++ b/tests/unit/optracing/test_qnn_profiler.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test QNN profiler and viewer with mocked ORT (no QNN hardware needed).""" + from __future__ import annotations from pathlib import Path @@ -30,9 +31,7 @@ def test_qnn_profiler_creates_session_options(): """Verify session options are configured correctly.""" - profiler = QNNProfiler( - Path("model.onnx"), output_dir=Path("out"), level="basic" - ) + profiler = QNNProfiler(Path("model.onnx"), output_dir=Path("out"), level="basic") mock_ort = MagicMock() mock_options = MagicMock() mock_ort.SessionOptions.return_value = mock_options @@ -54,9 +53,7 @@ def test_qnn_profiler_creates_session_options(): def test_qnn_profiler_provider_options_basic(): """Verify provider options for basic mode (profiling_level=detailed).""" - profiler = QNNProfiler( - Path("model.onnx"), output_dir=Path("out"), level="basic" - ) + profiler = QNNProfiler(Path("model.onnx"), output_dir=Path("out"), level="basic") opts = profiler._build_provider_options(Path("out/profiling.csv")) assert len(opts) == 1 @@ -71,9 +68,7 @@ def test_qnn_profiler_provider_options_basic(): def test_qnn_profiler_provider_options_detail(): """Verify provider options for detail mode (profiling_level=optrace).""" - profiler = QNNProfiler( - Path("model.onnx"), output_dir=Path("out"), level="detail" - ) + profiler = QNNProfiler(Path("model.onnx"), output_dir=Path("out"), level="detail") opts = profiler._build_provider_options(Path("out/profiling.csv")) po = opts[0] @@ -172,24 +167,21 @@ def write_csv_on_del(): # Simulate CSV being flushed when session is deleted. mock_ort.InferenceSession.return_value = mock_session - with patch.dict("sys.modules", {"onnxruntime": mock_ort}), patch( - "winml.modelkit.optracing.qnn.profiler.QNNProfiler._collect_results" - ) as mock_collect: + with ( + patch.dict("sys.modules", {"onnxruntime": mock_ort}), + patch("winml.modelkit.optracing.qnn.profiler.QNNProfiler._collect_results") as mock_collect, + ): # Write the CSV before _collect_results is called. write_csv_on_del() - profiler = QNNProfiler( - model_path, output_dir=output_dir, level="basic" - ) + profiler = QNNProfiler(model_path, output_dir=output_dir, level="basic") # Instead of running the full flow (which needs real ORT import), # test the collect_results path directly. mock_collect.return_value = MagicMock() # Verify session creation was called correctly via builder methods. profiler._build_session_options(mock_ort) - po = profiler._build_provider_options( - output_dir / "profiling_output.csv" - ) + po = profiler._build_provider_options(output_dir / "profiling_output.csv") assert po[0]["profiling_level"] == "detailed" # Now test the CSV parsing path directly. @@ -208,12 +200,8 @@ def write_csv_on_del(): def test_qnn_profiler_empty_artifacts(tmp_path): """Profiler returns empty result when no artifacts exist.""" - profiler = QNNProfiler( - Path("model.onnx"), output_dir=tmp_path, level="basic" - ) - result = profiler._collect_results( - tmp_path / "nonexistent.csv", iterations=5 - ) + profiler = QNNProfiler(Path("model.onnx"), output_dir=tmp_path, level="basic") + result = profiler._collect_results(tmp_path / "nonexistent.csv", iterations=5) assert result.model == "model.onnx" assert len(result.operators) == 0 assert result.num_samples == 0 @@ -269,12 +257,8 @@ def test_find_qnn_sdk_from_common_path(monkeypatch, tmp_path): def test_run_basic_viewer_no_sdk(tmp_path): """Basic viewer returns None when SDK is not found.""" - with patch( - "winml.modelkit.optracing.qnn.viewer._find_viewer_exe", return_value=None - ): - result = run_basic_viewer( - tmp_path / "log.qnn", tmp_path / "output.csv" - ) + with patch("winml.modelkit.optracing.qnn.viewer._find_viewer_exe", return_value=None): + result = run_basic_viewer(tmp_path / "log.qnn", tmp_path / "output.csv") assert result is None @@ -285,16 +269,17 @@ def test_run_basic_viewer_success(tmp_path): def fake_run(cmd, **kwargs): output_csv.write_text("header\ndata", encoding="utf-8") - with patch( - "winml.modelkit.optracing.qnn.viewer._find_viewer_exe", - return_value=Path("/fake/viewer.exe"), - ), patch( - "winml.modelkit.optracing.qnn.viewer.subprocess.run", - side_effect=fake_run, + with ( + patch( + "winml.modelkit.optracing.qnn.viewer._find_viewer_exe", + return_value=Path("/fake/viewer.exe"), + ), + patch( + "winml.modelkit.optracing.qnn.viewer.subprocess.run", + side_effect=fake_run, + ), ): - result = run_basic_viewer( - tmp_path / "log.qnn", output_csv - ) + result = run_basic_viewer(tmp_path / "log.qnn", output_csv) assert result == output_csv @@ -326,12 +311,15 @@ def test_run_qhas_viewer_writes_config(tmp_path): def fake_run(cmd, **kwargs): output.write_text("{}", encoding="utf-8") - with patch( - "winml.modelkit.optracing.qnn.viewer._find_viewer_exe", - return_value=Path("/fake/viewer.exe"), - ), patch( - "winml.modelkit.optracing.qnn.viewer.subprocess.run", - side_effect=fake_run, + with ( + patch( + "winml.modelkit.optracing.qnn.viewer._find_viewer_exe", + return_value=Path("/fake/viewer.exe"), + ), + patch( + "winml.modelkit.optracing.qnn.viewer.subprocess.run", + side_effect=fake_run, + ), ): run_qhas_viewer( tmp_path / "log.qnn", diff --git a/tests/optracing/test_registry.py b/tests/unit/optracing/test_registry.py similarity index 99% rename from tests/optracing/test_registry.py rename to tests/unit/optracing/test_registry.py index 38c8ca2bb..2a9e86511 100644 --- a/tests/optracing/test_registry.py +++ b/tests/unit/optracing/test_registry.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test OpTracer registry: registration, lookup, and EP pattern matching.""" + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/tests/optracing/test_report.py b/tests/unit/optracing/test_report.py similarity index 99% rename from tests/optracing/test_report.py rename to tests/unit/optracing/test_report.py index 7e20b8c4a..483cdb990 100644 --- a/tests/optracing/test_report.py +++ b/tests/unit/optracing/test_report.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Tests for op-tracing console report and JSON file output.""" + import json from io import StringIO @@ -21,6 +22,7 @@ # Fixtures # --------------------------------------------------------------------------- + def _make_basic_result() -> OpTraceResult: """Create a basic-mode OpTraceResult with sample operators.""" return OpTraceResult( @@ -107,6 +109,7 @@ def _make_empty_result() -> OpTraceResult: # display_op_trace_report — basic mode # --------------------------------------------------------------------------- + class TestDisplayBasicReport: def test_renders_without_error(self): """Basic report renders without raising.""" @@ -163,6 +166,7 @@ def test_default_console_created(self): # display_op_trace_report — detail mode # --------------------------------------------------------------------------- + class TestDisplayDetailReport: def test_renders_without_error(self): """Detail report renders without raising.""" @@ -220,6 +224,7 @@ def test_vtcm_hit_ratio_displayed(self): # display_op_trace_report — empty operators # --------------------------------------------------------------------------- + class TestDisplayEmptyReport: def test_empty_operators_renders(self): """Report with no operators renders without error.""" @@ -242,6 +247,7 @@ def test_empty_operators_shows_no_data_message(self): # write_op_trace_json # --------------------------------------------------------------------------- + class TestWriteOpTraceJson: def test_creates_file(self, tmp_path): """JSON file is created at the specified path.""" @@ -306,6 +312,7 @@ def test_empty_result_json(self, tmp_path): # _format_bytes # --------------------------------------------------------------------------- + class TestFormatBytes: @pytest.mark.parametrize( ("value", "expected"), diff --git a/tests/optracing/test_result.py b/tests/unit/optracing/test_result.py similarity index 95% rename from tests/optracing/test_result.py rename to tests/unit/optracing/test_result.py index 8ef216265..cbd17ac47 100644 --- a/tests/optracing/test_result.py +++ b/tests/unit/optracing/test_result.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- """Test OpTraceResult dataclass and serialization.""" + import json from winml.modelkit.optracing.result import OperatorMetrics, OpTraceResult @@ -38,9 +39,7 @@ def test_op_trace_result_to_dict(): model="resnet-50", device="npu", tracing_level="basic", - operators=[ - OperatorMetrics(name="Conv2d", op_path="/conv", duration_us=10.0) - ], + operators=[OperatorMetrics(name="Conv2d", op_path="/conv", duration_us=10.0)], ) d = result.to_dict() assert d["metadata"]["model"] == "resnet-50" diff --git a/tests/session/conftest.py b/tests/unit/session/conftest.py similarity index 100% rename from tests/session/conftest.py rename to tests/unit/session/conftest.py diff --git a/tests/session/test_compile_qairt_bin.py b/tests/unit/session/test_compile_qairt_bin.py similarity index 100% rename from tests/session/test_compile_qairt_bin.py rename to tests/unit/session/test_compile_qairt_bin.py diff --git a/tests/session/test_ep_monitor.py b/tests/unit/session/test_ep_monitor.py similarity index 94% rename from tests/session/test_ep_monitor.py rename to tests/unit/session/test_ep_monitor.py index eb007d937..20bce5c75 100644 --- a/tests/session/test_ep_monitor.py +++ b/tests/unit/session/test_ep_monitor.py @@ -193,16 +193,28 @@ def test_get_command_submissions_sums_across_contexts(self): client = XrtSmiClient() contexts = [ HwContext( - pid=100, context_id=0, status="Active", - command_submissions=50, command_completions=50, - gops="N/A", fps="N/A", latency="N/A", - priority="Low", errors=0, + pid=100, + context_id=0, + status="Active", + command_submissions=50, + command_completions=50, + gops="N/A", + fps="N/A", + latency="N/A", + priority="Low", + errors=0, ), HwContext( - pid=100, context_id=1, status="Idle", - command_submissions=30, command_completions=30, - gops="N/A", fps="N/A", latency="N/A", - priority="Low", errors=0, + pid=100, + context_id=1, + status="Idle", + command_submissions=30, + command_completions=30, + gops="N/A", + fps="N/A", + latency="N/A", + priority="Low", + errors=0, ), ] with patch.object(client, "get_hw_contexts", return_value=contexts): @@ -871,13 +883,15 @@ class TestLiveMonitorDisplay: def test_render_status_warmup_phase(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=110, warmup=10, model_id="test", device="npu" - ) + display = LiveMonitorDisplay(total_iterations=110, warmup=10, model_id="test", device="npu") status = display._render_status( - iteration=5, latency_ms=1.0, util_samples=[50.0], - memory_local_mb=10.0, memory_shared_mb=20.0, - cpu_pct=5.0, ram_mb=8000.0, + iteration=5, + latency_ms=1.0, + util_samples=[50.0], + memory_local_mb=10.0, + memory_shared_mb=20.0, + cpu_pct=5.0, + ram_mb=8000.0, ) assert "Warmup" in status assert "npu" in status.lower() or "Device" in status @@ -885,13 +899,15 @@ def test_render_status_warmup_phase(self): def test_render_status_benchmark_phase(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=110, warmup=10, model_id="test", device="npu" - ) + display = LiveMonitorDisplay(total_iterations=110, warmup=10, model_id="test", device="npu") status = display._render_status( - iteration=50, latency_ms=2.0, util_samples=[80.0, 90.0], - memory_local_mb=31.0, memory_shared_mb=43.0, - cpu_pct=15.0, ram_mb=40000.0, + iteration=50, + latency_ms=2.0, + util_samples=[80.0, 90.0], + memory_local_mb=31.0, + memory_shared_mb=43.0, + cpu_pct=15.0, + ram_mb=40000.0, ) assert "Iter" in status assert "Throughput" in status @@ -900,45 +916,45 @@ def test_render_status_benchmark_phase(self): def test_render_status_zero_latency_no_crash(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=10, warmup=0, model_id="test", device="cpu" - ) + display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # latency_ms=0 should not cause division by zero status = display._render_status( - iteration=1, latency_ms=0.0, util_samples=[], + iteration=1, + latency_ms=0.0, + util_samples=[], ) assert "Throughput" in status def test_render_status_empty_samples(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=10, warmup=0, model_id="test", device="cpu" - ) + display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") status = display._render_status( - iteration=1, latency_ms=1.0, util_samples=[], + iteration=1, + latency_ms=1.0, + util_samples=[], ) assert "0.0%" in status # NPU should show 0.0% def test_update_noop_when_live_is_none(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=10, warmup=0, model_id="test", device="cpu" - ) + display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # _live is None (not entered context) — should not crash display.update( - iteration=1, latency_ms=1.0, util_samples=[50.0], + iteration=1, + latency_ms=1.0, + util_samples=[50.0], ) def test_print_final_snapshot_is_noop(self): from winml.modelkit.commands.live_chart import LiveMonitorDisplay - display = LiveMonitorDisplay( - total_iterations=10, warmup=0, model_id="test", device="cpu" - ) + display = LiveMonitorDisplay(total_iterations=10, warmup=0, model_id="test", device="cpu") # Should not crash or print anything display.print_final_snapshot( - util_samples=[50.0], memory_mb=10.0, - latency_ms=1.0, hw_dict={}, + util_samples=[50.0], + memory_mb=10.0, + latency_ms=1.0, + hw_dict={}, ) diff --git a/tests/session/test_is_compatible.py b/tests/unit/session/test_is_compatible.py similarity index 87% rename from tests/session/test_is_compatible.py rename to tests/unit/session/test_is_compatible.py index 475d594e8..39a997aa0 100644 --- a/tests/session/test_is_compatible.py +++ b/tests/unit/session/test_is_compatible.py @@ -37,9 +37,7 @@ def cpu_session(tmp_path: Path) -> WinMLSession: [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 4])], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4])], ) - model = helper.make_model( - graph, opset_imports=[helper.make_opsetid("", 17)] - ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) model.ir_version = 8 import onnx @@ -84,9 +82,7 @@ def test_add_compatible_with_cpu(self, cpu_session: WinMLSession) -> None: assert cpu_session.is_compatible(node, graph) is True - def test_unknown_op_incompatible( - self, cpu_session: WinMLSession - ) -> None: + def test_unknown_op_incompatible(self, cpu_session: WinMLSession) -> None: """Nonexistent op should be incompatible.""" node = helper.make_node( "CompletelyFakeOp12345", @@ -125,34 +121,24 @@ def test_without_graph_context_still_works_for_valid_op( assert cpu_session.is_compatible(node) is True - def test_node_with_no_inputs_returns_false( - self, cpu_session: WinMLSession - ) -> None: + def test_node_with_no_inputs_returns_false(self, cpu_session: WinMLSession) -> None: """Node with empty input list should return False.""" node = helper.make_node("Relu", inputs=[], outputs=["Y"]) assert cpu_session.is_compatible(node) is False - def test_node_with_no_outputs_returns_false( - self, cpu_session: WinMLSession - ) -> None: + def test_node_with_no_outputs_returns_false(self, cpu_session: WinMLSession) -> None: """Node with empty output list should return False.""" node = helper.make_node("Relu", inputs=["X"], outputs=[]) assert cpu_session.is_compatible(node) is False - def test_graph_context_resolves_shapes( - self, cpu_session: WinMLSession - ) -> None: + def test_graph_context_resolves_shapes(self, cpu_session: WinMLSession) -> None: """When graph is provided, shapes come from graph value_info.""" node = helper.make_node("Relu", inputs=["X"], outputs=["Y"]) - x_info = helper.make_tensor_value_info( - "X", TensorProto.FLOAT, [2, 8] - ) - y_info = helper.make_tensor_value_info( - "Y", TensorProto.FLOAT, [2, 8] - ) + x_info = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 8]) + y_info = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 8]) graph = helper.make_graph([node], "test", [x_info], [y_info]) # Should work with the real shapes from graph diff --git a/tests/session/test_qairt_session.py b/tests/unit/session/test_qairt_session.py similarity index 100% rename from tests/session/test_qairt_session.py rename to tests/unit/session/test_qairt_session.py diff --git a/tests/session/test_winml_session.py b/tests/unit/session/test_winml_session.py similarity index 100% rename from tests/session/test_winml_session.py rename to tests/unit/session/test_winml_session.py diff --git a/tests/sysinfo/test_device.py b/tests/unit/sysinfo/test_device.py similarity index 94% rename from tests/sysinfo/test_device.py rename to tests/unit/sysinfo/test_device.py index 5a3f2bd11..67a6c9a7c 100644 --- a/tests/sysinfo/test_device.py +++ b/tests/unit/sysinfo/test_device.py @@ -175,11 +175,13 @@ def test_resolve_device_auto_npu_with_ep(self) -> None: ), patch( "winml.modelkit.sysinfo.device._get_available_eps", - return_value=frozenset({ - "QNNExecutionProvider", - "DmlExecutionProvider", - "CPUExecutionProvider", - }), + return_value=frozenset( + { + "QNNExecutionProvider", + "DmlExecutionProvider", + "CPUExecutionProvider", + } + ), ), ): device, available = resolve_device("auto") @@ -196,10 +198,12 @@ def test_resolve_device_auto_npu_without_ep(self) -> None: ), patch( "winml.modelkit.sysinfo.device._get_available_eps", - return_value=frozenset({ - "DmlExecutionProvider", - "CPUExecutionProvider", - }), + return_value=frozenset( + { + "DmlExecutionProvider", + "CPUExecutionProvider", + } + ), ), ): device, available = resolve_device("auto") @@ -249,10 +253,12 @@ def test_resolve_device_explicit_valid(self) -> None: ), patch( "winml.modelkit.sysinfo.device._get_available_eps", - return_value=frozenset({ - "DmlExecutionProvider", - "CPUExecutionProvider", - }), + return_value=frozenset( + { + "DmlExecutionProvider", + "CPUExecutionProvider", + } + ), ), ): device, available = resolve_device("gpu") diff --git a/tests/utils/test_config_utils.py b/tests/unit/utils/test_config_utils.py similarity index 100% rename from tests/utils/test_config_utils.py rename to tests/unit/utils/test_config_utils.py