In [None]:
!uv pip install pytest

In [None]:
import os
import re
import ast
import sys
import uuid
import json
import textwrap
import subprocess
from pathlib import Path
from dataclasses import dataclass
from typing import List, Protocol, Tuple, Dict, Optional

from dotenv import load_dotenv
from openai import OpenAI
from openai import BadRequestError as _OpenAIBadRequest
import gradio as gr

load_dotenv(override=True)

# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---
GEMINI_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
GROQ_BASE   = "https://api.groq.com/openai/v1"

# --- API Keys (add these in your .env) ---
openai_api_key = os.getenv("OPENAI_API_KEY")   # OpenAI
google_api_key = os.getenv("GOOGLE_API_KEY")   # Gemini
groq_api_key   = os.getenv("GROQ_API_KEY")     # Groq

# --- Clients ---
openai_client = OpenAI()  # OpenAI default (reads OPENAI_API_KEY)
gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None
groq_client   = OpenAI(api_key=groq_api_key,   base_url=GROQ_BASE)   if groq_api_key   else None

# --- Model registry: label -> { client, model } ---
MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}

def _register(label: str, client: Optional[OpenAI], model_id: str):
    """Add a model to the registry only if its client is configured."""
    if client is not None:
        MODEL_REGISTRY[label] = {"client": client, "model": model_id}

# OpenAI
_register("OpenAI • GPT-5",        openai_client, "gpt-5")
_register("OpenAI • GPT-5 Nano",   openai_client, "gpt-5-nano")
_register("OpenAI • GPT-4o-mini",  openai_client, "gpt-4o-mini")

# Gemini (Google)
_register("Gemini • 2.5 Pro",      gemini_client, "gemini-2.5-pro")
_register("Gemini • 2.5 Flash",    gemini_client, "gemini-2.5-flash")

# Groq
_register("Groq • Llama 3.1 8B",   groq_client,   "llama-3.1-8b-instant")
_register("Groq • Llama 3.3 70B",  groq_client,   "llama-3.3-70b-versatile")
_register("Groq • GPT-OSS 20B",    groq_client,   "openai/gpt-oss-20b")
_register("Groq • GPT-OSS 120B",   groq_client,   "openai/gpt-oss-120b")

DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)

print(f"Providers configured → OpenAI:{bool(openai_api_key)}  Gemini:{bool(google_api_key)}  Groq:{bool(groq_api_key)}")
print("Models available     →", ", ".join(MODEL_REGISTRY.keys()) or "None (add API keys in .env)")


In [None]:
class CompletionClient(Protocol):
    """Any LLM client provides a .complete() method using a registry label."""
    def complete(self, *, model_label: str, system: str, user: str) -> str: ...


def _extract_code_or_text(s: str) -> str:
    """Prefer fenced python if present; otherwise return raw text."""
    m = re.search(r"```(?:python)?\s*(.*?)```", s, flags=re.S | re.I)
    return m.group(1).strip() if m else s.strip()


class MultiModelChatClient:
    """Routes requests to the right provider/client based on model label."""
    def __init__(self, registry: Dict[str, Dict[str, object]]):
        self._registry = registry

    def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:
        params = {
            "model": model_id,
            "messages": [
                {"role": "system", "content": system},
                {"role": "user",   "content": user},
            ],
        }
        resp = client.chat.completions.create(**params)  # do NOT send temperature for strict providers
        text = (resp.choices[0].message.content or "").strip()
        return _extract_code_or_text(text)

    def complete(self, *, model_label: str, system: str, user: str) -> str:
        if model_label not in self._registry:
            raise ValueError(f"Unknown model label: {model_label}")
        info   = self._registry[model_label]
        client = info["client"]
        model  = info["model"]
        try:
            return self._call(client=client, model_id=str(model), system=system, user=user)
        except _OpenAIBadRequest as e:
            # Providers may reject stray params; we don't send any, but retry anyway.
            if "temperature" in str(e).lower():
                return self._call(client=client, model_id=str(model), system=system, user=user)
            raise


In [None]:
@dataclass(frozen=True)
class SymbolInfo:
    kind: str      # "function" | "class" | "method"
    name: str
    signature: str
    lineno: int

class PublicAPIExtractor:
    """Extract concise 'public API' summary from a Python module."""
    def extract(self, source: str) -> List[SymbolInfo]:
        tree = ast.parse(source)
        out: List[SymbolInfo] = []
        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and not node.name.startswith("_"):
                out.append(SymbolInfo("function", node.name, self._sig(node), node.lineno))
            elif isinstance(node, ast.ClassDef) and not node.name.startswith("_"):
                out.append(SymbolInfo("class", node.name, node.name, node.lineno))
                for sub in node.body:
                    if isinstance(sub, ast.FunctionDef) and not sub.name.startswith("_"):
                        out.append(SymbolInfo("method",
                                              f"{node.name}.{sub.name}",
                                              self._sig(sub),
                                              sub.lineno))
        return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))

    def _sig(self, fn: ast.FunctionDef) -> str:
        args = [a.arg for a in fn.args.args]
        if fn.args.vararg:
            args.append("*" + fn.args.vararg.arg)
        args.extend(a.arg + "=?" for a in fn.args.kwonlyargs)
        if fn.args.kwarg:
            args.append("**" + fn.args.kwarg.arg)
        ret = ""
        if fn.returns is not None:
            try:
                ret = f" -> {ast.unparse(fn.returns)}"
            except Exception:
                pass
        return f"def {fn.name}({', '.join(args)}){ret}:"


In [None]:
class PromptBuilder:
    """Builds deterministic prompts for pytest generation."""
    SYSTEM = (
        "You are a senior Python engineer. Produce a single, self-contained pytest file.\n"
        "Rules:\n"
        "- Output only Python test code (no prose, no markdown fences).\n"
        "- Use plain pytest tests (functions), no classes unless unavoidable.\n"
        "- Deterministic: avoid network/IO; seed randomness if used.\n"
        "- Import the target module by module name only.\n"
        "- Cover every public function and method with at least one tiny test.\n"
        "- Prefer straightforward, fast assertions.\n"
    )

    def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:
        summary = "\n".join(f"- {s.kind:<6}  {s.signature}" for s in symbols) or "- (no public symbols)"
        return textwrap.dedent(f"""
        Create pytest tests for module `{module_name}`.

        Public API Summary:
        {summary}

        Constraints:
        - Import as: `import {module_name} as mod`
        - Keep tests tiny, fast, and deterministic.

        Full module source (for reference):
        # --- BEGIN SOURCE {module_name}.py ---
        {source}
        # --- END SOURCE ---
        """).strip()


In [None]:
def _ensure_header_and_import(code: str, module_name: str) -> str:
    """Ensure tests import pytest and the target module as 'mod'."""
    code = code.strip()
    needs_pytest = "import pytest" not in code
    has_mod = (f"import {module_name} as mod" in code) or (f"from {module_name} import" in code)
    needs_import = not has_mod

    header = []
    if needs_pytest:
        header.append("import pytest")
    if needs_import:
        header.append(f"import {module_name} as mod")

    return ("\n".join(header) + "\n\n" + code) if header else code


def build_module_name_from_path(path: str) -> str:
    return Path(path).stem


In [None]:
class TestGenerator:
    """Extraction → prompt → model → polish."""
    def __init__(self, llm: CompletionClient):
        self._llm = llm
        self._extractor = PublicAPIExtractor()
        self._prompts = PromptBuilder()

    def generate_tests(self, model_label: str, module_name: str, source: str) -> str:
        symbols = self._extractor.extract(source)
        user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)
        raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)
        return _ensure_header_and_import(raw, module_name)


In [None]:
def _parse_pytest_summary(output: str) -> Tuple[str, Dict[str, int]]:
    """
    Parse the final summary line like:
      '3 passed, 1 failed, 2 skipped in 0.12s'
    Return (summary_line, counts_dict).
    """
    summary_line = ""
    for line in output.strip().splitlines()[::-1]:  # scan from end
        if " passed" in line or " failed" in line or " error" in line or " skipped" in line or " deselected" in line:
            summary_line = line.strip()
            break

    counts = {"passed": 0, "failed": 0, "errors": 0, "skipped": 0, "xfail": 0, "xpassed": 0}
    m = re.findall(r"(\d+)\s+(passed|failed|errors?|skipped|xfailed|xpassed)", summary_line)
    for num, kind in m:
        if kind.startswith("error"):
            counts["errors"] += int(num)
        elif kind == "passed":
            counts["passed"] += int(num)
        elif kind == "failed":
            counts["failed"] += int(num)
        elif kind == "skipped":
            counts["skipped"] += int(num)
        elif kind == "xfailed":
            counts["xfail"] += int(num)
        elif kind == "xpassed":
            counts["xpassed"] += int(num)

    return summary_line or "(no summary line found)", counts


def run_pytest_on_snippet(module_name: str, module_code: str, tests_code: str) -> Tuple[str, str]:
    """
    Create an isolated temp workspace, write module + tests, run pytest,
    and return (human_summary, full_cli_output).
    """
    if not module_name or not module_code.strip() or not tests_code.strip():
        return "❌ Provide module name, module code, and tests.", ""

    run_id = uuid.uuid4().hex[:8]
    base = Path(".pytest_runs") / f"run_{run_id}"
    tests_dir = base / "tests"
    tests_dir.mkdir(parents=True, exist_ok=True)

    # Write module and tests
    (base / f"{module_name}.py").write_text(module_code, encoding="utf-8")
    (tests_dir / f"test_{module_name}.py").write_text(tests_code, encoding="utf-8")

    # Run pytest with this temp dir on PYTHONPATH
    env = os.environ.copy()
    env["PYTHONPATH"] = str(base) + os.pathsep + env.get("PYTHONPATH", "")
    cmd = [sys.executable, "-m", "pytest", "-q"]  # quiet output, but still includes summary
    proc = subprocess.run(cmd, cwd=base, env=env, text=True, capture_output=True)

    full_out = (proc.stdout or "") + ("\n" + proc.stderr if proc.stderr else "")
    summary_line, counts = _parse_pytest_summary(full_out)

    badges = []
    for key in ("passed", "failed", "errors", "skipped", "xpassed", "xfail"):
        val = counts.get(key, 0)
        if val:
            badges.append(f"**{key}: {val}**")
    badges = "  •  ".join(badges) if badges else "no tests collected?"

    human = f"{summary_line}\n\n{badges}"
    return human, full_out


In [None]:
LLM = MultiModelChatClient(MODEL_REGISTRY)
SERVICE = TestGenerator(LLM)

def generate_from_code(model_label: str, module_name: str, code: str, save: bool, out_dir: str) -> Tuple[str, str]:
    if not model_label or model_label not in MODEL_REGISTRY:
        return "", "❌ Pick a model (or add API keys for providers in .env)."
    if not module_name.strip():
        return "", "❌ Please provide a module name."
    if not code.strip():
        return "", "❌ Please paste some Python code."

    tests_code = SERVICE.generate_tests(model_label=model_label, module_name=module_name.strip(), source=code)
    saved = ""
    if save:
        out = Path(out_dir or "tests")
        out.mkdir(parents=True, exist_ok=True)
        out_path = out / f"test_{module_name}.py"
        out_path.write_text(tests_code, encoding="utf-8")
        saved = f"✅ Saved to {out_path}"
    return tests_code, saved


def generate_from_file(model_label: str, file_obj, save: bool, out_dir: str) -> Tuple[str, str]:
    if file_obj is None:
        return "", "❌ Please upload a .py file."
    code = file_obj.decode("utf-8")
    module_name = build_module_name_from_path("uploaded_module.py")
    return generate_from_code(model_label, module_name, code, save, out_dir)


In [None]:
EXAMPLE_CODE = """\
def add(a: int, b: int) -> int:
    return a + b

def divide(a: float, b: float) -> float:
    if b == 0:
        raise ZeroDivisionError("b must be non-zero")
    return a / b

class Counter:
    def __init__(self, start: int = 0):
        self.value = start

    def inc(self, by: int = 1):
        self.value += by
        return self.value
"""


In [None]:
with gr.Blocks(title="PyTest Generator") as ui:
    gr.Markdown(
        "## 🧪 PyTest Generator (Week 4 • Community Contribution)\n"
        "Generate **minimal, deterministic** pytest tests from a Python module using your chosen model/provider."
    )

    with gr.Row(equal_height=True):
        # LEFT: inputs (module code)
        with gr.Column(scale=6):
            with gr.Row():
                model_dd = gr.Dropdown(
                    list(MODEL_REGISTRY.keys()),
                    value=DEFAULT_MODEL,
                    label="Model (OpenAI, Gemini, Groq)"
                )
                module_name_tb = gr.Textbox(
                    label="Module name (used in `import <name> as mod`)",
                    value="mymodule"
                )
            code_in = gr.Code(
                label="Python module code",
                language="python",
                lines=24,
                value=EXAMPLE_CODE
            )
            with gr.Row():
                save_cb = gr.Checkbox(label="Also save generated tests to /tests", value=True)
                out_dir_tb = gr.Textbox(label="Output folder", value="tests")
            gen_btn = gr.Button("Generate tests", variant="primary")

        # RIGHT: outputs (generated tests + pytest run)
        with gr.Column(scale=6):
            tests_out = gr.Code(label="Generated tests (pytest)", language="python", lines=24)
            with gr.Row():
                run_btn = gr.Button("Run PyTest", variant="secondary")
            summary_md = gr.Markdown()
            full_out = gr.Textbox(label="Full PyTest output", lines=12)

    # --- events ---

    def _on_gen(model_label, name, code, save, outdir):
        tests, msg = generate_from_code(model_label, name, code, save, outdir)
        status = msg or "✅ Done"
        return tests, status

    gen_btn.click(
        _on_gen,
        inputs=[model_dd, module_name_tb, code_in, save_cb, out_dir_tb],
        outputs=[tests_out, summary_md],
    )

    def _on_run(name, code, tests):
        summary, details = run_pytest_on_snippet(name, code, tests)
        return summary, details

    run_btn.click(
        _on_run,
        inputs=[module_name_tb, code_in, tests_out],
        outputs=[summary_md, full_out],
    )

ui.launch(inbrowser=True)
