# LLM4WCA Test
We evaluate the efficacy of OpenAI's GPT-5 model in generating worst-case constraints for Java programs in the [WARP-benchmark](https://github.com/dannkoh/warp-benchmark). We will ask the model for worst-case constraints of size *n* (2,4,8,16,30) and evaluate the semantic and syntactic accuracy of the generated constraints using **Z3**.

In [1]:
import json
import os
import re
from pathlib import Path
from typing import Any, Optional
from pprint import pprint

import openai
from dotenv import load_dotenv

load_dotenv()

client = openai.OpenAI(api_key=os.getenv("OPENAI_KEY"))

SIZES = {2,4,8,16,30}

In [2]:
# Load Java Programs
java_programs = {}
for file in Path("../../spf-wca/src/examples/custom/").glob("*.java"):
    with file.open() as f:
        java_programs[file.name.removesuffix(".java")] = {"program": f.read()}


pattern_re = re.compile(r"^(?:custom\.)?([^.]+)_([0-9]+)\.smt2$")
for file in Path("../../spf-wca/custom/").rglob("*.smt2"):
    match = pattern_re.match(file.name)
    if not match:
        continue
    problem, n_str = match.groups()
    try:
        n = int(n_str)
    except ValueError:
        continue
    if n not in SIZES:
        continue
    with file.open() as f:
        smt_lines = [line.strip() for line in f if line.strip()]
        constants = "\n".join(line for line in smt_lines if line.startswith("(declare-const"))
        assertions = "\n".join(line for line in smt_lines if line.startswith("(assert"))
    if problem not in java_programs:
        raise ValueError
    java_programs[problem][n] = {"constants": constants, "assertions": assertions}

for prog in java_programs:
    for n in SIZES:
        if n not in java_programs[prog]:
            java_programs[prog][n] = {"constants": None, "assertions": None}

In [None]:
def format_prompt(program, n):
    return f"""```java
{program}
```
Given the Java program above, determine its worst-case time complexity for input size n={n}.

Express the corresponding constraint in SMT-LIB v2 format using only the standard QF_LIA (Quantifier-Free Linear Integer Arithmetic) theory.
- Use only integer types (`Int`) and standard SMT-LIB v2 arithmetic functions, including division and modulo.
- Do not use bitvectors, custom functions, or any non-standard theories or types.

At the end, provide the final SMT-LIB v2 constraint, starting with 'Answer:' on a new line.
"""

In [None]:
model = "gpt-5-2025-08-07"
endpoint = "/v1/chat/completions"

for prog_name, data in java_programs.items():
    batch = []
    for n in SIZES:
        prompt = format_prompt(data["program"], n)
        batch.append(
            {
                "custom_id":str(n),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {"model": model, "messages": [{"role": "user", "content": prompt}]},
            }
        )
    with open(f"batch/{prog_name}-batch.jsonl", "w") as f:
        for item in batch:
            json.dump(item, f)
            f.write("\n")


    client.files.create(
        file=open(f"batch/{prog_name}-batch.jsonl", "rb"),
        purpose="batch"
    )


In [None]:
for file in client.files.list():
    if file.purpose == "batch":
        client.batches.create(
            completion_window="24h",
            endpoint=endpoint,
            input_file_id=file.id
        )

In [None]:
for batch in client.batches.list():
    if batch.created_at >= 1756138709:
        print(batch)

In [3]:
import json

for batches in client.batches.list():
    if batches.created_at >= 1756138709:
        result = client.files.content(batches.output_file_id).content.decode("utf-8")
        prog_name = client.files.retrieve(batches.input_file_id).filename.removesuffix("-batch.jsonl")
        # Parse JSONL back into java_programs for each n in "response" key
        result = [json.loads(line) for line in result.strip().splitlines() if line.strip()]
        for item in result:
            java_programs[prog_name][int(item["custom_id"])]["response"] = item["response"]["body"]["choices"][0]["message"]["content"].strip()

In [11]:
print(java_programs.keys())

with open("results/java_programs.dict", "w") as f:
    json.dump(java_programs, f, indent=2)

for key in java_programs.keys():
    # create an empty .ipynb file for each Java program
    open(f"results/{key}.ipynb", "w").close()


dict_keys(['RampUp', 'ComplexStateMachineParser', 'QuickSort', 'NaiveFibonacci', 'Collatz', 'KnapsackSolver', 'MergeSort', 'BinaryTreeSearch', 'MazeSolver', 'SortedListInsert', 'CaseFlipper', 'Dijkstra', 'TowerOfHanoi', 'BinarySearch', 'ArrayTwister', 'BubbleSort', 'BinarySearchTreeHeight', 'SubarraySumFinder', 'GreedyStepper', 'DizzyRamp'])
