Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Install packages

In [None]:
!pip install openai==0.28
!pip install tiktoken
!pip install tqdm
!pip install matplotlib

# Import packages

In [None]:
import numpy as np
import openai
import tiktoken
from tqdm.auto import trange, tqdm
import time
import os
import json
from tqdm import tqdm
import re
from types import NoneType
import multiprocessing.dummy
from io import StringIO
from contextlib import redirect_stdout
import signal
from contextlib import contextmanager
import matplotlib.pyplot as plt
import sys
import ast
import copy

# Set up API key

In [None]:
openai.api_key = "<add your API key here>"

# Global variables

In [None]:
ENGINE = 'gpt-3.5-turbo-instruct'
CORRECT_ANSWER = '52'
ANSWER_TOKEN = 'Answer: '
CODE_START_TOKEN = "# CODE START"
CODE_END_TOKEN = "# CODE END"
MAX_TOKENS = 4096
ENCODER = tiktoken.encoding_for_model(ENGINE)

# Helper functions

In [None]:
def query_llm(prompt, max_tokens, stop=None, temperature=0):
  assert type(prompt)
  response = openai.Completion.create(prompt=prompt, model=ENGINE, max_tokens=max_tokens, temperature=temperature, stop=stop)
  response_text = response.choices[0]["text"].strip()
  return response_text

In [None]:
def print_result(method, response, answer):
  print("#### Method ####")
  print(method)
  print("#### Full Response ####")
  print(response)
  print("#### Model Answer ####")
  print(answer)
  print("#### Correct Answer ####")
  print(CORRECT_ANSWER)

In [None]:
errors = {}
error_lineno = None
lines = None
trace_lines = []
last_state = None

def get_delta_state(state, last_state):
  delta_state = {}
  for key, val in state.items():
    if key not in last_state or val != last_state[key]:
      delta_state[key] = val
  return delta_state

def get_state(frame):
  state = {}
  for key, item in frame.f_locals.items():
    if isinstance(item, (bool, str, int, float, tuple, list, set, dict, NoneType)):
      state[key] = item
  return state

def show_trace(frame, event, arg):
  # Declare these global variable first
  global errors
  global error_lineno
  global lines
  global trace_lines
  global last_state
  global lines_run_history

  # The LLM-generated code will be wrapped around in the get_answer function call.
  # If we don't filter by "get_answer", we got a bunch of random exception from colab
  if frame.f_code.co_name != "get_answer":
    return

  lineno = frame.f_lineno - 1
  # Running a certain line
  if event == "line":
    current_line = lines[lineno]
    if current_line.strip() in ["try:", "except:", "pass"]:
      pass
    elif current_line.strip() == "return answer":
      assert lineno == len(lines) - 2, "return answer is at the wrong line" # Second to last line
      state = get_state(frame)
      assert last_state is not None
      delta_state = get_delta_state(state, last_state)
      trace_lines.append(f"delta state: {delta_state}")
      # Append the final state
      trace_lines.append(f"final state: {state}")
    elif lineno not in errors:
      # We previous indent 2 spaces
      assert current_line[:2] == "  ", f"Python: actual line to run doesn't have two leading spaces: {current_line} {lines}"
      # Now we revert back
      current_line = current_line[2:]

      state = get_state(frame)
      delta_state = None
      if last_state is None:
        delta_state = None
      else:
        delta_state = get_delta_state(state, last_state)
      last_state = copy.deepcopy(state)

      if delta_state is None:
        trace_lines.append("state: {}")
      else:
        trace_lines.append(f"delta state: {delta_state}")
      trace_lines.append(f"line: {current_line}")
    else:
      # We previous indent 4 spaces
      assert current_line[:4] == "    ", f"LLM: actual line to run doesn't have four leading spaces: {current_line} {lines}"
      # Now we revert back
      current_line = current_line[4:]
      # When LLM excutes, remove any trailing space at the beginning

      state = get_state(frame)
      delta_state = None
      if last_state is None:
        delta_state = None
      else:
        delta_state = get_delta_state(state, last_state)
      last_state = copy.deepcopy(state)

      if delta_state is None:
        trace_lines.append("state: {}")
      else:
        trace_lines.append(f"delta state: {delta_state}")
      trace_lines.append(f"line: {current_line}")

      # Due to the context length constraint, only feed in the last three lines of the trace.
      prompt = coc_trace_prompt + "\n" + "\n".join(trace_lines[-3:]) + "\n" + "delta state:"

      token_length = len(ENCODER.encode(prompt))

      llm_result = query_llm(prompt, max_tokens=32, stop=["\nline:"])

      progress_bar.update()
      program_state_str = llm_result.strip()
      try:
        new_program_state = ast.literal_eval(program_state_str)
        assert isinstance(new_program_state, dict), "new program state is not a valid dict"
        # Actually update the local variables with the new program state
        frame.f_locals.update(new_program_state)
      except Exception as e:
        raise e

  elif event == "exception":
    # Only capture the lowest level exception AND if this exception hasn't been "fixed" before, i.e. this line hasn't be sandwiched by try/except yet.
    if error_lineno is None and lineno not in errors:
      error_lineno = lineno

  return show_trace

sys.settrace(show_trace)

# Prompts

In [None]:
direct_prompt = """
Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.
Answer: 4
""".strip()

In [None]:
cot_prompt = """
Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.
A:
We'll group by countries and count:
1. Spain: Bilbao
2. USA: Death Valley, Honolulu
3. France: Paris
4. UK: Skye
There are 4 countries in total. So the answer is 4.
Answer: 4
""".strip()

In [None]:
coc_prompt = """
Q: How many countries have I been to? I’ve been to Bilbao, Death Valley, Paris, Honolulu, Skye.
A:
# CODE START
places = ["Bilbao", "Death Valley", "Paris", "Honolulu", "Skye"]
countries = set()
for place in places:
  country = get_country(place)
  countries.add(country)
answer = len(countries)
# CODE END
Answer: 4
""".strip()

In [None]:
coc_trace_prompt = """
# TRACE START
state: {}
line: places = ["Bilbao", "Death Valley", "Paris", "Honolulu", "Skye"]
delta state: {'places': ['Bilbao', 'Death Valley', 'Paris', 'Honolulu', 'Skye']}
line: countries = set()
delta state: {'countries': set()}
line: for place in places:
delta state: {'place': 'Bilbao'}
line:   country = get_country(place)
delta state: {'country': 'Spain'}
line:   countries.add(country)
delta state: {'countries': {'Spain'}}
line: for place in places:
delta state: {'place': 'Death Valley'}
line:   country = get_country(place)
delta state: {'country': 'USA'}
line:   countries.add(country)
delta state: {'countries': {'Spain', 'USA'}}
line: for place in places:
delta state: {'place': 'Paris'}
line:   country = get_country(place)
delta state: {'country': 'France'}
line:   countries.add(country)
delta state: {'countries': {'Spain', 'USA', 'France'}}
line: for place in places:
delta state: {'place': 'Honolulu'}
line:   country = get_country(place)
delta state: {'country': 'USA'}
line:   countries.add(country)
delta state: {'countries': {}}
line: for place in places:
delta state: {'place': 'Skye'}
line:   country = get_country(place)
delta state: {'country': 'UK'}
line:   countries.add(country)
delta state: {'countries': {'Spain', 'USA', 'France', 'UK'}}
line: answer = len(countries)
delta state: {'answer': 4}
# TRACE END

# TRACE START
""".strip()

In [None]:
query = """
Q: How many countries have I been to? I’ve been to Mumbai, London, Washington, Grand Canyon, Baltimore, Longsheng, Guilin, Beijing,
Galapagos, Quito, Barcelona, Paris, Prague, Nice, Dehli, Agra, Rome, Florence, Amalfi, Athens, Míkonos, Málaga, Monaco, Berlin,
Munich, Innsbruck, Bern, Milan, Lucerne, Gimmelwald (Schilthornbahn), St Moritz, St Petersburg, Helsinki, Amsterdam, Gdansk,
Vancouver, Anchorage, Montreal, Belize, The Bahamas, Jamaica, Hawaii, Acadia National Park, Stockholm, Copenhagen, Dover, Lyon,
Madrid, Toulouse, Santorini, Oslo, Kusadasi, Souda, Rhodes, Tallinn, Venice, Naples, Cape Town, Johannesburg, Addis Abeba,
Nairobi, Seattle, San Francisco, Chicago, St Louis, Memphis, Chinle, Stanford, New York, Philadelphia, Boston, Miami,
New Orleans, Walt Disney World Resort, Jacksonville, Las Vegas, Los Angeles, Portland, Salt Lake City, Tahoe City, Phoenix,
Albuquerque, Cleveland, Charlottesville, Nags Head, Newfoundland and Labrador, Burlington, Wilmington, Myrtle Beach, St Lucia,
Barbados, Banff, Haiti, Montego Bay, Sao Palo, Rio, Lima, Cusco, Cozumel, Amarillo, Yosemite National Park, Joshua Tree,
Zion National Park, Bryce Canyon National Park, Grand Teton National Park, Yellowstone National Park, Glacier National Park, Mount Hood,
Paso Robles, San Diego, Bend, North Cascades National Park, Olympic National Park Visitor Center, Jasper National Park,
Sequoia National Park, Kings Canyon National Park, Shasta National Forest, Mount Saint Helens, Mount Rainier, Austin, Buenos Aires,
El Calafate, El Chaltén, Fitz Roy, Torres del Paine National Park, Puerto Natales, Puerto Varas, Santiago, Marble Caves, Cerro Castillo,
Coyhaique, Singapore, Casablanca, Marrakesh, Cairo, Jerusalem, Tokyo, Kyoto Prefecture, Taipei City, Taichung City, Krk,
Naturpark Puez-Geisler, Ljubljana, Plitvice Lakes National Park, Fairbanks, Juneau, Dallas, Sydney, Cairns, Brisbane, Hook Island,
Charleston, Panama City, Bangkok, Chiang Mai, Bengaluru, Denver, Indianapolis, Nashville, Blacksburg, Lisbon, Porto, Estes Park,
Coeur d’Alene, Hood River, Denali, Sitka, Mexico City, Warsaw, Geneva, Auckland, Queenstown, Whitefish, Minneapolis, Sioux Falls,
Bozeman, Missoula, Springfield, Skye, Edinburgh, Honolulu, Kauai, Haleakal¯a National Park, Wrangell-St. Elias National Park &
Preserve, Atlanta, Tirana, Corfu, Siena.
""".strip()

# Demos

Note: running these demos will cost around $0.2 for calling OpenAI APIs.

## Direct answer prompting

In [None]:
def evaluate_direct(prompt, query):
  direct_response = query_llm(prompt + "\n\n" + query, max_tokens=32)
  if ANSWER_TOKEN in direct_response:
    direct_answer = direct_response.split(ANSWER_TOKEN)[1].strip()
  else:
    direct_answer = direct_response
  print_result("Direct", direct_response, direct_answer)

In [None]:
evaluate_direct(direct_prompt, query)

## CoT prompting

In [None]:
def evaluate_cot(prompt, query):
  cot_response = query_llm(prompt + "\n\n" + query, max_tokens=3072)
  if ANSWER_TOKEN in cot_response:
    cot_answer = cot_response.split(ANSWER_TOKEN)[1].strip()
  else:
    cot_answer = cot_response
  print_result("CoT", cot_response, cot_answer)

In [None]:
evaluate_cot(cot_prompt, query)

## CoC prompting

In [None]:
def evaluate_coc(prompt, query):
  global errors
  global error_lineno
  global lines
  global trace_lines
  global last_state
  coc_response = query_llm(prompt + "\n\n" + query, max_tokens=1024)
  code_to_run = coc_response.split(CODE_START_TOKEN)[1].split(CODE_END_TOKEN)[0].strip()

  answer = None
  max_trials = 20
  # Wrap the code inside the get_answer function call
  code_to_run_temp = code_to_run.split("\n")
  code_to_run = "\n".join(["  " + l for l in code_to_run_temp])
  code_to_run = f"""def get_answer():
{code_to_run}
  return answer
answer = get_answer()"""
  lines = code_to_run.split("\n")
  local_vars = locals()

  for num_trial in range(max_trials):
    if sys.gettrace() is None: sys.settrace(show_trace)
    assert sys.gettrace() is not None, "get trace is None"
    try:
      # answer will be populated by exec function.
      exec(code_to_run, globals(), local_vars)
      coc_answer = local_vars["answer"]
      assert coc_answer is not None
      break
    except Exception as e:
      assert error_lineno is not None
      # Update errors
      line = lines[error_lineno]
      errors[error_lineno + 1] = line

      # Update lines and code_to_run
      num_indent = len(line) - len(line.lstrip())
      lines[error_lineno] = " " * 2 + lines[error_lineno]
      lines.insert(error_lineno, " " * num_indent + "try:")
      lines.insert(error_lineno + 2, " " * num_indent + "except:")
      lines.insert(error_lineno + 3, " " * (num_indent + 2) + "pass")
      code_to_run = "\n".join(lines)

      # Reset error_lineno and trace_lines
      error_lineno = None
      trace_lines = []
      last_state = None

  print_result('CoC', coc_response, coc_answer)

In [None]:
# This cell runs for roughly one minute.
NUM_PLACES = 188
progress_bar = tqdm(total=NUM_PLACES)
evaluate_coc(coc_prompt, query)
progress_bar.close()

# Interactive Demos

## Code Generation

In [None]:
code_generation_prompt = """Q: Which sentence has the correct adjective order:
Options:
(A) rubber terrible ship
(B) terrible rubber ship
A:
# CODE START
import numpy as np
options = {"(A)": "rubber terrible ship", "(B)": "terrible rubber ship"}
priority = {"opinion": 1, "size": 2, "age": 3, "shape": 4, "color": 5, "origin": 6, "material": 7, "purpose": 8}
valid_types = list(priority.keys())
scores = []
for option, sentence in options.items():
  adjectives = sentence.split(" ")[:-1]
  order = [priority[get_adjective_type(adjective, valid_types, ret_type=str)] for adjective in adjectives]
  scores.append([order[i+1] > order[i] for i in range(len(order) - 1)].count(True))
answer = list(options.keys())[np.argmax(scores)]
# CODE END

Q: Today is Christmas Eve of 1937. What is the date 10 days ago in MM/DD/YYYY?
A:
# CODE START
import datetime
options = {"12/14/2026": "(A)", "12/14/1950": "(B)", "12/14/2007": "(C)", "12/14/1937": "(D)", "07/14/1938": "(E)", "12/14/1988": "(F)"}
today = datetime.date(year=1937, month=12, day=24)
date = today - datetime.timedelta(days=10)
answer = date.strftime("%m/%d/%Y")
# CODE END

Q: Recommend a movie similar to Star Wars Episode IV - A New Hope, Indiana Jones and the Last Crusade, Star Wars Episode V - The Empire Strikes Back, The Big Lebowski:
A:
# CODE START
ref_movies = ["Star Wars Episode IV - A New Hope", "Indiana Jones and the Last Crusade", "Star Wars Episode V - The Empire Strikes Back", "The Big Lebowski"]
ref_movie_infos = get_movie_genre_and_year(ref_movies, ret_type="list[tuple[str, int]]")
answer = get_most_similar_movie(ref_movies, ref_movie_infos, ret_type="str")
# CODE END

Q: """

In [None]:
def evaluate_code_generation(prompt, query):
  print(query_llm(prompt + query, max_tokens=256, stop=["\n\nQ:"]))

This interactive demo showcases the code generation aspect of Chain of Code.

Aside from the three examples below, feel free to try out any reasoning tasks of your choice by typing in the input box on the right-hand side.
Our model should be able to output reasonable python code/pseudocode to solve the task.

In [None]:
example_1 = "What type of food does two concentric circles look like?"
example_2 = "If I stack three Eiffel Towers on top of each other, how tall is the new tower?"
example_3 = "What are 10 smallest countries in the world?"
query = "What type of food does two concentric circles look like?" #@param {allow-input: true, type:"string"}
query += "\nA:\n"

evaluate_code_generation(code_generation_prompt, query)

## Code Execution (LMulator)

In [None]:
lmulator_prompt = """# TRACE START
state: {}
line: adjective = "red"
delta state: {'adjective': 'red'}
line: valid_types = {"opinion", "size", "age", "shape", "color", "origin", "material", "purpose"}
delta state: {'valid_types': {'opinion', 'size', 'age', 'shape', 'color', 'origin', 'material', 'purpose'}}
line: adj_type = get_adjective_type(adjective, valid_types, ret_type=str)
delta state: {'adj_type': 'color'}
# TRACE END

# TRACE START
state: {}
line: obj1 = "soda can"
delta state: {'obj1': 'soda can'}
line: is_obj1_recyclable = is_recyclable(obj1, ret_type=bool)
delta state: {'is_obj1_recyclable': True}
line: obj2 = "fruit"
delta state: {'obj2': 'fruit'}
line: is_obj2_recyclable = is_recyclable(obj2, ret_type=bool)
delta state: {'is_obj2_recyclable': False}
# TRACE END

# TRACE START
state: {}
line: num1 = 23
delta state: {'num1': 23}
line: num2 = 52
delta state: {'num2': 52}
line: sum_of_two = num1 + num2
delta state: {'sum_of_two': 75}
line: greated_than_one_hundred = is_greater(sum_of_two, 100)
delta state: {'greated_than_one_hundred': False}
# TRACE END

# TRACE START
state: {}
line: """

In [None]:
def evaluate_lmulator(prompt, query):
  print(query_llm(prompt + query, max_tokens=256, stop=["\nline:", "\n# TRACE END"]))

This interactive demo showcases the code execution aspect of Chain of Code (LMulator).

Aside from the three examples below, feel free to try out any python code/pseudocode that you want the LMulator to simulate the execution by typing in the input box on the right-hand side.

Our model should be able to output a reasonable delta state as the result of simulated code execution.

In [None]:
example_1 = "extinct_animal = which_animal_is_extinct(['panda', 'dinosaur', 'pig'], ret_type=str)"
example_2 = "divisible_test = is_divisible(divident=2142, divisor=17, ret_type=bool)"
example_3 = "distance_between_cities = get_distance(city1='San Francisco', city2='New York', unit='kilometers', ret_type=float)"
query = "extinct_animal = which_animal_is_extinct(['panda', 'dinosaur', 'pig'], ret_type=str)" #@param {allow-input: true, type:"string"}

evaluate_lmulator(lmulator_prompt, query)