# Professional Testing with pytest

**Chapter 10 - Learning Python, 5th Edition**

Testing is fundamental to professional Python development. `pytest` is the
de facto standard testing framework, offering powerful features like fixtures,
parametrization, and rich assertion introspection. This notebook covers test
structure, assertion patterns, fixtures, mocking, and best practices.

## Test Structure and Naming Conventions

pytest discovers tests automatically based on naming conventions:
- Test files: `test_*.py` or `*_test.py`
- Test functions: `test_*`
- Test classes: `Test*` (no `__init__` method)

The standard pattern is **Arrange-Act-Assert (AAA)**:
1. **Arrange**: Set up preconditions and inputs
2. **Act**: Execute the code under test
3. **Assert**: Verify the expected outcome

Below we define small modules inline and write tests against them. In a real
project, code and tests live in separate files.

In [None]:
# Code under test: a simple calculator module
from __future__ import annotations
from dataclasses import dataclass, field


class CalculatorError(Exception):
    """Base exception for calculator operations."""


class DivisionByZeroError(CalculatorError):
    """Raised when dividing by zero."""


@dataclass
class Calculator:
    """A calculator with history tracking."""

    history: list[str] = field(default_factory=list)
    precision: int = 2

    def add(self, a: float, b: float) -> float:
        result = round(a + b, self.precision)
        self.history.append(f"{a} + {b} = {result}")
        return result

    def subtract(self, a: float, b: float) -> float:
        result = round(a - b, self.precision)
        self.history.append(f"{a} - {b} = {result}")
        return result

    def multiply(self, a: float, b: float) -> float:
        result = round(a * b, self.precision)
        self.history.append(f"{a} * {b} = {result}")
        return result

    def divide(self, a: float, b: float) -> float:
        if b == 0:
            raise DivisionByZeroError(f"Cannot divide {a} by zero")
        result = round(a / b, self.precision)
        self.history.append(f"{a} / {b} = {result}")
        return result

    def clear_history(self) -> None:
        self.history.clear()


# Quick smoke test
calc = Calculator()
print(f"2 + 3 = {calc.add(2, 3)}")
print(f"10 / 3 = {calc.divide(10, 3)}")
print(f"History: {calc.history}")

## Assertions and Useful Assertion Patterns

pytest uses plain `assert` statements with rich introspection -- when an
assertion fails, pytest shows the actual values of both sides. No need for
`assertEqual`, `assertTrue`, etc. Here we demonstrate tests using the
AAA pattern with various assertion techniques.

In [None]:
# Test functions following AAA pattern (Arrange-Act-Assert)
# In a real project these would be in test_calculator.py

def test_add_integers() -> None:
    """Test addition of two integers."""
    # Arrange
    calc = Calculator()

    # Act
    result = calc.add(2, 3)

    # Assert
    assert result == 5


def test_add_floats_precision() -> None:
    """Test that floating point results are rounded to specified precision."""
    calc = Calculator(precision=3)
    result = calc.add(1.1111, 2.2222)
    assert result == 3.333  # Rounded to 3 decimal places


def test_divide_returns_float() -> None:
    """Test that division returns a float."""
    calc = Calculator()
    result = calc.divide(10, 3)
    assert isinstance(result, float)
    assert result == 3.33  # default precision=2


def test_history_tracking() -> None:
    """Test that operations are recorded in history."""
    calc = Calculator()
    calc.add(1, 2)
    calc.multiply(3, 4)

    assert len(calc.history) == 2
    assert "1 + 2 = 3" in calc.history
    assert "3 * 4 = 12" in calc.history


def test_clear_history() -> None:
    """Test that clear_history empties the history list."""
    calc = Calculator()
    calc.add(1, 2)
    calc.clear_history()
    assert calc.history == []


# Approximate equality for floating point (pytest.approx equivalent)
def test_float_comparison() -> None:
    """Demonstrate floating-point comparison challenges."""
    # This is dangerous: 0.1 + 0.2 != 0.3 due to float representation
    assert 0.1 + 0.2 != 0.3  # Surprising but true!

    # In pytest, use pytest.approx:
    # assert 0.1 + 0.2 == pytest.approx(0.3)
    # Here we show the manual approach:
    assert abs((0.1 + 0.2) - 0.3) < 1e-9


# Run all tests and report
tests = [
    test_add_integers,
    test_add_floats_precision,
    test_divide_returns_float,
    test_history_tracking,
    test_clear_history,
    test_float_comparison,
]

for test_func in tests:
    try:
        test_func()
        print(f"  PASSED: {test_func.__name__}")
    except AssertionError as e:
        print(f"  FAILED: {test_func.__name__} - {e}")

## pytest Fixtures

Fixtures provide reusable setup/teardown logic. They replace the traditional
`setUp`/`tearDown` from unittest with a more flexible, composable model.

Key fixture concepts:
- **Scope**: `function` (default), `class`, `module`, `session`
- **yield fixtures**: Code after `yield` runs as teardown
- **conftest.py**: Shared fixtures available to all tests in the directory
- **Fixture composition**: Fixtures can depend on other fixtures

Below we simulate fixture behavior since we cannot run pytest directly in a notebook.

In [None]:
# Simulating pytest fixtures as they would appear in test files
from typing import Generator
import tempfile
import os


# --- What fixtures look like in real pytest code ---

# @pytest.fixture
def calculator_fixture() -> Calculator:
    """Provide a fresh Calculator instance for each test."""
    return Calculator(precision=2)


# @pytest.fixture
def preloaded_calculator_fixture() -> Calculator:
    """Provide a Calculator with pre-existing history."""
    calc = Calculator()
    calc.add(10, 20)
    calc.multiply(5, 5)
    return calc


# @pytest.fixture  (yield fixture with teardown)
def temp_file_fixture() -> Generator[str, None, None]:
    """Provide a temporary file path; clean up after test."""
    fd, path = tempfile.mkstemp(suffix=".txt")
    os.close(fd)
    print(f"    [Setup] Created temp file: {path}")

    yield path  # This is where the test runs

    # Teardown: runs even if the test fails
    if os.path.exists(path):
        os.unlink(path)
        print(f"    [Teardown] Removed temp file: {path}")


# --- Using fixtures in tests ---
# In real pytest, fixtures are injected as function parameters:
#   def test_add(calculator: Calculator) -> None:
#       assert calculator.add(1, 2) == 3


# Demonstrate fixture lifecycle
print("=== Fixture lifecycle demo ===")
print("\n--- Fresh calculator fixture ---")
calc = calculator_fixture()
print(f"  History: {calc.history} (empty, fresh instance)")
calc.add(5, 10)
print(f"  After add: {calc.history}")

print("\n--- Preloaded calculator fixture ---")
calc = preloaded_calculator_fixture()
print(f"  History: {calc.history} (pre-populated)")

print("\n--- Yield fixture with teardown ---")
gen = temp_file_fixture()
path = next(gen)  # Setup runs, we get the path
print(f"    [Test] Writing to {path}")
with open(path, "w") as f:
    f.write("test data")
try:
    next(gen)  # Trigger teardown
except StopIteration:
    pass

print(f"    File exists after teardown: {os.path.exists(path)}")


# --- Fixture scope illustration ---
print("\n=== Fixture scopes ===")
scopes = {
    "function": "New instance per test function (default, most isolated)",
    "class": "Shared across all methods in a test class",
    "module": "Shared across all tests in a module (file)",
    "session": "Shared across the entire test session (all files)",
}
for scope, description in scopes.items():
    print(f"  @pytest.fixture(scope='{scope}')")
    print(f"    -> {description}")

## Exception Testing with pytest.raises

Testing that code raises the correct exception is just as important as testing
the happy path. `pytest.raises` is a context manager that captures and
inspects exceptions. Here we demonstrate the pattern manually.

In [None]:
import re


# In real pytest code:
#   def test_divide_by_zero():
#       calc = Calculator()
#       with pytest.raises(DivisionByZeroError, match="Cannot divide"):
#           calc.divide(10, 0)


def assert_raises(
    exc_type: type[Exception],
    callable_obj: object,
    *args: object,
    match: str | None = None,
) -> Exception:
    """Simplified version of pytest.raises for demonstration."""
    try:
        callable_obj(*args)  # type: ignore[operator]
    except exc_type as e:
        if match and not re.search(match, str(e)):
            raise AssertionError(
                f"Exception message '{e}' did not match pattern '{match}'"
            ) from e
        return e
    except Exception as e:
        raise AssertionError(
            f"Expected {exc_type.__name__}, got {type(e).__name__}: {e}"
        ) from e
    else:
        raise AssertionError(f"Expected {exc_type.__name__} but no exception was raised")


# Test exception type
calc = Calculator()
exc = assert_raises(DivisionByZeroError, calc.divide, 10, 0)
print(f"PASSED: Caught {type(exc).__name__}: {exc}")

# Test exception message with regex match
exc = assert_raises(DivisionByZeroError, calc.divide, 42, 0, match=r"Cannot divide 42")
print(f"PASSED: Message matched pattern: '{exc}'")

# Test exception inheritance (DivisionByZeroError is a CalculatorError)
exc = assert_raises(CalculatorError, calc.divide, 1, 0)
print(f"PASSED: Caught via parent class: {type(exc).__name__}")

# Test that NO exception is raised when it should be
try:
    assert_raises(DivisionByZeroError, calc.divide, 10, 2)
    print("FAILED: Should have reported missing exception")
except AssertionError as e:
    print(f"PASSED: Correctly detected missing exception: {e}")

# Test wrong exception type
try:
    assert_raises(TypeError, calc.divide, 10, 0)
    print("FAILED: Should have reported wrong exception type")
except AssertionError as e:
    print(f"PASSED: Correctly detected wrong exception type: {e}")

## pytest.mark: Parametrize, Skip, and Expected Failures

Markers add metadata to tests:
- `@pytest.mark.parametrize`: Run the same test with different inputs
- `@pytest.mark.skip` / `@pytest.mark.skipif`: Skip tests conditionally
- `@pytest.mark.xfail`: Mark a test as expected to fail

Parametrization is especially powerful -- it generates a separate test case
for each set of inputs, so failures pinpoint exactly which case broke.

In [None]:
import sys
import platform


# --- @pytest.mark.parametrize ---
# In real pytest:
#   @pytest.mark.parametrize("a, b, expected", [
#       (2, 3, 5),
#       (-1, 1, 0),
#       (0.1, 0.2, 0.3),
#   ])
#   def test_add_parametrized(calculator, a, b, expected):
#       assert calculator.add(a, b) == pytest.approx(expected)


def test_add_parametrized() -> None:
    """Demonstrate parametrized testing."""
    test_cases: list[tuple[float, float, float]] = [
        (2, 3, 5),
        (-1, 1, 0),
        (0, 0, 0),
        (100, -50, 50),
        (0.1, 0.2, 0.3),
        (-2.5, -3.5, -6.0),
    ]

    calc = Calculator()
    for a, b, expected in test_cases:
        result = calc.add(a, b)
        passed = abs(result - expected) < 1e-9
        status = "PASSED" if passed else "FAILED"
        print(f"  {status}: add({a}, {b}) == {expected} (got {result})")


# --- @pytest.mark.skip and @pytest.mark.skipif ---
def test_skip_examples() -> None:
    """Demonstrate skip conditions."""
    # @pytest.mark.skip(reason="Not yet implemented")
    print("  SKIPPED: test_new_feature - reason: Not yet implemented")

    # @pytest.mark.skipif(sys.platform == "win32", reason="Unix only")
    if sys.platform == "win32":
        print("  SKIPPED: test_unix_permissions - reason: Unix only")
    else:
        print(f"  RAN: test_unix_permissions (platform={sys.platform})")

    # @pytest.mark.skipif(sys.version_info < (3, 12), reason="Requires 3.12+")
    if sys.version_info < (3, 12):
        print(f"  SKIPPED: test_312_feature - reason: Requires 3.12+ (have {sys.version_info[:2]})")
    else:
        print(f"  RAN: test_312_feature (version={'.'.join(map(str, sys.version_info[:2]))})")


# --- @pytest.mark.xfail ---
def test_xfail_example() -> None:
    """Demonstrate expected failure."""
    # @pytest.mark.xfail(reason="Known floating point precision issue")
    # The test is expected to fail, and that's OK
    try:
        assert 0.1 + 0.2 == 0.3  # This will fail
        print("  XPASS: Float comparison unexpectedly passed (strict=True would fail)")
    except AssertionError:
        print("  XFAIL: Float comparison failed as expected (known issue)")


print("=== Parametrized tests ===")
test_add_parametrized()
print("\n=== Skip/skipif examples ===")
test_skip_examples()
print("\n=== Expected failures (xfail) ===")
test_xfail_example()

# --- What these look like in real pytest files ---
print("\n=== Real pytest syntax (for reference) ===")
real_pytest_example = '''
import pytest

@pytest.mark.parametrize("a, b, expected", [
    (2, 3, 5),
    (-1, 1, 0),
    (0.1, 0.2, pytest.approx(0.3)),
])
def test_add(calculator: Calculator, a: float, b: float, expected: float) -> None:
    assert calculator.add(a, b) == expected

@pytest.mark.skip(reason="Not yet implemented")
def test_future_feature() -> None:
    ...

@pytest.mark.xfail(reason="Known bug #1234", strict=True)
def test_known_bug() -> None:
    ...
'''
print(real_pytest_example)

## Mocking with unittest.mock

Mocking replaces real objects with controllable stand-ins during testing.
This is essential for isolating units from external dependencies
(databases, APIs, file systems, time).

Key tools:
- `MagicMock`: A flexible mock object that accepts any attribute/call
- `patch`: Temporarily replace an object in a specific namespace
- `side_effect`: Make a mock raise exceptions or return different values per call

In [None]:
from unittest.mock import MagicMock, patch, call
from dataclasses import dataclass
from typing import Protocol


# --- Production code: a service that depends on external systems ---

class DatabaseError(Exception):
    """Raised when a database operation fails."""


class UserRepository(Protocol):
    """Protocol defining the repository interface."""
    def get_user(self, user_id: int) -> dict[str, object]: ...
    def save_user(self, user: dict[str, object]) -> bool: ...


@dataclass
class UserService:
    """Service that uses a repository (to be mocked in tests)."""
    repository: UserRepository

    def get_user_display_name(self, user_id: int) -> str:
        """Get a formatted display name for a user."""
        user = self.repository.get_user(user_id)
        first = user.get("first_name", "Unknown")
        last = user.get("last_name", "")
        return f"{first} {last}".strip()

    def update_user_email(self, user_id: int, new_email: str) -> bool:
        """Update a user's email address."""
        user = self.repository.get_user(user_id)
        user["email"] = new_email
        return self.repository.save_user(user)


# --- MagicMock basics ---
print("=== MagicMock basics ===")
mock_repo = MagicMock(spec=UserRepository)
mock_repo.get_user.return_value = {
    "first_name": "Alice",
    "last_name": "Smith",
    "email": "alice@example.com",
}

service = UserService(repository=mock_repo)
display_name = service.get_user_display_name(user_id=42)

print(f"  Display name: {display_name}")
assert display_name == "Alice Smith"

# Verify the mock was called correctly
mock_repo.get_user.assert_called_once_with(42)
print(f"  get_user called with: {mock_repo.get_user.call_args}")
print(f"  get_user call count: {mock_repo.get_user.call_count}")


# --- side_effect: simulate exceptions ---
print("\n=== side_effect for exceptions ===")
mock_repo_failing = MagicMock(spec=UserRepository)
mock_repo_failing.get_user.side_effect = DatabaseError("Connection timeout")

service_failing = UserService(repository=mock_repo_failing)
try:
    service_failing.get_user_display_name(1)
except DatabaseError as e:
    print(f"  Caught expected error: {e}")


# --- side_effect: different return values per call ---
print("\n=== side_effect for sequential returns ===")
mock_repo_sequence = MagicMock(spec=UserRepository)
mock_repo_sequence.get_user.side_effect = [
    {"first_name": "Alice", "last_name": "Smith"},
    {"first_name": "Bob"},  # No last name
    DatabaseError("Server down"),  # Third call raises
]

service_seq = UserService(repository=mock_repo_sequence)
print(f"  Call 1: {service_seq.get_user_display_name(1)}")
print(f"  Call 2: {service_seq.get_user_display_name(2)}")
try:
    service_seq.get_user_display_name(3)
except DatabaseError as e:
    print(f"  Call 3: Raised {e}")

# Verify all calls
expected_calls = [call(1), call(2), call(3)]
mock_repo_sequence.get_user.assert_has_calls(expected_calls)
print(f"  All expected calls verified")

## Mocking with patch: Replacing Module-Level Objects

`patch` temporarily replaces an object at its lookup location. This is
critical for testing code that calls `time.time()`, `os.path.exists()`,
or other module-level functions. The key rule: **patch where the object
is looked up, not where it is defined**.

In [None]:
from unittest.mock import patch, MagicMock
import time
import os


# --- Production code that depends on time and filesystem ---

def get_uptime_message() -> str:
    """Return a message with the current timestamp."""
    current = time.time()
    return f"System checked at {current:.0f}"


def read_config(path: str) -> dict[str, str]:
    """Read a config file if it exists, else return defaults."""
    if os.path.exists(path):
        # In real code, we'd read and parse the file
        return {"source": "file", "path": path}
    return {"source": "defaults", "debug": "false"}


# --- patch as context manager ---
print("=== patch as context manager ===")
with patch("time.time", return_value=1700000000.0):
    message = get_uptime_message()
    print(f"  {message}")
    assert "1700000000" in message

# After the with block, time.time() is restored
print(f"  Real time after patch: {time.time():.0f}")


# --- patch for filesystem operations ---
print("\n=== Patching os.path.exists ===")

# Simulate file exists
with patch("os.path.exists", return_value=True):
    config = read_config("/etc/myapp/config.yaml")
    print(f"  File 'exists': {config}")
    assert config["source"] == "file"

# Simulate file missing
with patch("os.path.exists", return_value=False):
    config = read_config("/etc/myapp/config.yaml")
    print(f"  File 'missing': {config}")
    assert config["source"] == "defaults"


# --- patch as decorator (shown as reference) ---
print("\n=== patch decorator syntax (for test files) ===")
reference = '''
# Patch where the object is USED, not where it's defined:

@patch("myapp.services.time.time", return_value=1700000000.0)
def test_uptime_message(mock_time: MagicMock) -> None:
    message = get_uptime_message()
    assert "1700000000" in message
    mock_time.assert_called_once()

# Multiple patches (applied bottom-up, passed left-to-right):
@patch("myapp.services.os.path.exists", return_value=True)
@patch("myapp.services.open", create=True)
def test_read_config(mock_open: MagicMock, mock_exists: MagicMock) -> None:
    config = read_config("/etc/config.yaml")
    mock_exists.assert_called_once_with("/etc/config.yaml")
'''
print(reference)

## Testing Patterns: Complete Test Suite Example

A well-organized test suite follows consistent patterns. Below is what a
complete `conftest.py` and test file would look like for our Calculator,
demonstrating fixture sharing, parametrization, and clean AAA structure.

In [None]:
# === conftest.py (shared fixtures) ===
conftest_example = '''
# conftest.py - pytest discovers this automatically
import pytest
from myapp.calculator import Calculator

@pytest.fixture
def calculator() -> Calculator:
    """Provide a fresh Calculator for each test."""
    return Calculator(precision=2)

@pytest.fixture
def calculator_with_history(calculator: Calculator) -> Calculator:
    """Calculator with pre-populated history."""
    calculator.add(10, 20)
    calculator.multiply(3, 7)
    return calculator
'''

# === test_calculator.py ===
test_file_example = '''
import pytest
from myapp.calculator import Calculator, DivisionByZeroError


class TestCalculatorArithmetic:
    """Tests for basic arithmetic operations."""

    @pytest.mark.parametrize("a, b, expected", [
        (2, 3, 5),
        (-1, 1, 0),
        (0, 0, 0),
        (1.5, 2.5, 4.0),
    ])
    def test_add(self, calculator: Calculator, a: float, b: float, expected: float) -> None:
        assert calculator.add(a, b) == pytest.approx(expected)

    @pytest.mark.parametrize("a, b, expected", [
        (10, 3, 3.33),
        (1, 3, 0.33),
        (-6, 2, -3.0),
    ])
    def test_divide(self, calculator: Calculator, a: float, b: float, expected: float) -> None:
        assert calculator.divide(a, b) == pytest.approx(expected)

    def test_divide_by_zero(self, calculator: Calculator) -> None:
        with pytest.raises(DivisionByZeroError, match="Cannot divide"):
            calculator.divide(10, 0)


class TestCalculatorHistory:
    """Tests for history tracking."""

    def test_history_starts_empty(self, calculator: Calculator) -> None:
        assert calculator.history == []

    def test_operations_are_recorded(
        self, calculator_with_history: Calculator
    ) -> None:
        assert len(calculator_with_history.history) == 2

    def test_clear_history(
        self, calculator_with_history: Calculator
    ) -> None:
        calculator_with_history.clear_history()
        assert calculator_with_history.history == []
'''

print("=== conftest.py ===")
print(conftest_example)
print("=== test_calculator.py ===")
print(test_file_example)

print("=== Running tests would produce output like: ===")
print("""
$ pytest -v test_calculator.py

test_calculator.py::TestCalculatorArithmetic::test_add[2-3-5]          PASSED
test_calculator.py::TestCalculatorArithmetic::test_add[-1-1-0]         PASSED
test_calculator.py::TestCalculatorArithmetic::test_add[0-0-0]          PASSED
test_calculator.py::TestCalculatorArithmetic::test_add[1.5-2.5-4.0]   PASSED
test_calculator.py::TestCalculatorArithmetic::test_divide[10-3-3.33]   PASSED
test_calculator.py::TestCalculatorArithmetic::test_divide[1-3-0.33]    PASSED
test_calculator.py::TestCalculatorArithmetic::test_divide[-6-2--3.0]   PASSED
test_calculator.py::TestCalculatorArithmetic::test_divide_by_zero      PASSED
test_calculator.py::TestCalculatorHistory::test_history_starts_empty   PASSED
test_calculator.py::TestCalculatorHistory::test_operations_are_recorded PASSED
test_calculator.py::TestCalculatorHistory::test_clear_history          PASSED

========================= 11 passed in 0.03s =========================
""")