In [21]:
!pip install anthropic



In [22]:
import anthropic
import time
import random
from google.colab import userdata

## Errors

In [23]:
class TimeoutError(Exception):
  """timeout error -- can retry"""
  pass

class RateLimitError(Exception):
  pass

## Tools

In [24]:
search_papers_tool = {
    "name": "search_papers",
    "description": "Search for academic papers by the given keyword query. Returns the list of paper titles and ids. "
    + "Use this tool to find relevant papers on a topic. ",
    "input_schema": {
        "type": "object",
        "properties": {
            "query": {
                "type": "string",
                "description": "The search keyword/topic, e.g. 'transformer', 'agent'"
            },
            "max_results": {
                "type": "integer",
                "description": "The maximum number of results to return. Defaults to 5.",
                "default": 5
            }
        },
        "required": ["query"]
    }
}

In [25]:
def search_papers(query: str, max_results: int=5) -> str:
  time.sleep(0.5)

  if random.random() < 0.2:
    raise TimeoutError("Search API call timed out")

  if len(query) < 3:
    raise ValueError("Query must be at least 3 characters long")

  papers = {
      "transformers": ["Paper: 'Attention Is All You Need' (ID: 1706.03762)",
                        "Paper: 'BERT: Pre-training' (ID: 1810.04805)"],
      "agents": ["Paper: 'ReAct: Reasoning and Acting' (ID: 2210.03629)",
                  "Paper: 'AutoGPT: Autonomous Agents' (ID: 2306.02224)"],
      "rag": ["Paper: 'Retrieval-Augmented Generation' (ID: 2005.11401)",
              "Paper: 'Dense Passage Retrieval' (ID: 2004.04906)"]
  }
  for k, v in papers.items():
    if query.lower() in k:
      return "\n".join(v[:max_results])
  else:
    return "No papers found for query"

In [26]:
get_paper_details_tool = {
    "name": "get_paper_details",
    "description": "Get the full details of an academic paper by the paper ID. Returns the title, authors and abstract of the paper. "
    + "Use this tool to get more details of a specific paper after search_papers. ",
    "input_schema": {
        "type": "object",
        "properties": {
            "paper_id": {
                "type": "string",
                "description": "The ID of the paper, e.g. '1706.03762'"
            },
        },
        "required": ["paper_id"]
    }
}


def get_paper_details(paper_id: str) -> str:
  time.sleep(0.3)

  if random.random() < 0.1:
    raise TimeoutError("Paper details API call timed out")

  details = {
          "1706.03762": "Title: Attention Is All You Need\nAuthors: Vaswani et al.\nAbstract: The dominant sequence transduction models...",
          "1810.04805": "Title: BERT\nAuthors: Devlin et al.\nAbstract: We introduce BERT...",
          "2210.03629": "Title: ReAct\nAuthors: Yao et al.\nAbstract: We explore using LLMs...",
  }

  if paper_id in details:
    return details[paper_id]
  else:
    return "Paper details not found"


In [27]:
summarize_text_tool = {
    "name": "summarize_text",
    "description": "Summarize the input text into a short paragraph that is at most max_words long. "
    + "Use this tool to summarize paper abstracts or details. ",
    "input_schema": {
        "type": "object",
        "properties": {
            "text": {
                "type": "string",
                "description": "The text to summarize"
            },
            "max_words": {
                "type": "integer",
                "description": "Maximum words in the summary",
                "default": 50
            },
        },
        "required": ["text"]
    }
}


def summarize_text(text: str, max_words: int = 50) -> str:
  time.sleep(0.2)

  if len(text) < 10:
    raise ValueError("Text must be at least 10 characters long for summarization")

  if random.random() < 0.05:
    raise RateLimitError("Rate limit exceeded. Try again later")

  return f"Summary: {text[:max_words]}... [truncated]"

In [28]:
details = get_paper_details("1706.03762")
print(summarize_text(details))

Summary: Title: Attention Is All You Need
Authors: Vaswani ... [truncated]


## Dispatcher

In [38]:
def process_tool_call(tool_name: str, tool_input: dict, max_retries: int = 3) -> str:
  retry_count = 0
  for attempt in range(max_retries):
    try:
      if tool_name == "search_papers":
        if "query" not in tool_input:
          raise ValueError("search_papers requires 'query' parameter")
        return search_papers(query=tool_input["query"], max_results=tool_input.get("max_results", 5))
      if tool_name == "get_paper_details":
        if "paper_id" not in tool_input:
          raise ValueError("get_paper_details requires 'paper_id' parameter")
        return get_paper_details(paper_id=tool_input["paper_id"])
      if tool_name == "summarize_text":
        if "text" not in tool_input:
          raise ValueError("summarize_text requires 'text' parameter")
        return summarize_text(text=tool_input["text"], max_words=tool_input.get("max_words", 50))
      return f"Error: unknown tool name {tool_name}"
    except TimeoutError as e:
      print(f" Attempt {attempt + 1}/{max_retries} Timeout: {e}")
      if attempt + 1 < max_retries:
        time.sleep(2 ** attempt)
      continue
    except RateLimitError as e:
      print(f" Attempt {attempt + 1}/{max_retries} Rate limited: {e}")
      if attempt + 1 < max_retries:
        time.sleep(3)
      continue
    except ValueError as e:
      return f"Error: {e}"
    except Exception as e:
      return f"Error: {e}"

  return "Error: max retries exceeded"


In [39]:
process_tool_call("search_papers", {})

"Error: search_papers requires 'query' parameter"

In [30]:
process_tool_call("search_papers", {"query": "transformer"})

"Paper: 'Attention Is All You Need' (ID: 1706.03762)\nPaper: 'BERT: Pre-training' (ID: 1810.04805)"

In [41]:
process_tool_call("get_paper_details", {"paper_id": "1706.03762"})

"Error: get_paper_details requires 'paper_id' parameter"

In [32]:
process_tool_call("summarize_text", {"text": "The dominant sequence transduction models are based on complex recurrent or"})

'Summary: The dominant sequence transduction models are base... [truncated]'

In [42]:
process_tool_call("summarize_text", {"text": "short"})

"Error: summarize_text requires 'text' parameter"

## Agent loop

In [None]:
def handle_query(query: str, max_iterations=10) -> str:
  api_key = userdata.get("ANTHROPIC_API_KEY")
  client=anthropic.Anthropic(api_key=api_key)

  tools = [search_papers_tool, get_paper_details_tool, summarize_text_tool]

  system_prompt = f"""
  You are a research assistant answering user's research question.
  You have the following tools available: {",".join(t["name"] for t in tools)}.

  - Use the available tool to find relevant information to answer the user's question. Generate a final reponse by incorporating the tool results.
  - When a tool returns an error, try an alternative approach or inform the user.
  - Only invoke tools that are absolutely needed to answer the user's query. Do not invoke unnecessary tools.
  - ONLY use the available tools. Do not call tools that are not provided to you.
  """

  messages = [
      {"role": "user", "content": query}
  ]

  iter = 0
  while iter < max_iterations:
    try:
      response = client.messages.create(
          max_tokens=1024,
          model="claude-3-haiku-20240307",
          messages=messages,
          system=system_prompt,
          tools=tools
      )
    except Exception as e:
      print(f"claude API error: {e}")
      break

    print(f"***LOG: received response {response}")
    messages.append({"role": "assistant", "content": response.content})

    if response.stop_reason == "tool_use":
      tool_calls = [block for block in response.content if block.type == "tool_use"]
      tool_results = []
      for tool_call in tool_calls:
        tool_output = process_tool_call(tool_call.name, tool_call.input)
        print(f"***LOG: tool call {tool_call.name}({tool_call.input}), result={tool_output}")
        tool_results.append({
            "type": "tool_result",
            "tool_use_id": tool_call.id,
            "content": tool_output
        })
      messages.append({"role": "user", "content": tool_results})
      continue

    if response.stop_reason == "end_turn":

      break
    else:
      print(f"WARNING: unexpected stop reason {response.stop_reason}")

  if iter >= max_iterations:
    print("WARNING: max iterations reached")

  final_text = next((block.text for block in response.content if block.type == "text"), "No response")
  return final_text, messages

## tests

In [None]:
def test_with_query(query: str):
  final_text, messages = handle_query(query)
  print("=" * 100)
  print(f"Final text: {final_text}")
  print("====== Message History =====")
  for i, m in enumerate(messages):
    print(f"{i}: {m}")
  print("=" * 50)

In [None]:
test_with_query("Find papers about transformers")

***LOG: received response Message(id='msg_01DkHhG4QbHN2EZWGWPUKyVf', content=[TextBlock(citations=None, text="Okay, let's search for papers on transformers.", type='text'), ToolUseBlock(id='toolu_01Ck4fwHegoqJFnnURLr7DTz', input={'query': 'transformers'}, name='search_papers', type='tool_use')], model='claude-3-haiku-20240307', role='assistant', stop_reason='tool_use', stop_sequence=None, type='message', usage=Usage(cache_creation=CacheCreation(ephemeral_1h_input_tokens=0, ephemeral_5m_input_tokens=0), cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=779, output_tokens=67, server_tool_use=None, service_tier='standard'))
***LOG: tool call search_papers({'query': 'transformers'}), result=Paper: 'Attention Is All You Need' (ID: 1706.03762)
Paper: 'BERT: Pre-training' (ID: 1810.04805)
***LOG: received response Message(id='msg_01Aw4zutmHMnZE8WeMiTNRW8', content=[TextBlock(citations=None, text='The search returned 2 relevant paper titles and IDs on the topic of transfor

In [None]:
test_with_query("Find papers about agents and get details on the first one")

***LOG: received response Message(id='msg_01GvKtC11sdBoUnFXaEoWeuN', content=[TextBlock(citations=None, text="Okay, let's find some papers about agents and get details on the first one.", type='text'), ToolUseBlock(id='toolu_013NbKicXb1eKq4eF6LnwUuP', input={'query': 'agents', 'max_results': 1}, name='search_papers', type='tool_use')], model='claude-3-haiku-20240307', role='assistant', stop_reason='tool_use', stop_sequence=None, type='message', usage=Usage(cache_creation=CacheCreation(ephemeral_1h_input_tokens=0, ephemeral_5m_input_tokens=0), cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=785, output_tokens=91, server_tool_use=None, service_tier='standard'))
***LOG: tool call search_papers({'query': 'agents', 'max_results': 1}), result=Paper: 'ReAct: Reasoning and Acting' (ID: 2210.03629)
***LOG: received response Message(id='msg_0161CTGgHzFdCNjuCSNpBycr', content=[ToolUseBlock(id='toolu_0178gsGsutMQzCGfS8GPUubo', input={'paper_id': '2210.03629'}, name='get_pape

In [None]:
test_with_query("Search for papers about RAG, get details on one, and summarize it")

***LOG: received response Message(id='msg_01LsgMtujQNGwRrjVgwU4Niz', content=[TextBlock(citations=None, text="Okay, let's find some relevant papers on RAG and summarize one of them.", type='text'), ToolUseBlock(id='toolu_01VbhHvqSUjzcHRQzNxfPqhB', input={'query': 'RAG'}, name='search_papers', type='tool_use')], model='claude-3-haiku-20240307', role='assistant', stop_reason='tool_use', stop_sequence=None, type='message', usage=Usage(cache_creation=CacheCreation(ephemeral_1h_input_tokens=0, ephemeral_5m_input_tokens=0), cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=790, output_tokens=74, server_tool_use=None, service_tier='standard'))
 Attempt 1/3 Timeout: Search API call timed out
***LOG: tool call search_papers({'query': 'RAG'}), result=Paper: 'Retrieval-Augmented Generation' (ID: 2005.11401)
Paper: 'Dense Passage Retrieval' (ID: 2004.04906)
***LOG: received response Message(id='msg_01DXrNwNiob5ABmt19A91wHo', content=[TextBlock(citations=None, text="The search 