# Week 4 Assignment: Unit Test Code Generator

This tool generates comprehensive unit tests for Python code covering:
- **Positive cases**: Normal, expected behavior
- **Negative cases**: Error handling and invalid inputs
- **Edge cases**: Boundary conditions and special values
- **Stress testing**: Performance with large inputs

Use frontier models to generate high-quality pytest test cases that can be run directly.

In [None]:
# imports

import os
import io
import sys
import re
import contextlib
import unittest
import tempfile
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
from IPython.display import Markdown, display

# pytest is required to run generated tests
try:
    import pytest
    print(f"pytest found: {pytest.__file__}")
except ImportError:
    print("WARNING: pytest not installed. Generated tests will fail to run.")
    print("From the project root (llm_engineering) run: uv sync")
    print("Then restart the kernel and ensure this notebook uses that Python (e.g. select the '.venv' kernel).")

In [None]:
load_dotenv(override=True)
openai_api_key = os.getenv('OPENAI_API_KEY')
anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')
google_api_key = os.getenv('GOOGLE_API_KEY')
grok_api_key = os.getenv('GROK_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
openrouter_api_key = os.getenv('OPENROUTER_API_KEY')

if openai_api_key:
    print(f"OpenAI API Key exists and begins {openai_api_key[:8]}")
else:
    print("OpenAI API Key not set")
    
if anthropic_api_key:
    print(f"Anthropic API Key exists and begins {anthropic_api_key[:7]}")
else:
    print("Anthropic API Key not set (and this is optional)")

if google_api_key:
    print(f"Google API Key exists and begins {google_api_key[:2]}")
else:
    print("Google API Key not set (and this is optional)")

if grok_api_key:
    print(f"Grok API Key exists and begins {grok_api_key[:4]}")
else:
    print("Grok API Key not set (and this is optional)")

if groq_api_key:
    print(f"Groq API Key exists and begins {groq_api_key[:4]}")
else:
    print("Groq API Key not set (and this is optional)")

if openrouter_api_key:
    print(f"OpenRouter API Key exists and begins {openrouter_api_key[:6]}")
else:
    print("OpenRouter API Key not set (and this is optional)")

In [3]:
# Connect to client libraries

openai = OpenAI()

anthropic_url = "https://api.anthropic.com/v1/"
gemini_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
grok_url = "https://api.x.ai/v1"
groq_url = "https://api.groq.com/openai/v1"
ollama_url = "http://localhost:11434/v1"
openrouter_url = "https://openrouter.ai/api/v1"

anthropic = OpenAI(api_key=anthropic_api_key, base_url=anthropic_url)
gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)
grok = OpenAI(api_key=grok_api_key, base_url=grok_url)
groq = OpenAI(api_key=groq_api_key, base_url=groq_url)
ollama = OpenAI(api_key="ollama", base_url=ollama_url)
openrouter = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)

In [4]:
models = [
    "gpt-5",
    "claude-sonnet-4-5-20250929",
    "grok-4",
    "gemini-2.5-pro",
    "qwen2.5-coder",
    "deepseek-coder-v2",
    "gpt-oss:20b",
    "qwen/qwen3-coder-30b-a3b-instruct",
    "openai/gpt-oss-120b",
]

clients = {
    "gpt-5": openai,
    "claude-sonnet-4-5-20250929": anthropic,
    "grok-4": grok,
    "gemini-2.5-pro": gemini,
    "openai/gpt-oss-120b": groq,
    "qwen2.5-coder": ollama,
    "deepseek-coder-v2": ollama,
    "gpt-oss:20b": ollama,
    "qwen/qwen3-coder-30b-a3b-instruct": openrouter,
}

# Want to keep costs ultra-low? Uncomment these lines:
# models = ["gpt-5-nano", "claude-haiku-4-5", "gemini-2.5-flash-lite"]
# clients = {"gpt-5-nano": openai, "claude-haiku-4-5": anthropic, "gemini-2.5-flash-lite": gemini}

## Unit Test Generation Prompts

These prompts instruct the model to generate comprehensive tests covering all test categories.

In [5]:
SYSTEM_PROMPT = """You are an expert Python testing engineer specializing in comprehensive test coverage.
Your task is to generate extensive, runnable unit tests for the given Python code.

MANDATORY REQUIREMENTS:

1. **Test Coverage Categories** - You MUST include ALL four types:
   a) POSITIVE CASES: Normal, expected behavior with valid inputs
   b) NEGATIVE CASES: Error handling, exceptions, invalid inputs
   c) EDGE CASES: Boundary conditions, empty inputs, None values, special characters
   d) STRESS TESTING: Large inputs, performance under load, memory efficiency

2. **Technical Requirements**:
   - Use pytest framework with assert statements
   - Use pytest.raises() for exception testing
   - Use pytest.mark.parametrize for multiple test cases
   - Add fixtures for complex setup if needed
   - Mock external dependencies (requests, databases, file I/O) using unittest.mock

3. **Code Structure**:
   - Include the original code first (copy it exactly)
   - Then add all test code below
   - Group tests by category with clear comments
   - Each test function should have a descriptive docstring

4. **Output Format**:
   - Respond ONLY with Python code
   - NO markdown code fences
   - NO explanations outside of Python comments
   - All code must be executable as a single block

5. **Test Organization**:
   - Start with: # ============ POSITIVE TEST CASES ============
   - Then: # ============ NEGATIVE TEST CASES ============
   - Then: # ============ EDGE CASES ============
   - Finally: # ============ STRESS TESTS ============
"""

def user_prompt_for(python_code: str) -> str:
    return f"""Generate comprehensive pytest unit tests for this Python code.

You MUST include tests for ALL four categories:
1. Positive cases (normal expected behavior)
2. Negative cases (error handling and exceptions)
3. Edge cases (boundary conditions, empty values, None)
4. Stress tests (large inputs, performance)

Include the original code in your response first, then all test code.
Respond only with Python code that can be executed directly.

Python code to test:

```python
{python_code}
```
"""

In [6]:
def messages_for(python_code: str):
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt_for(python_code)}
    ]

In [7]:
def extract_code(reply: str) -> str:
    """Strip markdown code fences from model output."""
    reply = reply.strip()
    # Remove markdown code fences
    reply = re.sub(r'^```python\s*', '', reply, flags=re.MULTILINE)
    reply = re.sub(r'^```\s*', '', reply, flags=re.MULTILINE)
    reply = re.sub(r'```\s*$', '', reply, flags=re.MULTILINE)
    return reply.strip()

In [8]:
def generate_unit_tests(model: str, python_code: str) -> str:
    """Call the chosen model to generate comprehensive unit tests."""
    if not python_code or not python_code.strip():
        return "# Please paste some Python code first."
    
    client = clients.get(model)
    if not client:
        return f"# Error: Unknown model: {model}"
    
    try:
        kwargs = {"model": model, "messages": messages_for(python_code)}
        
        # Add reasoning_effort for OpenAI models (except those via other providers)
        if "gpt" in model and "groq" not in model and "openrouter" not in model:
            kwargs["reasoning_effort"] = "high"
        
        response = client.chat.completions.create(**kwargs)
        reply = response.choices[0].message.content or ""
        return extract_code(reply)
    except Exception as e:
        return f"# Error calling model: {e}"

## Running the Generated Tests

Functions to execute the generated test code using pytest.

In [9]:
def _strip_ansi(text: str) -> str:
    """Remove ANSI escape codes for clean output in Gradio."""
    return re.sub(r"\x1b\[[0-9;]*m", "", text)

def _clean_pytest_output(text: str, temp_path: str = "") -> str:
    """Clean up pytest output: strip ANSI and simplify temp file paths."""
    text = _strip_ansi(text)
    if temp_path:
        base = os.path.basename(temp_path)
        text = text.replace(temp_path, "generated_tests.py")
        # Replace full relative path
        text = re.sub(r"(\.\./)+[^\s]*" + re.escape(base), "generated_tests.py", text)
        # Remove leftover ../ before generated_tests.py
        text = re.sub(r"(\.\./)*(\.\.)?generated_tests\.py", "generated_tests.py", text)
    return text

def _output_to_html(plain_text: str) -> str:
    """Convert pytest output to HTML with green for PASSED, red for FAILED."""
    import html
    lines = plain_text.splitlines()
    out = []
    for line in lines:
        escaped = html.escape(line)
        # Highlight PASSED in green
        if " PASSED" in line or line.strip().endswith("PASSED"):
            escaped = escaped.replace("PASSED", '<span style="color:#0a0;font-weight:600">PASSED</span>')
        # Highlight FAILED in red
        if " FAILED" in line or line.strip().endswith("FAILED"):
            escaped = escaped.replace("FAILED", '<span style="color:#c00;font-weight:600">FAILED</span>')
        # Summary lines
        if "passed" in line.lower() and "=" in line:
            escaped = f'<div style="margin-top:0.75em;font-weight:600;color:#0a0;">{escaped}</div>'
        elif "failed" in line.lower() and "=" in line:
            escaped = f'<div style="margin-top:0.75em;font-weight:600;color:#c00;">{escaped}</div>'
        out.append(escaped)
    return "<pre style='margin:0;font-family:monospace;font-size:0.9em;line-height:1.4;'>" + "\n".join(out) + "</pre>"

def run_unit_tests(code: str) -> str:
    """Execute the combined code (source + tests) using pytest."""
    if not code or not code.strip():
        return "No code to run."
    
    buffer = io.StringIO()
    temp_path = ""
    
    try:
        with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):
            # Ensure pytest is available
            try:
                import pytest
            except ImportError:
                buffer.write(
                    "pytest is not installed in this Python environment.\n\n"
                    "Fix: From the project root folder (llm_engineering) run:\n"
                    "  uv sync\n\n"
                    "Then restart the kernel and select the kernel that uses this project's "
                    "environment (click kernel name top-right ‚Üí Select Another Kernel ‚Üí "
                    "choose the one showing '.venv' or 'llm_engineering').\n"
                )
                return buffer.getvalue()
            
            # Execute the code to define functions and tests
            ns = {}
            try:
                exec(code, ns)
            except ModuleNotFoundError as e:
                if "pytest" in str(e):
                    buffer.write(
                        "pytest not found in this environment.\n\n"
                        "Use the project's Python: from llm_engineering folder run 'uv sync', "
                        "then restart kernel and select the .venv kernel.\n"
                    )
                    return buffer.getvalue()
                raise
            
            # Check for unittest.TestCase subclasses
            test_cases = [
                obj for obj in ns.values()
                if isinstance(obj, type) and issubclass(obj, unittest.TestCase)
            ]
            
            if test_cases:
                # Run unittest-style tests
                suite = unittest.TestSuite()
                for case in test_cases:
                    suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
                runner = unittest.TextTestRunner(stream=buffer, verbosity=2)
                runner.run(suite)
            else:
                # Run pytest-style tests: write to temp file
                with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
                    f.write(code)
                    temp_path = f.name
                try:
                    pytest.main([temp_path, "-v", "--tb=short", "-x"])
                finally:
                    os.unlink(temp_path)
    except Exception as e:
        buffer.write(f"Error executing tests: {e}")
    
    out = buffer.getvalue()
    return _clean_pytest_output(out, temp_path)

def run_unit_tests_html(code: str) -> str:
    """Run tests and return HTML-formatted output with color highlighting."""
    plain = run_unit_tests(code)
    return _output_to_html(plain)

## Example Code to Test

Here are some sample functions you can use to test the unit test generator.

In [None]:
example_code_simple = '''def add(a, b):
    """Add two numbers and return the result."""
    return a + b

def divide(a, b):
    """Divide a by b. Raises ValueError if b is zero."""
    if b == 0:
        raise ValueError("Cannot divide by zero")
    return a / b
'''

example_code_medium = '''def factorial(n):
    """Calculate factorial of n. Raises ValueError for negative numbers."""
    if not isinstance(n, int):
        raise TypeError("Input must be an integer")
    if n < 0:
        raise ValueError("Factorial not defined for negative numbers")
    if n == 0 or n == 1:
        return 1
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result

def is_palindrome(text):
    """Check if a string is a palindrome (case-insensitive)."""
    if not isinstance(text, str):
        raise TypeError("Input must be a string")
    cleaned = ''.join(c.lower() for c in text if c.isalnum())
    return cleaned == cleaned[::-1]
'''

example_code_complex = '''class BankAccount:
    """A simple bank account with deposit, withdraw, and balance operations."""
    
    def __init__(self, initial_balance=0):
        if initial_balance < 0:
            raise ValueError("Initial balance cannot be negative")
        self._balance = initial_balance
    
    def deposit(self, amount):
        """Deposit money into the account."""
        if amount <= 0:
            raise ValueError("Deposit amount must be positive")
        self._balance += amount
        return self._balance
    
    def withdraw(self, amount):
        """Withdraw money from the account."""
        if amount <= 0:
            raise ValueError("Withdrawal amount must be positive")
        if amount > self._balance:
            raise ValueError("Insufficient funds")
        self._balance -= amount
        return self._balance
    
    def get_balance(self):
        """Get current account balance."""
        return self._balance

def find_max_subarray_sum(arr):
    """Find maximum sum of contiguous subarray (Kadane's algorithm)."""
    if not arr:
        return 0
    max_sum = current_sum = arr[0]
    for num in arr[1:]:
        current_sum = max(num, current_sum + num)
        max_sum = max(max_sum, current_sum)
    return max_sum
'''

print("Example codes loaded. Use these in the Gradio interface below.")

## Test the Generator (Without Gradio)

You can test the generator directly here before launching the Gradio interface.

In [None]:
# Test with a simple example
model = models[0]  # Use first available model
generated_tests =  (model, example_code_simple)
print("Generated Tests:")
print("="*80)
print(generated_tests)
print("="*80)

In [12]:
# Run the generated tests
if generated_tests and not generated_tests.startswith("#"):
    print("\nRunning Tests:")
    print("="*80)
    test_output = run_unit_tests(generated_tests)
    print(test_output)
    print("="*80)

## Gradio Interface

Interactive web interface for generating and running unit tests.

In [None]:
def gradio_generate(model: str, python_code: str) -> str:
    """Wrapper for Gradio: generate tests."""
    return generate_unit_tests(model, python_code)

def gradio_run_tests(test_code: str) -> str:
    """Wrapper for Gradio: run tests and return HTML output."""
    return run_unit_tests_html(test_code)

# Create Gradio interface
with gr.Blocks(title="Unit Test Generator", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Unit Test Code Generator
    
    Generate comprehensive unit tests for your Python code covering:
    - **Positive Cases**: Normal, expected behavior
    - **Negative Cases**: Error handling and invalid inputs  
    - **Edge Cases**: Boundary conditions and special values
    - **Stress Tests**: Performance with large inputs
    """)
    
    with gr.Row():
        with gr.Column():
            model_dropdown = gr.Dropdown(
                choices=models,
                value=models[0],
                label="Select Model",
                info="Choose which AI model to generate tests"
            )
            
            input_code = gr.Code(
                label="Python Code to Test",
                language="python",
                lines=20,
                placeholder="Paste your Python code here..."
            )
            
            with gr.Row():
                example_btn1 = gr.Button("Load Simple Example", size="sm")
                example_btn2 = gr.Button("Load Medium Example", size="sm")
                example_btn3 = gr.Button("Load Complex Example", size="sm")
            
            generate_btn = gr.Button("üî® Generate Unit Tests", variant="primary", size="lg")
        
        with gr.Column():
            output_tests = gr.Code(
                label="Generated Unit Tests",
                language="python",
                lines=20
            )
            
            run_btn = gr.Button("‚ñ∂Ô∏è Run Tests", variant="secondary", size="lg")
            
            test_results = gr.HTML(
                label="Test Results",
                value="<p style='color:#666;'>Test output will appear here...</p>"
            )
    
    gr.Markdown("""
    ### Instructions:
    1. Choose a model from the dropdown
    2. Paste your Python code or load an example
    3. Click "Generate Unit Tests" to create comprehensive tests
    4. Review the generated tests
    5. Click "Run Tests" to execute them and see results
    
    The generated tests will include positive, negative, edge, and stress test cases.
    """)
    
    # Event handlers
    example_btn1.click(fn=lambda: example_code_simple, outputs=input_code)
    example_btn2.click(fn=lambda: example_code_medium, outputs=input_code)
    example_btn3.click(fn=lambda: example_code_complex, outputs=input_code)
    
    generate_btn.click(
        fn=gradio_generate,
        inputs=[model_dropdown, input_code],
        outputs=output_tests
    )
    
    run_btn.click(
        fn=gradio_run_tests,
        inputs=output_tests,
        outputs=test_results
    )

# Launch the interface
demo.launch()