In [None]:
from datasets import load_dataset
from invariant.analyzer import Policy
import re
from tests.utils import *

import nest_asyncio
nest_asyncio.apply()

In [None]:
dataset = load_dataset("xingyaoww/opendevin-code-act")["codeact"]

In [None]:
regex = {
    "bash": r'<execute_bash>(.*?)</execute_bash>',
    "ipython": r'<execute_ipython>(.*?)</execute_ipython>',
    "browse": r'<execute_browse>(.*?)</execute_browse>',
}

trace = []
convs = dataset["conversations"]
for conv in convs[:100]:
    last_call_idx = None
    for idx, msg in enumerate(conv):
        if msg["role"] == "assistant":
            function_name, arg = None, None
            for lang in ["bash", "ipython", "browse"]:
                match = re.search(regex[lang], msg["content"], re.DOTALL)
                if match is not None:
                    function_name = lang
                    arg = match.group(1)
                    thought = msg["content"][:match.start()]
            if function_name is None:
                trace.append(assistant(msg["content"]))
            else:
                last_call_idx = str(idx)
                call = tool_call(last_call_idx, function_name, {"arg": arg})
                trace.append(assistant(thought, call))
        else:
            if msg["content"].startswith("OBSERVATION:\n\n"):
                trace.append(tool(last_call_idx, msg["content"][len("OBSERVATION:\n\n"):]))
            else:
                trace.append(user(msg["content"]))
    
    policy = Policy.from_string("""
    from invariant.detectors import python_code
                                
    raise PolicyViolation("found double syntax error") if:
        (call1: ToolCall) -> (call2: ToolCall)
        call1.function.name == "ipython"
        python_code(call1.function.arguments["arg"]).syntax_error
        call2.function.name == "ipython"
        python_code(call2.function.arguments["arg"]).syntax_error
    """)
    res = policy.analyze(trace)
    if len(res.errors) > 0:
        print(res)
        for t in trace:
            print(t)
        break
    