# Week 4 Assignment: Unit Test Codegen Tool

Use a frontier model to generate unit tests for Python code. Paste your code, pick a model, generate tests, and run them.

In [11]:
# imports

import os
import io
import sys
import re
import contextlib
import unittest
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
from IPython.display import Markdown, display

# pytest is required to run generated tests
try:
    import pytest
    print("pytest found:", pytest.__file__)
except ImportError:
    print("WARNING: pytest not installed. Generated tests will fail to run.")
    print("From the project root (llm_engineering) run:  uv sync")
    print("Then restart the kernel and ensure this notebook uses that Python (e.g. select the '.venv' kernel).")


pytest found: /Users/collinewaitire/projects/python-p/llm_engineering/.venv/lib/python3.12/site-packages/pytest/__init__.py


In [12]:
load_dotenv(override=True)
openai_api_key = os.getenv('OPENAI_API_KEY')
anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')
google_api_key = os.getenv('GOOGLE_API_KEY')
grok_api_key = os.getenv('GROK_API_KEY')
groq_api_key = os.getenv('GROQ_API_KEY')
openrouter_api_key = os.getenv('OPENROUTER_API_KEY')

if openai_api_key:
    print(f"OpenAI API Key exists and begins {openai_api_key[:8]}")
else:
    print("OpenAI API Key not set")
if anthropic_api_key:
    print(f"Anthropic API Key exists and begins {anthropic_api_key[:7]}")
else:
    print("Anthropic API Key not set (optional)")
if google_api_key:
    print(f"Google API Key exists and begins {google_api_key[:2]}")
else:
    print("Google API Key not set (optional)")
if grok_api_key:
    print(f"Grok API Key exists and begins {grok_api_key[:4]}")
else:
    print("Grok API Key not set (optional)")
if groq_api_key:
    print(f"Groq API Key exists and begins {groq_api_key[:4]}")
else:
    print("Groq API Key not set (optional)")
if openrouter_api_key:
    print(f"OpenRouter API Key exists and begins {openrouter_api_key[:6]}")
else:
    print("OpenRouter API Key not set (optional)")


OpenAI API Key exists and begins sk-proj-
Anthropic API Key not set (optional)
Google API Key not set (optional)
Grok API Key not set (optional)
Groq API Key not set (optional)
OpenRouter API Key not set (optional)


In [13]:
# Connect to client libraries (same as day5)

openai = OpenAI()
anthropic_url = "https://api.anthropic.com/v1/"
gemini_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
grok_url = "https://api.x.ai/v1"
groq_url = "https://api.groq.com/openai/v1"
ollama_url = "http://localhost:11434/v1"
openrouter_url = "https://openrouter.ai/api/v1"

anthropic = OpenAI(api_key=anthropic_api_key, base_url=anthropic_url)
gemini = OpenAI(api_key=google_api_key, base_url=gemini_url)
grok = OpenAI(api_key=grok_api_key, base_url=grok_url)
groq = OpenAI(api_key=groq_api_key, base_url=groq_url)
ollama = OpenAI(api_key="ollama", base_url=ollama_url)
openrouter = OpenAI(api_key=openrouter_api_key, base_url=openrouter_url)

models = [
    "gpt-5",
    "claude-sonnet-4-5-20250929",
    "grok-4",
    "gemini-2.5-pro",
    "qwen2.5-coder",
    "deepseek-coder-v2",
    "gpt-oss:20b",
    "qwen/qwen3-coder-30b-a3b-instruct",
    "openai/gpt-oss-120b",
]

clients = {
    "gpt-5": openai,
    "claude-sonnet-4-5-20250929": anthropic,
    "grok-4": grok,
    "gemini-2.5-pro": gemini,
    "openai/gpt-oss-120b": groq,
    "qwen2.5-coder": ollama,
    "deepseek-coder-v2": ollama,
    "gpt-oss:20b": ollama,
    "qwen/qwen3-coder-30b-a3b-instruct": openrouter,
}


## Unit test generation prompts

In [14]:
SYSTEM_PROMPT = """You are an expert at writing unit tests for Python code.
Your task is to generate comprehensive, runnable unit tests for the given Python code.

Rules:
- Use pytest. Use assert statements and pytest idioms (e.g. pytest.raises for exceptions).
- Respond ONLY with Python code. No markdown, no explanation outside comments.
- Include the original code in your response so the tests can run in one block (paste the source first, then the test code).
- Cover normal cases, edge cases, and error cases where appropriate.
- If the code has dependencies (e.g. requests), mock them in tests; use unittest.mock or pytest fixtures as needed.
- Add a brief comment before each test describing what it checks.
"""

def user_prompt_for(python_code: str) -> str:
    return f"""Generate pytest unit tests for this Python code.
Include the original code in your response first, then the test code, so everything can be executed together in a single block.
Respond only with Python code.

Python code to test:

```python
{python_code}
```
"""


In [15]:
def messages_for(python_code: str):
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt_for(python_code)},
    ]

def extract_code(reply: str) -> str:
    """Strip markdown code fences from model output."""
    reply = reply.strip()
    for pattern in (r'```python\s*', r'```\s*'):
        reply = re.sub(f'^{pattern}', '', reply)
        reply = re.sub(f'{pattern}$', '', reply)
    return reply.strip()

def generate_unit_tests(model: str, python_code: str) -> str:
    """Call the chosen model to generate unit tests for the given Python code."""
    if not python_code or not python_code.strip():
        return "Please paste some Python code first."
    client = clients.get(model)
    if not client:
        return f"Unknown model: {model}"
    try:
        kwargs = {"model": model, "messages": messages_for(python_code)}
        if "gpt" in model and "groq" not in model and "openrouter" not in model:
            kwargs["reasoning_effort"] = "high"
        response = client.chat.completions.create(**kwargs)
        reply = response.choices[0].message.content or ""
        return extract_code(reply)
    except Exception as e:
        return f"Error calling model: {e}"


## Run generated tests

In [16]:
def _strip_ansi(text: str) -> str:
    """Remove ANSI escape codes so output is readable in Gradio textbox."""
    return re.sub(r"\x1b\[[0-9;]*m", "", text)

def _clean_pytest_output(text: str, temp_path: str = "") -> str:
    """Strip ANSI and shorten temp file paths in pytest output."""
    text = _strip_ansi(text)
    if temp_path:
        base = os.path.basename(temp_path)
        text = text.replace(temp_path, "generated_tests.py")
        # Replace full relative path (greedy): ../../../../path/tmpXXX.py -> generated_tests.py
        text = re.sub(r"(\.\./)+[^\s]*" + re.escape(base), "generated_tests.py", text)
        # Remove any leftover ../ or .. before "generated_tests.py"
        text = re.sub(r"(\.\./)*(\.\.)?generated_tests\.py", "generated_tests.py", text)
    return text

def _output_to_html(plain_text: str) -> str:
    """Convert pytest output to HTML with green for PASSED, red for FAILED."""
    import html
    lines = plain_text.splitlines()
    out = []
    for line in lines:
        escaped = html.escape(line)
        if " PASSED " in line or line.strip().endswith("PASSED"):
            escaped = escaped.replace("PASSED", '<span style="color:#0a0;font-weight:600">PASSED</span>')
        if " FAILED " in line or line.strip().endswith("FAILED"):
            escaped = escaped.replace("FAILED", '<span style="color:#c00;font-weight:600">FAILED</span>')
        if "passed" in line and "warning" in line.lower() and "=" in line:
            escaped = f'<div style="margin-top:0.75em;font-weight:600;color:#0a0;">{escaped}</div>'
        elif "failed" in line and "=" in line:
            escaped = f'<div style="margin-top:0.75em;font-weight:600;color:#c00;">{escaped}</div>'
        out.append(escaped)
    return "<pre style='margin:0;font-family:monospace;font-size:0.9em;line-height:1.4;'>" + "\n".join(out) + "</pre>"

def run_unit_tests(code: str) -> str:
    """Execute the combined code (source + tests) and run tests."""
    if not code or not code.strip():
        return "No code to run."
    buffer = io.StringIO()
    temp_path = ""
    try:
        with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):
            # Ensure pytest is available (generated code often does "import pytest" at top)
            try:
                import pytest
            except ImportError:
                buffer.write(
                    "pytest is not installed in this Python environment.\n\n"
                    "Fix: From the project root folder (llm_engineering) run:\n"
                    "  uv sync\n\n"
                    "Then restart the kernel and pick the kernel that uses this project's "
                    "environment (in Cursor: click the kernel name top-right → Select Another Kernel "
                    "→ Python Environments → choose the one that shows '.venv' or 'llm_engineering').\n"
                )
                return buffer.getvalue()
            ns = {}
            try:
                exec(code, ns)
            except ModuleNotFoundError as e:
                if "pytest" in str(e):
                    buffer.write(
                        "pytest not found in this environment.\n\n"
                        "Use the project's Python: from the llm_engineering folder run 'uv sync', "
                        "then restart the kernel and select the .venv / llm_engineering kernel.\n"
                    )
                    return buffer.getvalue()
                raise
            # Run unittest.TestCase subclasses if present
            test_cases = [
                obj for obj in ns.values()
                if isinstance(obj, type) and issubclass(obj, unittest.TestCase)
            ]
            if test_cases:
                suite = unittest.TestSuite()
                for case in test_cases:
                    suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
                runner = unittest.TextTestRunner(stream=buffer, verbosity=2)
                runner.run(suite)
            else:
                # Pytest-style: write to temp file and run pytest (no color for clean UI output)
                import tempfile
                with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
                    f.write(code)
                    temp_path = f.name
                try:
                    pytest.main([temp_path, "-v", "--tb=short", "-x"])
                finally:
                    os.unlink(temp_path)
    except Exception as e:
        buffer.write(f"Error: {e}")
    out = buffer.getvalue()
    return _clean_pytest_output(out, temp_path)

def run_unit_tests_html(code: str) -> str:
    """Run tests and return HTML with green/red for Gradio."""
    plain = run_unit_tests(code)
    return _output_to_html(plain)


In [None]:
# run_unit_tests is defined in the cell above


## Gradio UI

In [17]:
sample_code = """def add(a: int, b: int) -> int:
    return a + b

def divide(a: float, b: float) -> float:
    if b == 0:
        raise ValueError("division by zero")
    return a / b
"""

with gr.Blocks(title="Unit Test Codegen") as ui:
    gr.Markdown("# Unit Test Codegen Tool")
    gr.Markdown("Paste Python code, choose a model, and generate pytest unit tests. Then run them.")
    with gr.Row():
        code_in = gr.Code(
            label="Python code to test",
            value=sample_code,
            language="python",
            lines=12,
        )
        test_code = gr.Code(
            label="Generated unit tests (source + tests)",
            value="",
            language="python",
            lines=20,
        )
    with gr.Row():
        model = gr.Dropdown(choices=models, value=models[0], label="Model")
        generate_btn = gr.Button("Generate unit tests")
        run_btn = gr.Button("Run tests")
    test_output = gr.HTML(label="Test run output", value="<pre style='margin:0;color:#666'>Run tests to see output. Passing tests will appear in green, failed in red.</pre>")

    generate_btn.click(fn=generate_unit_tests, inputs=[model, code_in], outputs=[test_code])
    run_btn.click(fn=run_unit_tests_html, inputs=[test_code], outputs=[test_output])

ui.launch(inbrowser=True)


* Running on local URL:  http://127.0.0.1:7864
* To create a public link, set `share=True` in `launch()`.


