In [2]:
!guardrails hub install hub://guardrails/valid_choices --quiet

Installing hub:[35m/[0m[35m/guardrails/[0m[95mvalid_choices...[0m
✅Successfully installed guardrails/valid_choices!




# Enforcing Guardrails on Choice Selection

!!! note
    To download this tutorial as a Jupyter notebook, click [here](https://github.com/guardrails-ai/guardrails/blob/main/docs/examples/select_choice_based_on_action.ipynb).

In this example, we want the LLM to pick an action (e.g. `fight` or `flight`), and based on that action we want to return different JSON objects. For example, if the action is `fight`, we want to return a JSON object that contains the `weapon` field. If the action is `flight`, we want to return a JSON object that contains the `direction` and `distance` fields.

We make the assumption that:

1. We don't need any external libraries that are not already installed in the environment.
2. We are able to execute the code in the environment.

## Objective

We want the LLM to play an RP game where it can choose to either `fight` or `flight`. If it chooses to `fight`, the LLM should choose a `weapon` and an `enemy`. If the player chooses `flight`, the LLM should choose a `direction` and a `distance`.

## Step 1: Generating `RAIL` Spec

Ordinarily, we could create a separate `RAIL` spec in a file. However, for the sake of this example, we will generate the `RAIL` spec in the notebook as a string or a Pydantic Model.

XML option:

In [9]:
rail_str = """
<rail version="0.1">

<output>
    <choice name="action" on-fail-choice="reask" discriminator="chosen_action">
        <case name="fight">
            <string name="weapon" format="valid-choices: {['crossbow', 'machine gun']}" on-fail-valid-choices="reask" />
        </case>
        <case name="flight">
            <string name="flight_direction" format="valid-choices: {['north','south','east','west']}" on-fail-valid-choices="exception" />
            <integer name="distance" format="valid-choices: {[1,2,3,4]}" on-fail-valid-choices="exception" />
        </case>
    </choice>
</output>
<messages>
<message role="user">
You are a human in an enchanted forest. You come across opponents of different types, and you should fight smaller opponents and run away from bigger ones.

You run into a ${opp_type}. What do you do?

${gr.complete_xml_suffix_v2}
</message>
</messages>
</rail>
"""

Pydantic model option:

In [10]:
from guardrails.hub import ValidChoices
from pydantic import BaseModel, Field
from typing import Literal, Union

prompt = """
You are a human in an enchanted forest. You come across opponents of different types, and you should fight smaller opponents and run away from bigger ones.

You run into a ${opp_type}. What do you do?

${gr.complete_xml_suffix_v2}"""

class Fight(BaseModel):
    chosen_action: Literal['fight']
    weapon: str = Field(validators=[ValidChoices(['crossbow', 'machine gun'], on_fail="reask")])

class Flight(BaseModel):
    chosen_action: Literal['flight']
    flight_direction: str = Field(validators=[ValidChoices(['north','south','east','west'], on_fail="exception")])
    distance: int = Field(validators=[ValidChoices([1,2,3,4], on_fail="exception")])

class FightOrFlight(BaseModel):
    action: Union[Fight, Flight] = Field(discriminator='chosen_action')
    

## Step 2: Create a `Guard` object with the RAIL Spec

We create a `gd.Guard` object that will check, validate and correct the generated code. This object:

1. Enforces the quality criteria specified in the RAIL spec (i.e. bug free code).
2. Takes corrective action when the quality criteria are not met (i.e. reasking the LLM).
3. Compiles the schema and type info from the RAIL spec and adds it to the prompt.

From XML:

In [11]:
import guardrails as gd

from rich import print

guard = gd.Guard.for_rail_string(rail_str)

Or from Pydantic:

In [12]:
import guardrails as gd

from rich import print

guard = gd.Guard.for_pydantic(output_class=FightOrFlight)

The `Guard` object compiles the output schema and adds it to the prompt.

## Step 3: Wrap the LLM API call with `Guard`

We can now wrap the LLM API call with the `Guard` object. This will ensure that the LLM generates an output that is compliant with the RAIL spec.

To start, we test with a 'giant' as an opponent, and look at the output.

In [21]:
# Set your OPENAI_API_KEY as an environment variable
# import os
# os.environ["OPENAI_API_KEY"] = "YOUR_API_KEY"

raw_llm_response, validated_response, *rest = guard(
    messages=[{"role":"user", "content": prompt}],
    prompt_params={'opp_type': 'giant'},
    model="gpt-4o-mini",
    max_tokens=256,
    temperature=0.0,
)



Running the cell above returns:
1. The raw LLM text output as a single string.
2. A dictionary where the key is `python_code` and the value is the generated code.

We can see that if the LLM chooses `flight`, the output is a dictionary with `flight_direction` and `distance` fields.

In [22]:
print(validated_response)

We can also see the final prompt below:

In [23]:
print(guard.history.last.iterations.first.inputs.messages[0]["content"])

We can inspect the logs of the guard object to see the quality criteria that were checked and the corrective actions that were taken.

In [24]:
print(guard.history.last.tree)

Now, let's test with a `goblin` as an opponent.

We can see that the LLM chose to `fight` and the output is a choice of `weapon`.

In [27]:
raw_llm_response, validated_response, *rest = guard(
    messages=[{"role":"user", "content": prompt}],
    prompt_params={'opp_type': 'goblin'},
    model="gpt-4o-mini",
    max_tokens=256,
    temperature=0.0,
)



In [28]:
print(validated_response)

We can inspect the history of the guard after each call to see what happened.

In [29]:
print(guard.history.last.tree)