In [None]:
!pip install openai
!pip install cascades
!pip install duckduckgo-search # or your preferred web query api

In [None]:
!pip install openai
import os
import openai
api_key = "" # @param
os.environ['OPENAI_API_KEY'] = api_key
openai.api_key = os.getenv(os.environ['OPENAI_API_KEY'])
os.getenv('OPENAI_API_KEY')

In [70]:
import cascades as cc

In [42]:
"""Sampling from OpenAI api."""
import bisect
import dataclasses
import functools
import os
from typing import Iterable, Optional, Text
import uuid

from cascades._src.distributions import base as dists
import openai
import jax

openai.api_key = os.getenv('OPENAI_API_KEY')


# TODO(ddohan): Persist cache to disk
@functools.lru_cache()
def cached_completion(rng=None, **kwargs):
  del rng
  return openai.Completion.create(**kwargs)


@dataclasses.dataclass(eq=True, frozen=True)
class GPT(dists.Distribution):
  """Sample a String from GPT."""

  prompt: Text = ''

  stop: Optional[Iterable[Text]] = ('\n',)

  engine = 'davinci-codex'
  temperature: float = 0.7
  max_tokens: int = 128
  top_p: float = .95
  frequency_penalty: int = 0
  presence_penalty: int = 0

  def sample(self, rng=None, raw=False):
    """Sample a value from the distribution given rng key.

    Args:
      rng: Optional random key.
      raw: If True, return OpenAI api request verbatim.

    Returns:
      a RandomSample or, if raw is True, the OpenAI response dict.
    """
    if rng is None:
      rng = uuid.uuid4().hex
    elif isinstance(rng, jax.numpy.DeviceArray):
      rng = jax.random.randint(rng, (), 0, 1_000_000_000)
      rng = int(rng)
    elif not isinstance(rng):
      raise ValueError(f'RNG must be an integer or Jax key. Was {rng}')
    result = cached_completion(
        rng=rng,
        model=self.engine,
        prompt=self.prompt,
        temperature=self.temperature,
        max_tokens=self.max_tokens,
        top_p=self.top_p,
        frequency_penalty=self.frequency_penalty,
        presence_penalty=self.presence_penalty,
        logprobs=0,
        stop=self.stop,
        echo=False,
    )['choices'][0]

    completion = result['text']
    logprobs = result['logprobs']

    span = (0, len(completion) + 1)
    start = 0
    end = bisect.bisect_left(logprobs['text_offset'], span[1])
    total_logprob = sum(result['logprobs']['token_logprobs'][start:end])
    if raw:
      return (total_logprob, result)
    return dists.RandomSample(log_p=total_logprob, value=completion)

  def log_prob(self, value, raw=False):
    """Get log prob of completion.

    Args:
      value: Completion to score.
      raw: If True, return raw OpenAI API response.

    Returns:
      float logprob or the OpenAI response dict.
    """
    text = self.prompt + value
    span = (len(self.prompt), len(text))
    result = cached_completion(
        rng=0,
        model=self.engine,
        prompt=text,
        temperature=self.temperature,
        max_tokens=0,
        echo=True,
        top_p=self.top_p,
        frequency_penalty=self.frequency_penalty,
        presence_penalty=self.presence_penalty,
        logprobs=0,
    )['choices'][0]
    logprobs = result['logprobs']
    start = bisect.bisect(logprobs['text_offset'], span[0]) - 1
    end = bisect.bisect_left(logprobs['text_offset'], span[1])
    excerpt = text[span[0]:span[1]]
    joined = ''.join(logprobs['tokens'][start:end])
    assert excerpt in joined
    total_logprob = sum(logprobs['token_logprobs'][start:end])
    if raw:
      return (total_logprob, result)
    return total_logprob

In [71]:
from cascades._src.distributions import gpt
cc.GPT = gpt.GPT
# cc.GPT = GPT # temp hack until the api is updated

In [72]:
cc.GPT(prompt='Hello').sample(0)

RandomSample(log_p=-17.8193744, value=' world!", 7, 0.6)', dist=None)

# Web Search

In [73]:
from duckduckgo_search import ddg

In [74]:
keywords = 'How many legs does a rabbit have?'
results = ddg(keywords, region='wt-wt', safesearch='Moderate', time='y', max_results=28)
print(results)

[{'title': 'How Many Legs Does A Rabbit Have - Realonomics', 'href': 'https://aeries.norushcharge.com/how-many-legs-does-a-rabbit-have/', 'body': "Does a rabbit have 2 or 4 legs? A rabbit has four legs. The two in the front are called the forelegs and the two in the rear are called the hind legs. … Muscling in the hind legs is much more extensive than in the forelegs. A rabbit's body is broken into two sections called the forequarter and the hindquarter. How many feet does a rabbit have?"}, {'title': 'How Many Legs Do Rabbits Have - Realonomics', 'href': 'https://aeries.norushcharge.com/how-many-legs-do-rabbits-have/', 'body': 'How Many Legs Do Rabbits Have? four legs Do rabbits have legs? The hind limbs of the rabbit are longer than the front limbs. This allows them to produce their hopping form of locomotion. Longer hind limbs are more capable of producing faster speeds. Hares which have longer legs than cottontail rabbits are able to move considerably faster.'}, {'title': "Rabbit An

In [75]:
results[0].keys(), results[0]['body']

(dict_keys(['title', 'href', 'body']),
 "Does a rabbit have 2 or 4 legs? A rabbit has four legs. The two in the front are called the forelegs and the two in the rear are called the hind legs. … Muscling in the hind legs is much more extensive than in the forelegs. A rabbit's body is broken into two sections called the forequarter and the hindquarter. How many feet does a rabbit have?")

In [76]:
import functools
@functools.lru_cache(maxsize=1000)
def get_passages(query, num_passages=5, output=None):
  # output: json, csv, print
  res = ddg(keywords=query, max_results=num_passages, output=output) 
  return res

In [77]:
@cc.model
def qa_with_search(question):
  """Answer question."""
  context = get_passages(question, num_passages=1)[0]['body']
  yield cc.log(context, name='context')
  prompt = f"""The answer sheet for the questions is below:

Question: Which planet is the hottest in the solar system?
Context: It has a strong greenhouse effect, similar to the one we experience on Earth. Because of this, Venus is the hottest planet in the solar system. The surface of Venus is approximately 465°C! Fourth from the Sun, after Earth, is Mars.
Answer: Venus

Question: Which country produces the most coffee in the world?
Context: With the rise in popularity of coffee among Europeans, Brazil became the world's largest producer in the 1840s and has been ever since. Some 300,000 coffee farms are spread over the Brazilian landscape.
Answer: Brazil

Question: {question}
Context: {context}
Answer:"""
  answer = yield GPT(prompt=prompt, stop='\n', name='answer')
  return answer.value

@cc.model
def qa(question):
  """Answer question."""
  prompt = f"""Answer the questions below given a document from the web:

Question: What is often seen as the smallest unit of memory?
Answer: kilobyte

Question: Which planet is the hottest in the solar system?
Answer: Venus

Question: Which country produces the most coffee in the world?
Answer: Brazil

Question: {question}
Answer:"""
  answer = yield GPT(prompt=prompt, stop='\n', name='answer')
  return answer.value

In [78]:
%time no_search = qa.sample('Which bones are babies born without?')
no_search

CPU times: user 7.03 ms, sys: 2.23 ms, total: 9.26 ms
Wall time: 10.4 ms


Record(
  answer: Sample(name='answer', score=0, value=' Clavicle', should_stop=False, replayed=False, metadata=None)
  return:  Clavicle
)

In [79]:
%time with_search = qa_with_search.sample('Which bones are babies born without?')
with_search

CPU times: user 30 ms, sys: 6.58 ms, total: 36.6 ms
Wall time: 2.13 s


Record(
  context: Log(name='context', score=None, value='One example of a bone that babies are born without: the kneecap (or patella). The kneecap starts out as cartilage and starts significantly hardening into bone between the ages of 2 and 6 years old. In most cases, several areas of cartilage in the knee begin to harden at the same time and eventually fuse together to form one solid bone.', should_stop=False, replayed=False, metadata=None)
  answer: Sample(name='answer', score=0, value=' patella', should_stop=False, replayed=False, metadata=None)
  return:  patella
)

In [80]:
def compare(question):
  no_search = qa.sample(question)
  search = qa_with_search.sample(question)
  return no_search, search

In [81]:
compare('Which bone are babies born without')

(Record(
   answer: Sample(name='answer', score=0, value=' Stapes', should_stop=False, replayed=False, metadata=None)
   return:  Stapes
 ), Record(
   context: Log(name='context', score=None, value='Firstly, a newborn has several "proto bones" (ie cartilagneous precusors) which are not bones at all - yet. During development, many bones consist of "several bones, joined by cartilage" which will become one bone eventually. Take an example - the femur. This consists of at least five bones until total fusion aged perhaps 17yo.', should_stop=False, replayed=False, metadata=None)
   answer: Sample(name='answer', score=0, value=' rib', should_stop=False, replayed=False, metadata=None)
   return:  rib
 ))

In [82]:
from concurrent import futures
pool = futures.ThreadPoolExecutor(16)

In [None]:
Q = 'Which bone is a baby born without?'
rs = qa_with_search.sample_parallel(pool, Q, n=16)
rs[:3]  # Records that haven't been run yet.

In [None]:
# show 20 results
rs[0].future.result(20)

In [None]:
[r.return_value for r in rs]

In [None]:
%%time
rs = qa.sample_parallel(pool, Q, n=16)
rs[-1].future.result(20)
print([r.return_value for r in rs])