# Unit Test Generator

The requirement: use a Frontier model to generate fast and repeatable unit tests for Python code.


In [None]:
# imports

import os
import io
import sys
import ast
import unittest, contextlib
from dotenv import load_dotenv
from openai import OpenAI
import google.generativeai
import anthropic
from IPython.display import Markdown, display, update_display
import gradio as gr
import subprocess

# environment

load_dotenv(override=True)
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')
os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')

openai = OpenAI()
claude = anthropic.Anthropic()
OPENAI_MODEL = "gpt-4o"
CLAUDE_MODEL = "claude-3-7-sonnet-20250219"

system_message = "You are an assistant that implements unit testing for Python code. "
system_message += "Respond only with Python code; use comments sparingly and do not provide any explanation other than occasional comments. "

def remove_main_block_from_code(code):
    """
    Remove top-level `if __name__ == "__main__":` blocks from code.
    """
    try:
        tree = ast.parse(code)

        class RemoveMain(ast.NodeTransformer):
            def visit_If(self, node):
                # check if this is: if __name__ == "__main__":
                test = node.test
                if (
                    isinstance(test, ast.Compare) and
                    isinstance(test.left, ast.Name) and
                    test.left.id == "__name__" and
                    len(test.ops) == 1 and isinstance(test.ops[0], ast.Eq) and
                    len(test.comparators) == 1 and
                    isinstance(test.comparators[0], ast.Constant) and
                    test.comparators[0].value == "__main__"
                ):
                    return None  # remove this node entirely
                return node

        tree = RemoveMain().visit(tree)
        ast.fix_missing_locations(tree)
        return ast.unparse(tree)  # get back code as string
    except Exception as e:
        print("Error removing __main__ block:", e)
        return code  # fallback: return original code if AST fails

def user_prompt_for(python_file):
    if isinstance(python_file, dict):  # from Gradio
        file_path = python_file["name"]
    elif hasattr(python_file, "name"):  # tempfile
        file_path = python_file.name
    else:  # string path
        file_path = python_file

    with open(file_path, "r", encoding="utf-8") as f:
        python_code = f.read()

    # strip __main__ blocks
    python_code = remove_main_block_from_code(python_code)

    user_prompt = "Write unit tests for this Python code. "
    user_prompt += "Respond only with Python code; do not explain your work other than a few comments. "
    user_prompt += "The unit testing is done in Jupyterlab, so you should use packages that play nicely with the Jupyter kernel. \n\n"
    user_prompt += "Include the original Python code in your generated output so that I can run all in one fell swoop.\n\n"
    user_prompt += python_code

    return user_prompt

def messages_for(python_file):
    return [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_prompt_for(python_file)}
    ]
	
def stream_gpt(python_file):    
    stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python_file), stream=True)
    reply = ""
    for chunk in stream:
        fragment = chunk.choices[0].delta.content or ""
        reply += fragment
        yield reply.replace('```python\n','').replace('```','')
		
def stream_claude(python_file):
    result = claude.messages.stream(
        model=CLAUDE_MODEL,
        max_tokens=2000,
        system=system_message,
        messages=[{"role": "user", "content": user_prompt_for(python_file)}],
    )
    reply = ""
    with result as stream:
        for text in stream.text_stream:
            reply += text
            yield reply.replace('```python\n','').replace('```','')
			
def unit_test(python_file, model):
    if model=="GPT":
        result = stream_gpt(python_file)
    elif model=="Claude":
        result = stream_claude(python_file)
    else:
        raise ValueError("Unknown model")
    for stream_so_far in result:
        yield stream_so_far

def execute_python(code):
    buffer = io.StringIO()
    try:
        with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer):
            # execute code in isolated namespace
            ns = {}
            exec(code, ns)

            # manually collect TestCase subclasses
            test_cases = [
                obj for obj in ns.values()
                if isinstance(obj, type) and issubclass(obj, unittest.TestCase)
            ]
            if test_cases:
                suite = unittest.TestSuite()
                for case in test_cases:
                    suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
                runner = unittest.TextTestRunner(stream=buffer, verbosity=2)
                runner.run(suite)
    except Exception as e:
        print(f"Error during execution: {e}", file=buffer)

    return buffer.getvalue()

In [None]:
# --- Gradio UI ---
with gr.Blocks() as ui:
    gr.Markdown("## Unit Test Generator\nUpload a Python file and get structured unit testing.")
    with gr.Row(): # Row 1
        orig_code = gr.File(label="Upload your Python file", file_types=[".py"])
        test_code = gr.Textbox(label="Unit test code:", lines=10)
    with gr.Row(): # Row 2
        model = gr.Dropdown(["GPT", "Claude"], label="Select model", value="GPT")
    with gr.Row(): # Row 3
        generate = gr.Button("Generate unit test code")
    with gr.Row(): # Row 4
        unit_run = gr.Button("Run Python unit test")
    with gr.Row(): # Row 5
        test_out = gr.Textbox(label="Unit test result:", lines=10)

    generate.click(unit_test, inputs=[orig_code, model], outputs=[test_code])

    unit_run.click(execute_python, inputs=[test_code], outputs=[test_out])

In [None]:
ui.launch(inbrowser=True)