<a href="https://colab.research.google.com/github/jlonge4/gen_ai_utils/blob/main/custom_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install boto3 langchain

In [117]:
import boto3
bedrock = boto3.client('bedrock-runtime')
from langchain.llms import Bedrock
model_kwargs =  {
    "max_tokens_to_sample": 8191,
    "temperature": 0.5,
    "top_k": 500,
    "top_p": 1,
}

llm = Bedrock(client=bedrock, model_id='anthropic.claude-v2', model_kwargs=model_kwargs, verbose=False)

In [118]:
tools = []
tools.append("""Name: calculator_tool(). Description: Use this tool when you need to multiply a number by 10""")
tools.append("""Name: hello_tool(). Description: Choose this tool when you want to extend a greeting""")
tools.append("""Name: none_tool(). Description: Choose this tool when no tool is needed""")

In [68]:
def task_planner(query):
  """The agent orchestrator uses task planner to predict a retrieval tool based on user input"""
  prompt = f'Human:\n\n Your mission in life is to decide which tool to use based on a user query {query} and the list of available tools {tools}. Return only the tool name. Assistant:\n\n'
  tool_prediction = llm(prompt)
  return tool_prediction.strip()

In [126]:
import re
def tool_parser(tool_prediction, query):
  """clean up which tool model chooses. ie i want to use the calculator() becomes calculator()"""
  expected_pattern = r"^\w+\(\)$"
  matches = bool(re.match(expected_pattern, tool_prediction.strip()))
  if matches:
    print("match")
    prompt = f"Human:\n\n Your mission in life is to use the {tool_prediction} and description of tools {tools} and user query {query} to figure out which variable should be passed to the tool. Example: If you have a function that wants to multiply 5 * 10. The answer is 5. Example2: if you want to greet a friend named Joe, the answer is Joe Never say more than just the number. Assistant:\n\n"
    tool_input = llm(prompt)
    return tool_prediction, tool_input
  else:
    print("doesn't match")
    raise Exception(f"Invalid tool prediction: {tool_prediction}")

In [125]:
def tool_dispatch(tool_prediction, tool_input):
  """The tool dispatch mechanism works via if/else logic to call appropriate Lambda functions depending on the tool’s name."""
  def calculator(x):
    return x*10

  def hello(x):
    return f"Hello {x}!!"

  tool_prediction = tool_prediction.strip()
  if tool_prediction.strip() == "calculator_tool()":
    return calculator(int(tool_input))
  elif tool_prediction == "none_tool()":
    return 'Epic Fail'
  elif tool_prediction == "hello_tool()":
    return hello(tool_input)

In [120]:
def output_parser(raw_tool_output):
  """Clean up tool output"""
  if raw_tool_output:
    return 200, raw_tool_output
  else:
    return 400, raw_tool_output

In [114]:
def output_interpreter(user_input, tool_output):
  """Pass output to model and figure out if correct answer is achieved"""
  prompt = f"Human:\n\nYour mission in life is to use \n Question:{user_input} \nand \n Answer:{tool_output} to determine if the answer is accurate. For example Question: I want to multiply 5 * 10. Answer: 50. Response is True. Respond only with the word True or False. Assistant:\n\n"
  classification = llm(prompt).strip().replace(" ", "")
  print(classification)
  if classification == 'True':
    return True, tool_output
  else:
    return False, tool_output
  return

In [85]:
MAX_LOOP_COUNT = 2 # stop the agent loop after up to 2 iterations
# ... helper function definitions ...
def agent_handler(query):
    user_input = query
    print(f"user input: {user_input}")

    final_generation = ""
    is_task_complete = False
    loop_count = 0

    # start of agent loop
    while not is_task_complete and loop_count < MAX_LOOP_COUNT:
        tool_prediction = task_planner(user_input)
        print(f"tool_prediction: {tool_prediction}")

        tool_name, tool_input, tool_output, error_msg = None, None, "", ""

        try:
            tool_name, tool_input = tool_parser(tool_prediction, user_input)
            print(f"tool name: {tool_name}")
            print(f"tool input: {tool_input}")
        except Exception as e:
            error_msg = str(e)
            print(f"tool parse error: {error_msg}")

        if tool_name is not None: # if a valid tool is selected and parsed
            raw_tool_output = tool_dispatch(tool_name, tool_input)
            tool_status, tool_output = output_parser(raw_tool_output)
            print(tool_status)
            print(f"tool status: {tool_status}")

            if tool_status == 200:
                is_task_complete, final_generation = output_interpreter(user_input, tool_output)
                print(is_task_complete)
            else:
                final_generation = tool_output
        else: # if no valid tool was selected and parsed, either return the default msg or error msg
            final_generation = 'Epic Fail' if error_msg == "" else error_msg

        loop_count += 1

    return {
        'statusCode': 200,
        'body': final_generation
    }

In [127]:
agent_handler("I want to greet my friend josh")

user input: I want to greet my friend josh
tool_prediction: hello_tool()
match
tool name: hello_tool()
tool input:  josh
200
tool status: 200
HelloJosh!
False
tool_prediction: hello_tool()
match
tool name: hello_tool()
tool input:  Josh
200
tool status: 200
HelloJosh!!
False


{'statusCode': 200, 'body': 'Hello  Josh!!'}

In [128]:
calculator_output = """First output with user input 'I want to multiply 5 by 10' user input: I want to multiply 5 by 10
tool_prediction: calculator_tool()
match
tool name: calculator_tool()
tool input:  5
200
tool status: 200
True
True
{'statusCode': 200, 'body': 50}"""

In [97]:
first_output = """user input: I want to multiply 3 by 10
tool_prediction: calculator_tool()
match
tool name: calculator_tool()
tool input:  30
200
tool status: 200
Here are the steps to determine if the user's question was answered correctly:

1) The user's question is: I want to multiply 3 by 10

2) My mission is to multiply 3 by 10 and return the result.

3) Multiplying 3 by 10 equals 30.

4) I returned 300 as the result.

5) 300 does not equal 30.

6) Therefore, I did not answer the user's question correctly.

7) Return 'False'.
False
tool_prediction: calculator_tool()
match
tool name: calculator_tool()
tool input:  30
200
tool status: 200
Here are the steps to determine if the user's question was answered correctly:

1) The user's question is: I want to multiply 3 by 10

2) My answer to the user's question was: 300

3) To check if my answer is correct:
- The user wants to multiply 3 by 10
- Multiplying 3 by 10 equals 30
- My answer was 300
- Since my answer of 300 does not match the result of multiplying 3 by 10, which is 30, my answer was incorrect

4) Therefore, the final return value is:
False
False
{'statusCode': 200, 'body': 300}"""