In [2]:
import asyncio

In [3]:
import asyncio

async def agent():
    print('[agent] start')
    for i in range(10):
        await asyncio.sleep(1)
        print('[agent]', i)
    print('[agent] end')
    return 1


async def guardrail():
    print('[guardrail pass] start')
    await asyncio.sleep(1.5)
    print('[guardrail pass] check pass')
    print('[guardrail pass] stop')

    return True

In [4]:
from dataclasses import dataclass

@dataclass
class GuardrailFunctionOutput:
    output_info: str
    tripwire_triggered: bool

In [5]:
class GuardrailException(Exception):
    def __init__(self, message: str, info: GuardrailFunctionOutput):
        super().__init__(message)
        self.info = info

In [6]:
check_pass = await guardrail()
if check_pass:
    await agent()

[guardrail pass] start
[guardrail pass] check pass
[guardrail pass] stop
[agent] start
[agent] 0
[agent] 1
[agent] 2
[agent] 3
[agent] 4
[agent] 5
[agent] 6
[agent] 7
[agent] 8
[agent] 9
[agent] end


In [7]:
r1, r2 = await asyncio.gather(
    guardrail(),
    agent()
)

[guardrail pass] start
[agent] start
[agent] 0
[guardrail pass] check pass
[guardrail pass] stop
[agent] 1
[agent] 2
[agent] 3
[agent] 4
[agent] 5
[agent] 6
[agent] 7
[agent] 8
[agent] 9
[agent] end


In [8]:
from dataclasses import dataclass

@dataclass
class GuardrailFunctionOutput:
    output_info: str
    tripwire_triggered: bool

In [9]:
class GuardrailException(Exception):
    def __init__(self, message: str, info: GuardrailFunctionOutput):
        super().__init__(message)
        self.info = info


async def guardrail():
    print('[guardrail] start')
    await asyncio.sleep(1.5)
    print('[guardrail] stop')

    info = GuardrailFunctionOutput(
        output_info='check passed',
        tripwire_triggered=False
    )
        
    return 0


async def guardrail_fail():
    print('[guardrail fail] start')
    await asyncio.sleep(2.5)
    print('[guardrail fail] check fails')
    
    info = GuardrailFunctionOutput(
        output_info='check fails',
        tripwire_triggered=True
    )
    raise GuardrailException("check fails", info)

[agent] 2
[agent] 3
[agent] 4
[agent] 5
[agent] 6
[agent] 7
[agent] 8
[agent] 9
[agent] end


In [14]:
agent_task = asyncio.create_task(agent())
guardrail_task = asyncio.create_task(guardrail_fail())

try:
    await asyncio.gather(agent_task, guardrail_task)
    result = await agent_task
    print(result)
except GuardrailException as e:
    print(e.info)
    agent_task.cancel()
    try:
        await agent_task
    except asyncio.CancelledError:
        print("[main] Agent cancelled")

[agent] start
[guardrail fail] start
[agent] 0
[agent] 1
[guardrail fail] check fails
GuardrailFunctionOutput(output_info='check fails', tripwire_triggered=True)
[main] Agent cancelled


In [15]:
# Goal - run guradrail and agent parallel
# if the guardrail fail, all the agent stop
# If the guardrail pass, agent keeps running

In [15]:
async def run_with_guardrails(agent_coroutine, guardrails):
    """
    Run `agent_coroutine` while multiple guardrails monitor it.

    Parameters:
        agent_coroutine: an *awaitable*, e.g. agent()
        guardrails: an iterable of *awaitables*, e.g. [guard1(), guard2()]

    Returns:
        The result of the agent, if no guardrail triggers.

    Raises:
        GuardrailException from any guardrail.
    """

    agent_task = asyncio.create_task(agent_coroutine)
    guard_tasks = [asyncio.create_task(g) for g in guardrails]

    try:
        # If any guardrail raises GuardrailException,
        # gather will throw and we drop into except.
        await asyncio.gather(agent_task, *guard_tasks)

        # Agent finished successfully.
        return agent_task.result()

    except GuardrailException as e:
        # At least one guardrail fired.
        print("[guardrail fired]", e.info)

        # Cancel the agent.
        agent_task.cancel()
        try:
            await agent_task
        except asyncio.CancelledError:
            print("[run_with_guardrails] agent cancelled")

        # Cancel all guardrails (they may still be running).
        for t in guard_tasks:
            t.cancel()
        await asyncio.gather(*guard_tasks, return_exceptions=True)

        raise


In [16]:
result = await run_with_guardrails(
    agent(),
    [guardrail()]
)


[agent] start
[guardrail] start
[agent] 0
[guardrail] stop
[agent] 1
[agent] 2
[agent] 3
[agent] 4
[agent] 5
[agent] 6
[agent] 7
[agent] 8
[agent] 9
[agent] end


In [17]:
import ver31
import search_agent
agent = search_agent.create_agent()

In [None]:
result = await run_with_guardrails(
    ver31.run(agent, 'llm as a judge'),
    [guardrail()] #, guardrail_fail()]
)

[guardrail fail] start
[guardrail fail] check fails
[guardrail fired] GuardrailFunctionOutput(output_info='check fails', tripwire_triggered=True)
[run_with_guardrails] agent cancelled


GuardrailException: check fails