<a href="https://colab.research.google.com/github/madaan/memprompt/blob/main/CompletionAndChat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:

!pip install openai

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting openai
  Downloading openai-0.27.2-py3-none-any.whl (70 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.1/70.1 KB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting aiohttp
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (264 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m264.6/264.6 KB[0m [31m21.8 MB/s[0m eta [36m0:00:00[0m
Collecting multidict<7.0,>=4.5
  Downloading multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (114 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m114.2/114.2 KB[0m [31m14.1 MB/s[0m eta [36m0:00

In [5]:
from collections import Counter
import os
from typing import Dict, Any
import openai
import random
import time


# from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb
def retry_with_exponential_backoff(
    func,
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 10,
    errors: tuple = (openai.error.RateLimitError,),
):
    """Retry a function with exponential backoff."""

    def wrapper(*args, **kwargs):
        # Initialize variables
        num_retries = 0
        delay = initial_delay

        # Loop until a successful response or max_retries is hit or an exception is raised
        while True:
            try:

                return func(*args, **kwargs)

            # Retry on specified errors
            except errors as e:
                # Increment retries
                num_retries += 1

                # Check if max retries has been reached
                if num_retries > max_retries:
                    raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")

                # Increment the delay
                delay *= exponential_base * (1 + jitter * random.random())

                # Sleep for the delay
                time.sleep(delay)

            # Raise exceptions for any errors not specified
            except Exception as e:
                raise e

    return wrapper

class BaseAPIWrapper:
    @staticmethod
    def call(
        prompt: str,
        max_tokens: int,
        engine: str,
        stop_token: str,
        temperature: float,
        num_completions: int = 1,
    ) -> dict:
        raise NotImplementedError()

    @staticmethod
    def get_first_response(response) -> Dict[str, Any]:
        raise NotImplementedError()

    @staticmethod
    def get_majority_answer(response) -> Dict[str, Any]:
        raise NotImplementedError()

    @staticmethod
    def get_all_responses(response) -> Dict[str, Any]:
        raise NotImplementedError()


class CompletionAPIWrapper(BaseAPIWrapper):
    @staticmethod
    @retry_with_exponential_backoff
    def call(
        prompt: str,
        max_tokens: int,
        engine: str,
        stop_token: str,
        temperature: float,
        num_completions: int = 1,
    ) -> dict:
        """Calls the completion API.

        if the num_completions is > 2, we call the API multiple times. This is to prevent
        overflow issues that can occur when the number of completions is too large.
        """
        if num_completions > 2:
            response_combined = dict()
            num_completions_remaining = num_completions
            for i in range(0, num_completions, 2):
                # note that we are calling the same function --- this prevents backoff from being reset for the entire function
                response = CompletionAPIWrapper.call(
                    prompt=prompt,
                    max_tokens=max_tokens,
                    engine=engine,
                    stop_token=stop_token,
                    temperature=temperature,
                    num_completions=min(num_completions_remaining, 2),
                )
                num_completions_remaining -= 2
                if i == 0:
                    response_combined = response
                else:
                    response_combined["choices"] += response["choices"]
            return response_combined
        response = openai.Completion.create(
            engine=engine,
            prompt=prompt,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=1,
            stop=[stop_token],
            # logprobs=3,
            n=num_completions,
        )
        return response

    @staticmethod
    def get_first_response(response) -> Dict[str, Any]:
        """Returns the first response from the list of responses."""
        text = response["choices"][0]["text"]
        return text

    @staticmethod
    def get_majority_answer(response) -> Dict[str, Any]:
        """Returns the majority answer from the list of responses."""
        answers = [choice["text"] for choice in response["choices"]]
        answers = Counter(answers)
        # if there is a tie, return the first answer
        if answers.most_common(1)[0][1] == answers.most_common(2)[1][1]:
            return CompletionAPIWrapper.get_first_response(response)

        return answers.most_common(1)[0][0]

    @staticmethod
    def get_all_responses(response) -> Dict[str, Any]:
        """Returns the list of responses."""
        return [choice["text"] for choice in response["choices"]]  # type: ignore


class ChatGPTAPIWrapper(BaseAPIWrapper):
    @staticmethod
    @retry_with_exponential_backoff
    def call(
        prompt: str,
        max_tokens: int,
        engine: str,
        stop_token: str,
        temperature: float,
        num_completions: int = 1,
    ) -> dict:
        """Calls the Chat API.

        if the num_completions is > 2, we call the API multiple times. This is to prevent
        overflow issues that can occur when the number of completions is too large.
        """
        messages = [
            {
                "role": "system",
                "content": "You are ChatGPT, a large language model trained by OpenAI.",
            },
            {"role": "user", "content": prompt},
        ]
        if num_completions > 2:
            response_combined = dict()
            num_completions_remaining = num_completions
            for i in range(0, num_completions, 2):
                # note that we are calling the same function --- this prevents backoff from being reset for the entire function
                response = ChatGPTAPIWrapper.call(
                    prompt=prompt,
                    max_tokens=max_tokens,
                    engine=engine,
                    stop_token=stop_token,
                    temperature=temperature,
                    num_completions=min(num_completions_remaining, 2),
                )
                num_completions_remaining -= 2
                if i == 0:
                    response_combined = response
                else:
                    response_combined["choices"] += response["choices"]
            return response_combined
        response = openai.ChatCompletion.create(
            model=engine,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            top_p=1,
            stop=[stop_token],
            # logprobs=3,
            n=num_completions,
        )
        return response

    @staticmethod
    def get_first_response(response) -> Dict[str, Any]:
        """Returns the first response from the list of responses."""
        text = response["choices"][0]["message"]["content"]
        return text

    @staticmethod
    def get_majority_answer(response) -> Dict[str, Any]:
        """Returns the majority answer from the list of responses."""
        answers = [choice["message"]["content"] for choice in response["choices"]]
        answers = Counter(answers)
        # if there is a tie, return the first answer
        if answers.most_common(1)[0][1] == answers.most_common(2)[1][1]:
            return ChatGPTAPIWrapper.get_first_response(response)

        return answers.most_common(1)[0][0]

    @staticmethod
    def get_all_responses(response) -> Dict[str, Any]:
        """Returns the list of responses."""
        return [choice["message"]["content"] for choice in response["choices"]]  # type: ignore


class OpenaiAPIWrapper:
    chat_engines = ["gpt-3.5-turbo", "gpt-4", "gpt-3.5-turbo-0301", "gpt-4-0314"]

    @staticmethod
    def get_api_wrapper(engine: str) -> BaseAPIWrapper:
        if engine in OpenaiAPIWrapper.chat_engines:
            return ChatGPTAPIWrapper
        else:
            return CompletionAPIWrapper


    @staticmethod
    def call(
        prompt: str,
        max_tokens: int,
        engine: str,
        stop_token: str,
        temperature: float,
        num_completions: int = 1,
    ) -> dict:
        api_wrapper = OpenaiAPIWrapper.get_api_wrapper(engine)
        return api_wrapper.call(prompt, max_tokens, engine, stop_token, temperature, num_completions)

    @staticmethod
    def get_first_response(response) -> Dict[str, Any]:
        api_wrapper = OpenaiAPIWrapper.get_api_wrapper(response["model"])
        return api_wrapper.get_first_response(response)

    @staticmethod
    def get_majority_answer(response) -> Dict[str, Any]:
        api_wrapper = OpenaiAPIWrapper.get_api_wrapper(response["model"])
        return api_wrapper.get_majority_answer(response)



In [7]:

def test_completion():
    prompt = """Optimize the following Python code:
  
# Start of code
n = int(input())
result = 0
for i in range(1, n + 1):
  result += i
return result
"""
    engine = "text-davinci-003"
    num_completions = 3
    max_tokens = 300
    response = OpenaiAPIWrapper.call(
        prompt=prompt,
        max_tokens=max_tokens,
        engine=engine,
        stop_token="Optimize the following Python code:\n\n",
        temperature=0.7,
        num_completions=num_completions,
    )
    print(response)
    print(OpenaiAPIWrapper.get_first_response(response))
    print(OpenaiAPIWrapper.get_majority_answer(response))


def test_chat():
    prompt = """Optimize the following Python code:
  
# Start of code
n = int(input())
result = 0
for i in range(1, n + 1):
  result += i
return result
""" 
    engine = "gpt-3.5-turbo"
    num_completions = 3
    max_tokens = 300
    response = OpenaiAPIWrapper.call(
        prompt=prompt,
        max_tokens=max_tokens,
        engine=engine,
        stop_token="End of code",
        temperature=0.7,
        num_completions=num_completions,
    )
    print(response)
    print(OpenaiAPIWrapper.get_first_response(response))
    print(OpenaiAPIWrapper.get_majority_answer(response))



In [8]:


openai.api_key = "sk-???"

In [9]:

if __name__ == "__main__":
    # test the API
    print("Testing completion API")
    test_completion()
    print("Testing chat API")
    test_chat()

Testing completion API
{
  "choices": [
    {
      "finish_reason": null,
      "index": 0,
      "logprobs": null,
      "text": "# End of code\n\ndef calculate_sum(n): \n  return (n * (n + 1)) // 2\n\nn = int(input())\nresult = calculate_sum(n) \nprint(result)"
    },
    {
      "finish_reason": "stop",
      "index": 1,
      "logprobs": null,
      "text": "# End of code\n\ndef sum(n):\n  return (n*(n+1))//2\n\nn = int(input())\nresult = sum(n)\nprint(result)"
    },
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "text": "# End of code\n\n# Optimized code\nn = int(input())\nreturn n * (n+1) // 2"
    }
  ],
  "created": 1680366579,
  "id": "cmpl-70YOhXoWRwEBhVBiMWGozGebWDLkL",
  "model": "text-davinci-003",
  "object": "text_completion",
  "usage": {
    "completion_tokens": 94,
    "prompt_tokens": 46,
    "total_tokens": 140
  }
}
# End of code

def calculate_sum(n): 
  return (n * (n + 1)) // 2

n = int(input())
result = calculate_sum(n) 
