In [26]:
import copy
import logging
import json
from collections import defaultdict
from typing import Union, Dict, List, Optional, Any
from lxml import etree as ET

import rich

from guardrails.x_datatypes import registry as types_registry

from guardrails.x_schema import XSchema
from guardrails.prompt_repo import Prompt
from guardrails.utils.docs_utils import read_pdf

In [2]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [19]:
base_prompt_template = Prompt("""Given the following document, answer the following questions. If the answer doesn't exist in the document, enter 'None'.

{document}""")

schema = XSchema.from_xml('guardrails/prompt.xml', base_prompt_template)



In [20]:
# # Get all elements in `schema.parsed_xml`
# elements = schema.parsed_xml.findall(".//*")
# # Print information about each element
# for element in elements:
#     print(f"Element: {element.tag}")
#     print(f"Attributes: {element.attrib}")
#     print(f"Text: {element.text}")
#     print()

# explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object")
explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object/string[@name='explanation']")

# Get the element that has the property `name` with the value `explanation`
# and is the descendant of elements `list` with name 'fees' and `object` (no name)
explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object/string[@name='explanation']")

# Print information about the element
print(f"Element: {explanation_element.tag}")

Element: string


In [21]:
obj = {
    "fees": [
        {
            "index": 1,
            "name": "annual membership",
            "explanation": "annual fee for membership",
            "value": 0,
            "description": None,
            "example": None,
            "advertisement": None,
        },
        {
            "index": 2,
            "name": "my chase plan",
            "explanation": None,
            "value": 1.72,
            "description": """Monthly fee of 0% of the amount of each eligible purchase transaction or amount selected to 
create a My Chase Plan while in the 0% Intro Purchase APR period. After that, monthly fee of 1.72% of the amount of
each eligible purchase transaction or amount selected to create a My Chase Plan. The My Chase Plan Fee will be 
determined at the time each My Chase Plan is created and will remain the same until the My Chase Plan is paid in 
full.""",
            "example": None,
            "advertisement": None,
        },
        {
            "index": 3,
            "name": "balance transfers",
            "explanation": "intro fee of either $5 or 3% of the amount of each transfer, whichever is greater",
            "value": 3,
            "description": """Intro fee of either $5 or 3% of the amount of each transfer, whichever is greater, on 
transfers made within 60 days of account opening. After that: Either $5 or 5% of the amount of each transfer, 
whichever is greater.""",
            "example": "Transfer $100 and pay a fee of $3.",
            "advertisement": "Transfer your balance and pay only a 3% fee!",
        },
        {
            "index": 4,
            "name": "cash advances",
            "explanation": "either $10 or 5% of the amount of each transaction, whichever is greater",
            "value": 5,
            "description": "Either $10 or 5% of the amount of each transaction, whichever is greater.",
            "example": "Withdraw $100 and pay a fee of $5.",
            "advertisement": "Withdraw cash and pay only a 5% fee!",
        },
        {
            "index": 5,
            "name": "foreign transactions",
            "explanation": "3% of the amount of each transaction in U.S. dollars",
            "value": 3,
            "description": None,
            "example": "Make a purchase of $100 and pay a fee of $3.",
            "advertisement": "Make purchases abroad and pay only a 3% fee!",
        },
        {
            "index": 6,
            "name": "late payment",
            "explanation": "up to $40",
            "value": 0,
            "description": "Up to $40.",
            "example": "Make a late payment and pay a fee of up to $40.",
            "advertisement": "Make a late payment and pay only up to $40!",
        },
        {
            "index": 7,
            "name": "over-the-credit-limit",
            "explanation": None,
            "value": 0,
            "description": None,
            "example": None,
            "advertisement": None,
        },
        {
            "index": 8,
            "name": "return payment",
            "explanation": "up to $40",
            "value": 0,
            "description": "Up to $40.",
            "example": "Make a return payment and pay a fee of up to $40.",
            "advertisement": "Make a return payment and pay only up to $40!",
        },
        {
            "index": 9,
            "name": "return check",
            "explanation": None,
            "value": 0,
            "description": None,
            "example": None,
            "advertisement": None,
        },
    ],
    "interest_rates": None,
}

nones = []


def _gather_reasks(response: Union[list, dict], path: List[str] = []):
    if isinstance(response, dict):
        iterable = response.items()
    elif isinstance(response, list):
        iterable = enumerate(response)
    else:
        raise ValueError(f"Expected dict or list, got {type(response)}")
    for field, value in iterable:
        if value is None:
            nones.append((path + [field], value))

        if isinstance(value, dict):
            _gather_reasks(value, path + [field])

        if isinstance(value, list):
            for idx, item in enumerate(value):
                if item is None:
                    nones.append((path + [field, idx], item))
                else:
                    _gather_reasks(item, path + [field, idx])

_gather_reasks(obj)
rich.print(nones)

In [22]:
rich.print(nones)

In [84]:
parsed_xml_copy = copy.deepcopy(schema.parsed_xml)

In [85]:
"""
[
    (['fees', 0, 'description'], None),
    (['fees', 0, 'example'], None),
    (['fees', 0, 'advertisement'], None),
    (['fees', 1, 'explanation'], None),
    (['fees', 1, 'example'], None),
    (['fees', 1, 'advertisement'], None),
    (['fees', 4, 'description'], None),
    (['fees', 6, 'explanation'], None),
    (['fees', 6, 'description'], None),
    (['fees', 6, 'example'], None),
    (['fees', 6, 'advertisement'], None),
    (['fees', 8, 'explanation'], None),
    (['fees', 8, 'description'], None),
    (['fees', 8, 'example'], None),
    (['fees', 8, 'advertisement'], None),
    (['interest_rates'], None)
]
"""

# explanation
things = [
    (['fees', 0, 'description'], None),
    (['fees', 0, 'example'], None),
    (['fees', 0, 'advertisement'], None),
    (['fees', 1, 'explanation'], None),
    (['fees', 1, 'example'], None),
    (['fees', 1, 'advertisement'], None),
    (['fees', 4, 'description'], None),
    (['fees', 6, 'explanation'], None),
    (['fees', 6, 'description'], None),
    (['fees', 6, 'example'], None),
    (['fees', 6, 'advertisement'], None),
    (['fees', 8, 'explanation'], None),
    (['fees', 8, 'description'], None),
    (['fees', 8, 'example'], None),
    (['fees', 8, 'advertisement'], None),
    (['interest_rates'], None)
]

nones_by_element = defaultdict(list)

for path, value in things:
    # Make a find query for each path: replace int values in path with '*'
    query = "."
    for part in path:
        if isinstance(part, int):
            query += "/*"
        else:
            query += f"/*[@name='{part}']"

    # Find the element
    element = parsed_xml_copy.find(query)
    # Print the element
    nones_by_element[element].append((path, value))

rich.print(dict(nones_by_element))

In [86]:
def get_pruned_tree(
    root: ET._Element,
    reask_elements: List[ET._Element] = None,
) -> str:
    """Prune tree of any elements that are not in `reasks`.

    Return the tree with only the elements that are keys of `reasks` and their
    parents. If `reasks` is None, return the entire tree. If an element is
    removed, remove all ancestors that have no children.

    Args:
        root: The XML tree.
        reasks: The elements that are to be reasked.

    Returns:
        The prompt.
    """
    if reask_elements is None:
        return root

    # Get all elements in `root`
    elements = root.findall(".//*")
    for element in elements:
        if (element not in reask_elements) and len(element) == 0:
            parent = element.getparent()
            parent.remove(element)

            # Remove all ancestors that have no children
            while len(parent) == 0:
                grandparent = parent.getparent()
                grandparent.remove(parent)
                parent = grandparent

    return root

pruned_tree = get_pruned_tree(parsed_xml_copy, list(nones_by_element.keys()))


 /response/list fees
False
False

 /response/list/object None
False
False

 /response/list/object/integer index
False
True
element: <Element integer at 0x7fef28ad9b00>
parent: <Element object at 0x7fef28ad9800>
Pre removal parent len: 7
Post removal parent len: 6

 /response/list/object/string[1] name
False
True
element: <Element string at 0x7fef28ad9380>
parent: <Element object at 0x7fef28ad9800>
Pre removal parent len: 6
Post removal parent len: 5

 /response/list/object/string[1] explanation
True
True

 /response/list/object/float value
False
True
element: <Element float at 0x7fef28ad9240>
parent: <Element object at 0x7fef28ad9800>
Pre removal parent len: 5
Post removal parent len: 4

 /response/list/object/string[2] description
True
True

 /response/list/object/string[3] example
True
True

 /response/list/object/string[4] advertisement
True
True

 /response/string[1] interest_rates
True
True

 /response/string[2] follow_up_url
False
True
element: <Element string at 0x7fef28ad9e00>

In [97]:
list(list(list(parsed_xml_copy)[0])[0])[3].attrib

{'name': 'advertisement', 'format': 'tagline tv-ad'}

In [79]:
len(list(nones_by_element.keys())[0])

0

In [82]:
list(list(list(schema.parsed_xml)[0])[0])[4] in list(nones_by_element.keys())

True

In [53]:
dir(element)

['__bool__',
 '__class__',
 '__contains__',
 '__copy__',
 '__deepcopy__',
 '__delattr__',
 '__delitem__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_init',
 'addnext',
 'addprevious',
 'append',
 'attrib',
 'base',
 'clear',
 'cssselect',
 'extend',
 'find',
 'findall',
 'findtext',
 'get',
 'getchildren',
 'getiterator',
 'getnext',
 'getparent',
 'getprevious',
 'getroottree',
 'index',
 'insert',
 'items',
 'iter',
 'iterancestors',
 'iterchildren',
 'iterdescendants',
 'iterfind',
 'itersiblings',
 'itertext',
 'keys',
 'makeelement',
 'nsmap',
 'prefix',
 'remove',
 'replace',
 'set',
 'sourceline',
 'tag',
 'tail',
 'text',
 'values',
 'xpath']

In [56]:
element.getroottree().getpath(element)

'/response/string[1]'

In [40]:
list(pruned_tree)

[<!-- <string name='fees' description='What fees and charges are associated with my account?' format="max-len: 5; explain-like-im-five; valid-choices: {[0,5,10]}"/> -->]

In [36]:
list(list(list(list(schema.parsed_xml)[0])[0])[0])

[]

In [24]:
rich.print(
    schema.schema,
)

In [25]:
schema.schema['fees'].validators, schema.schema['fees'].children['item'].children['name'].validators, schema.schema['interest_rates'].children

([],
 [<guardrails.x_validators.LowerCase at 0x7fa4989c7a00>,
  <guardrails.x_validators.TwoWords at 0x7fa4989c7190>],
 {})

In [26]:
print(schema.prompt)

Given the following document, answer the following questions. If the answer doesn't exist in the document, enter 'None'.

{document}

Given below is XML that describes the information to extract from this document and the tags to extract it into.

<prompt>
    <list name="fees" description="What fees and charges are associated with my account?">
        <object>
            <integer name="index" format="1-indexed" />
            <string name="name" format="lower-case; two-words" />
            <string name="explanation" format="one-line" />
            <float name="value" format="percentage" />
            <string name="description" format="length: 0 200" />
            <string name="example" required="True" format="tone-twitter explain-high-quality" />
            <string name="advertisement" format="tagline tv-ad" />
        </object>
    </list>
    <string name="interest_rates" description="What are the interest rates offered by the bank on savings and checking accounts, loans, and

In [27]:
content = read_pdf('chase_card_agreement.pdf')

8763142813423

In [28]:
response, response_as_dict, validated_response = schema.ask_with_validation(content[:6000])

KeyboardInterrupt: 

In [None]:
rich.print(response), rich.print(response_as_dict), rich.print(validated_response)

(None, None, None)

In [6]:
# prompt = schema.prompt.format(document=content[:6000])
# response = schema.llm_ask(prompt)
# response_as_dict = json.loads(response)

# # Save response_as_dict to a file
# with open('response_as_dict.json', 'w') as f:
#     json.dump(response_as_dict, f, indent=4)

In [7]:
with open('response_as_dict.json', 'r') as f:
    response_as_dict = json.load(f)


In [14]:
for field, value in response_as_dict.items():

    print(f"Field: {field}")

    if field not in schema.schema:
        print(f"Field {field} not in schema")
        continue

    print(f"Field {field} is in schema")
    outcome = schema.schema[field].validate(value)
    print(f"Outcome: {outcome}")


DEBUG:guardrails.x_validators:Validating annual membership is lower case...
DEBUG:guardrails.x_validators:Validation outcome: True
DEBUG:guardrails.x_validators:Validating annual membership is two words...
DEBUG:guardrails.x_validators:Validation outcome: True
DEBUG:guardrails.x_validators:Validating annual fee for membership is a single line...
DEBUG:guardrails.x_validators:Validation outcome: True
DEBUG:guardrails.x_validators:Validating None is in length range 0 - 200...
DEBUG:guardrails.x_validators:Value None is in range 0 - 200.
DEBUG:guardrails.x_validators:Validating my chase plan is lower case...
DEBUG:guardrails.x_validators:Validation outcome: True
DEBUG:guardrails.x_validators:Validating my chase plan is two words...
DEBUG:guardrails.x_validators:Validation outcome: False


Field: fees
Field fees is in schema
Outcome: False
Field: interest_rates
Field interest_rates is in schema
Outcome: True


In [9]:
print(schema.prompt)

<prompt>
    <list name="fees" description="What fees and charges are associated with my account?">
        <object>
            <integer name="index" format="1-indexed" />
            <string name="name" format="lower-case; two-words" />
            <string name="explanation" format="one-line" />
            <float name="value" format="percentage" />
            <string name="description" format="length: 0 200" />
            <string name="example" required="True" format="tone-twitter explain-high-quality" />
            <string name="advertisement" format="tagline tv-ad" />
        </object>
    </list>
    <string name="interest_rates" description="What are the interest rates offered by the bank on savings and checking accounts, loans, and credit products?" />
    
</prompt>


In [7]:
print(raw_response)



The fees associated with your account include:

1. Maintenance Fee: A fee charged for the maintenance of your account. Value: 0.25%. Description: This fee is charged for the upkeep of your account and is usually charged on a monthly basis. Example: "A $5 monthly maintenance fee is charged for the upkeep of your account." Advertisement: "Keep your money safe with our low maintenance fees!"

2. Transaction Fee: A fee charged for each transaction made with your account. Value: 0.10%. Description: This fee is charged for each transaction made with your account and is usually charged on a per-transaction basis. Example: "A $0.50 fee is charged for each transaction made with your account." Advertisement: "Make transactions with ease and low fees!"

3. ATM Fee: A fee charged for using an ATM. Value: 0.15%. Description: This fee is charged for using an ATM and is usually charged on a per-transaction basis. Example: "A $2 fee is charged for each ATM transaction made with your account." Advert

In [9]:
import openai

with open('openai_api_key.txt', 'r') as f:
    openai_api_key = f.read()

models_list = openai.Model.list(api_key=openai_api_key)


In [12]:
for model in models_list:
    print(model)

object
data


In [25]:
for model in models_list['data']:
    print(model['id'])

babbage
davinci
gpt-3.5-turbo-0301
text-davinci-003
babbage-code-search-code
text-similarity-babbage-001
text-davinci-001
ada
curie-instruct-beta
babbage-code-search-text
babbage-similarity
gpt-3.5-turbo
code-davinci-002
code-search-babbage-text-001
text-embedding-ada-002
code-cushman-001
whisper-1
code-search-babbage-code-001
audio-transcribe-deprecated
text-ada-001
text-similarity-ada-001
text-davinci-insert-002
ada-code-search-code
ada-similarity
code-search-ada-text-001
text-search-ada-query-001
text-curie-001
text-davinci-edit-001
davinci-search-document
ada-code-search-text
text-search-ada-doc-001
code-davinci-edit-001
davinci-instruct-beta
text-similarity-curie-001
code-search-ada-code-001
ada-search-query
text-search-davinci-query-001
curie-search-query
davinci-search-query
text-davinci-insert-001
babbage-search-document
ada-search-document
text-search-curie-query-001
text-search-babbage-doc-001
text-davinci-002
curie-search-document
text-search-curie-doc-001
babbage-search-que

In [18]:
from dataclasses import dataclass

@dataclass
class Foo(BaseException):

    key: str
    value: list
    error: BaseException = None
    
    def __str__(self) -> str:
        return f"Error {self.error} with key {self.key} and value {self.value}."

Foo('key', [1, 2, 3])

Foo(key='key', value=[1, 2, 3], error=None)

In [19]:
raise Foo('key', [1, 2, 3], ValueError('My error'))

TypeError: exceptions must derive from BaseException