In [8]:
import logging
import json
from copy import deepcopy
from typing import Dict, List, Union, Optional
from xml.etree import ElementTree as ET

import rich

from guardrails.x_schema import XSchema
from guardrails.prompt_repo import Prompt
from guardrails.utils.docs_utils import read_pdf

In [3]:
base_prompt_template = Prompt("""Given the following document, answer the following questions. If the answer doesn't exist in the document, enter 'None'.

{document}""")

schema = XSchema.from_xml('guardrails/prompt.xml', base_prompt_template)



In [10]:
# # Get all elements in `schema.parsed_xml`
# elements = schema.parsed_xml.findall(".//*")
# # Print information about each element
# for element in elements:
#     print(f"Element: {element.tag}")
#     print(f"Attributes: {element.attrib}")
#     print(f"Text: {element.text}")
#     print()

# explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object")
explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object/string[@name='explanation']")

# Get the element that has the property `name` with the value `explanation`
# and is the descendant of elements `list` with name 'fees' and `object` (no name)
explanation_element = schema.parsed_xml.find(".//list[@name='fees']/object/string[@name='explanation']")

# Print information about the element
print(f"Element: {explanation_element.tag}")

Element: string


In [11]:
schema.schema['fees'].children['item'].children['name'].validators[0].on_fail

<bound method Validator.noop of <guardrails.x_validators.LowerCase object at 0x7fec28919b80>>

In [12]:
rich.print(schema.prompt)

In [6]:
from guardrails.x_schema import extract_prompt_from_xml

rich.print(extract_prompt_from_xml(schema.parsed_xml))

In [13]:
def get_pruned_tree(
    root: Union[ET.ElementTree, ET.Element],
    reask_elements: Optional[List[ET.Element]] = None,
) -> str:
    """Prune tree of any elements that are not in `reasks`.

    Return the tree with only the elements that are keys of `reasks` and their
    parents. If `reasks` is None, return the entire tree. If an element is
    removed, remove all ancestors that have no children.

    Args:
        root: The XML tree.
        reasks: The elements that are to be reasked.

    Returns:
        The prompt.
    """

    # Prune the tree of any elements that are not keys of `reasks`.
    # 

    if reask_elements is None:
        return root

    # Get all elements in `root`
    elements = root.findall(".//*")
    # Print information about each element
    for element in elements:
        if element not in reas:
            parent = element.getparent()
            parent.remove(element)

            # Remove all ancestors that have no children
            while len(parent) == 0:
                grandparent = parent.getparent()
                grandparent.remove(parent)
                parent = grandparent
    
    return root


In [14]:
rich.print(get_pruned_tree(schema.parsed_xml))

In [5]:
# content = read_pdf('chase_card_agreement.pdf')
# validated_response, raw_response = schema.ask_with_validation(content[:6000])

In [7]:
# prompt = schema.prompt.format(document=content[:6000])
# response = schema.llm_ask(prompt)
# response_as_dict = json.loads(response)

# # Save response_as_dict to a file
# with open('response_as_dict.json', 'w') as f:
#     json.dump(response_as_dict, f, indent=4)

In [8]:
with open('response_as_dict.json', 'r') as f:
    response_as_dict = json.load(f)

In [9]:
validated_response = deepcopy(response_as_dict)

for field, value in validated_response.items():

    # print(f"Field: {field}")

    if field not in schema.schema:
        print(f"Field {field} not in schema")
        continue

    # print(f"Field {field} is in schema")
    validated_response = schema.schema[field].validate(
        field,
        value,
        validated_response
    )
    # print(f"Outcome: {validated_response}")





Validating annual membership is two words...
Validation outcome: True






Validating my chase plan is two words...
Validation outcome: False






Validating balance transfers is two words...
Validation outcome: True






Validating cash advances is two words...
Validation outcome: True






Validating foreign transactions is two words...
Validation outcome: True






Validating late payment is two words...
Validation outcome: True






Validating over-the-credit-limit is two words...
Validation outcome: False






Validating return payment is two words...
Validation outcome: True






Validating return check is two words...
Validation outcome: True





In [10]:
rich.print(validated_response)

In [19]:
rich.print(response_as_dict)

In [9]:
print(schema.prompt)

<prompt>
    <list name="fees" description="What fees and charges are associated with my account?">
        <object>
            <integer name="index" format="1-indexed" />
            <string name="name" format="lower-case; two-words" />
            <string name="explanation" format="one-line" />
            <float name="value" format="percentage" />
            <string name="description" format="length: 0 200" />
            <string name="example" required="True" format="tone-twitter explain-high-quality" />
            <string name="advertisement" format="tagline tv-ad" />
        </object>
    </list>
    <string name="interest_rates" description="What are the interest rates offered by the bank on savings and checking accounts, loans, and credit products?" />
    
</prompt>


In [7]:
print(raw_response)



The fees associated with your account include:

1. Maintenance Fee: A fee charged for the maintenance of your account. Value: 0.25%. Description: This fee is charged for the upkeep of your account and is usually charged on a monthly basis. Example: "A $5 monthly maintenance fee is charged for the upkeep of your account." Advertisement: "Keep your money safe with our low maintenance fees!"

2. Transaction Fee: A fee charged for each transaction made with your account. Value: 0.10%. Description: This fee is charged for each transaction made with your account and is usually charged on a per-transaction basis. Example: "A $0.50 fee is charged for each transaction made with your account." Advertisement: "Make transactions with ease and low fees!"

3. ATM Fee: A fee charged for using an ATM. Value: 0.15%. Description: This fee is charged for using an ATM and is usually charged on a per-transaction basis. Example: "A $2 fee is charged for each ATM transaction made with your account." Advert

In [9]:
import openai

with open('openai_api_key.txt', 'r') as f:
    openai_api_key = f.read()

models_list = openai.Model.list(api_key=openai_api_key)


In [12]:
for model in models_list:
    print(model)

object
data


In [25]:
for model in models_list['data']:
    print(model['id'])

babbage
davinci
gpt-3.5-turbo-0301
text-davinci-003
babbage-code-search-code
text-similarity-babbage-001
text-davinci-001
ada
curie-instruct-beta
babbage-code-search-text
babbage-similarity
gpt-3.5-turbo
code-davinci-002
code-search-babbage-text-001
text-embedding-ada-002
code-cushman-001
whisper-1
code-search-babbage-code-001
audio-transcribe-deprecated
text-ada-001
text-similarity-ada-001
text-davinci-insert-002
ada-code-search-code
ada-similarity
code-search-ada-text-001
text-search-ada-query-001
text-curie-001
text-davinci-edit-001
davinci-search-document
ada-code-search-text
text-search-ada-doc-001
code-davinci-edit-001
davinci-instruct-beta
text-similarity-curie-001
code-search-ada-code-001
ada-search-query
text-search-davinci-query-001
curie-search-query
davinci-search-query
text-davinci-insert-001
babbage-search-document
ada-search-document
text-search-curie-query-001
text-search-babbage-doc-001
text-davinci-002
curie-search-document
text-search-curie-doc-001
babbage-search-que