##### Copyright 2023 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Use external tools with generate_text

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://developers.generativeai.google/examples/text_calculator"><img src="https://developers.generativeai.google/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on Generative AI</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/examples/text_calculator.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/examples/text_calculator.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

This tutorial shows an exape of using an external tool with the `palm.generate_text` function. The overal plan is:

1. Describe a `start` and `end` tag it can use to call the tool.
2. Include the `end` tag in the of `stop_sequences` passed to `generate_text`.
3. Take the text between the `start` and `end` tags as input to the tool.
4. Run the tool and add it's output to the prompt.
5. Call `generate_text` again, to have the model continue with the tool's output.


## Setup

In [1]:
!pip install -q google.generativeai

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.5/121.5 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.3/113.3 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
import google.generativeai as palm
palm.configure(api_key='YOUR API KEY')

from google.api_core import retry

@retry.Retry()
def generate_text(*args, **kwargs):
  return palm.generate_text(*args, **kwargs)


In [8]:
model="models/text-bison-001"

## Try to solve the problem directly

Here's the word problem you're going to solve:

In [9]:
question = """
I have 77 houses, each with 31 cats.
Each cat owns 14 mittens, and 6 hats.
Each mitten was knit from 141m of yarn, each hat from 55m.
How much yarn was needed to make all the items?
"""

In [16]:
prompt_template = """
You are an expert at solving word probles. Here's one:

{question}

Work throught it step by step, and show your work.
One step per line.

Your solution:
"""

Try it as is:

In [17]:
completion = generate_text(
    model=model,
    prompt=prompt_template.format(question=question),
    # The maximum length of the response
    max_output_tokens=800,
)

print(completion.result)

In the houses there are 77 * 31 = 2387 cats.
So they need 2387 * 14 = 33418 mittens.
And they need 2387 * 6 = 14322 hats.
In total they need 33418 * 141 + 14322 * 55 = 5554525m of yarn.
The answer: 5554525.


Just like that, it's usually incorrect.
It generally gets the steps right but the arithmetic wrong.

The answer should be:

In [26]:
answer = 77*31*14*141 + 77*31*6*55
answer

5499648

## Tell the model to use a calculator

So give the model access to a calculator. You can do that by adding something like this to the prompt:

In [19]:
calc_prompt_template = """
You are an expert at solving word probles. Here's a question:

{question}

-------------------

When solving this problem, use the calculator for any arithmetic.

To use the calculator, put an expression between <calc></calc> tags.
The answer will be printed after the </calc> tag.

For example: 2 houses  * 8 cats/house = <calc>2 * 8</calc> = 16 cats

-------------------

Work throught it step by step, and show your work.
One step per line.

Your solution:
"""

calc_prompt = calc_prompt_template.format(question=question)

To give the model access to the output of this "calculator", you'll have to pause generation and insert the result. Use the `stop_sequences` argument to stop at the `</calc>` tag:

In [20]:
completion = generate_text(
    model=model,
    prompt=calc_prompt,
    stop_sequences=["</calc>"],
    # The maximum length of the response
    max_output_tokens=800,
    candidate_count=1,
)

result = completion.result
print(result)

In each house, there are <calc>31 * 14


The `stop_sequence` is not included in the result. Split off the experssion and run it through the calculator, and add it back to the result:

In [21]:
# Use re to clear units from the calculator expressions
import re
# Use numexpr since `eval` is unsafe.
import numexpr


def calculator(result):
  result, experssion = result.rsplit('<calc>', 1)

  # Strip any units like "cats / house"
  clean_experssion = re.sub("[a-zA-Z]([ /a-zA-Z]*[a-zA-Z])?",'', experssion)

  # `eval` is unsafe use numexpr
  result = f"{result}<calc>{experssion}</calc> = {str(numexpr.evaluate(clean_experssion))}"
  return result

In [22]:
print(calculator(result))

In each house, there are <calc>31 * 14</calc> = 434


Now append that to the prompt, and run the model again, so it can continue where it left off:

In [23]:
continue_prompt=calc_prompt +"\n"+ "-"*80 + "\n" + calculator(result)

completion = generate_text(
    model=model,
    prompt=continue_prompt,
    stop_sequences=["</calc>"],
    # The maximum length of the response
    max_output_tokens=800,
    candidate_count=1,
)

print(completion.result)

mittens.
In each house, there are <calc>31 * 6


That basically works. So run it in a loop to solve the problem:

In [27]:
def solve(question=question):
  results = []

  for n in range(10):
    prompt = calc_prompt_template.format(question=question)

    prompt += " ".join(results)

    completion = generate_text(
        model=model,
        prompt=prompt,
        stop_sequences=["</calc>"],
        # The maximum length of the response
        max_output_tokens=800,
    )

    result = completion.result
    if '<calc>' in result:
      result = calculator(result)

    results.append(result)
    print('-'*40)
    print(result)
    if str(answer) in result:
      break
    if "<calc>" not in  result:
      break

  is_good = any(str(answer) in r for r in results)

  print("*"*100)
  if is_good:
    print("Success!")
  else:
    print("Failure!")
  print("*"*100)

  return is_good

In [28]:
solve(question);

----------------------------------------
The total number of cats is <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
The total number of mittens is <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
The total amount of yarn needed for the mittens is <calc>33418 * 141</calc> = 4711938
----------------------------------------
m.
The total number of hats is <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
 The total amount of yarn needed for the hats is <calc>14322 * 55</calc> = 787710
----------------------------------------
m.
In total, <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!


Run that a few times to estimate the solve rate:

In [29]:
good = []

for n in range(10):
  good.append(solve(question))

----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
They need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
The mittens need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn.
They need <calc>2387 * 6</calc> = 14322
----------------------------------------
hats.
The hats need <calc>14322 * 55</calc> = 787710
----------------------------------------
m of yarn.
 They need a total of <calc>4711938 + 787710</calc> = 5499648
********************************************************************************
Success!
----------------------------------------
There are <calc>77 * 31</calc> = 2387
----------------------------------------
cats.
So for the mittens, we need <calc>2387 * 14</calc> = 33418
----------------------------------------
mittens.
That means we need <calc>33418 * 141</calc> = 4711938
----------------------------------------
m of yarn f

In [30]:
import numpy as np
np.mean(good)

0.9