Skip to content

Commit

Permalink
Merge pull request #346 from n3d1117/feature/support-functions
Browse files Browse the repository at this point in the history
Support functions (aka plugins)
  • Loading branch information
n3d1117 committed Aug 4, 2023
2 parents 3d2231a + 3e4c228 commit 30d441d
Show file tree
Hide file tree
Showing 23 changed files with 1,429 additions and 142 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ __pycache__
.DS_Store
/usage_logs
venv
/.cache
115 changes: 79 additions & 36 deletions README.md

Large diffs are not rendered by default.

21 changes: 16 additions & 5 deletions bot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from dotenv import load_dotenv

from openai_helper import OpenAIHelper, default_max_tokens
from plugin_manager import PluginManager
from openai_helper import OpenAIHelper, default_max_tokens, are_functions_available
from telegram_bot import ChatGPTTelegramBot


Expand All @@ -27,6 +28,7 @@ def main():

# Setup configurations
model = os.environ.get('OPENAI_MODEL', 'gpt-3.5-turbo')
functions_available = are_functions_available(model=model)
max_tokens_default = default_max_tokens(model=model)
openai_config = {
'api_key': os.environ['OPENAI_API_KEY'],
Expand All @@ -41,14 +43,18 @@ def main():
'temperature': float(os.environ.get('TEMPERATURE', 1.0)),
'image_size': os.environ.get('IMAGE_SIZE', '512x512'),
'model': model,
'enable_functions': os.environ.get('ENABLE_FUNCTIONS', str(functions_available)).lower() == 'true',
'functions_max_consecutive_calls': int(os.environ.get('FUNCTIONS_MAX_CONSECUTIVE_CALLS', 10)),
'presence_penalty': float(os.environ.get('PRESENCE_PENALTY', 0.0)),
'frequency_penalty': float(os.environ.get('FREQUENCY_PENALTY', 0.0)),
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
'show_plugins_used': os.environ.get('SHOW_PLUGINS_USED', 'false').lower() == 'true',
}

# log deprecation warning for old budget variable names
# old variables are caught in the telegram_config definition for now
# remove support for old budget names at some point in the future
if openai_config['enable_functions'] and not functions_available:
logging.error(f'ENABLE_FUNCTIONS is set to true, but the model {model} does not support it. '
f'Please set ENABLE_FUNCTIONS to false or use a model that supports it.')
exit(1)
if os.environ.get('MONTHLY_USER_BUDGETS') is not None:
logging.warning('The environment variable MONTHLY_USER_BUDGETS is deprecated. '
'Please use USER_BUDGETS with BUDGET_PERIOD instead.')
Expand Down Expand Up @@ -78,8 +84,13 @@ def main():
'bot_language': os.environ.get('BOT_LANGUAGE', 'en'),
}

plugin_config = {
'plugins': os.environ.get('PLUGINS', '').split(',')
}

# Setup and run ChatGPT and Telegram bot
openai_helper = OpenAIHelper(config=openai_config)
plugin_manager = PluginManager(config=plugin_config)
openai_helper = OpenAIHelper(config=openai_config, plugin_manager=plugin_manager)
telegram_bot = ChatGPTTelegramBot(config=telegram_config, openai=openai_helper)
telegram_bot.run()

Expand Down
134 changes: 122 additions & 12 deletions bot/openai_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type

from utils import is_direct_result
from plugin_manager import PluginManager

# Models can be found here: https://platform.openai.com/docs/models/overview
GPT_3_MODELS = ("gpt-3.5-turbo", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613")
GPT_3_16K_MODELS = ("gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613")
Expand All @@ -39,6 +42,19 @@ def default_max_tokens(model: str) -> int:
return base * 8


def are_functions_available(model: str) -> bool:
"""
Whether the given model supports functions
"""
# Deprecated models
if model in ("gpt-3.5-turbo-0301", "gpt-4-0314", "gpt-4-32k-0314"):
return False
# Stable models will be updated to support functions on June 27, 2023
if model in ("gpt-3.5-turbo", "gpt-4", "gpt-4-32k"):
return datetime.date.today() > datetime.date(2023, 6, 27)
return True


# Load translations
parent_dir_path = os.path.join(os.path.dirname(__file__), os.pardir)
translations_file_path = os.path.join(parent_dir_path, 'translations.json')
Expand Down Expand Up @@ -69,14 +85,16 @@ class OpenAIHelper:
ChatGPT helper class.
"""

def __init__(self, config: dict):
def __init__(self, config: dict, plugin_manager: PluginManager):
"""
Initializes the OpenAI helper class with the given configuration.
:param config: A dictionary containing the GPT configuration
:param plugin_manager: The plugin manager
"""
openai.api_key = config['api_key']
openai.proxy = config['proxy']
self.config = config
self.plugin_manager = plugin_manager
self.conversations: dict[int: list] = {} # {chat_id: history}
self.last_updated: dict[int: datetime] = {} # {chat_id: last_update_timestamp}

Expand All @@ -97,7 +115,13 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used
"""
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query)
if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(chat_id, response)
if is_direct_result(response):
return response, '0'

answer = ''

if len(response.choices) > 1 and self.config['n_choices'] > 1:
Expand All @@ -113,11 +137,17 @@ async def get_chat_response(self, chat_id: int, query: str) -> tuple[str, str]:
self.__add_to_history(chat_id, role="assistant", content=answer)

bot_language = self.config['bot_language']
show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += "\n\n---\n" \
f"💰 {str(response.usage['total_tokens'])} {localized_text('stats_tokens', bot_language)}" \
f" ({str(response.usage['prompt_tokens'])} {localized_text('prompt', bot_language)}," \
f" {str(response.usage['completion_tokens'])} {localized_text('completion', bot_language)})"
if show_plugins_used:
answer += f"\n🔌 {', '.join(plugin_names)}"
elif show_plugins_used:
answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"

return answer, response.usage['total_tokens']

Expand All @@ -128,22 +158,34 @@ async def get_chat_response_stream(self, chat_id: int, query: str):
:param query: The query to send to the model
:return: The answer from the model and the number of tokens used, or 'not_finished'
"""
plugins_used = ()
response = await self.__common_get_chat_response(chat_id, query, stream=True)
if self.config['enable_functions']:
response, plugins_used = await self.__handle_function_call(chat_id, response, stream=True)
if is_direct_result(response):
yield response, '0'
return

answer = ''
async for item in response:
if 'choices' not in item or len(item.choices) == 0:
continue
delta = item.choices[0].delta
if 'content' in delta:
if 'content' in delta and delta.content is not None:
answer += delta.content
yield answer, 'not_finished'
answer = answer.strip()
self.__add_to_history(chat_id, role="assistant", content=answer)
tokens_used = str(self.__count_tokens(self.conversations[chat_id]))

show_plugins_used = len(plugins_used) > 0 and self.config['show_plugins_used']
plugin_names = tuple(self.plugin_manager.get_plugin_source_name(plugin) for plugin in plugins_used)
if self.config['show_usage']:
answer += f"\n\n---\n💰 {tokens_used} {localized_text('stats_tokens', self.config['bot_language'])}"
if show_plugins_used:
answer += f"\n🔌 {', '.join(plugin_names)}"
elif show_plugins_used:
answer += f"\n\n---\n🔌 {', '.join(plugin_names)}"

yield answer, tokens_used

Expand Down Expand Up @@ -186,16 +228,24 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
logging.warning(f'Error while summarising chat history: {str(e)}. Popping elements instead...')
self.conversations[chat_id] = self.conversations[chat_id][-self.config['max_history_size']:]

return await openai.ChatCompletion.acreate(
model=self.config['model'],
messages=self.conversations[chat_id],
temperature=self.config['temperature'],
n=self.config['n_choices'],
max_tokens=self.config['max_tokens'],
presence_penalty=self.config['presence_penalty'],
frequency_penalty=self.config['frequency_penalty'],
stream=stream
)
common_args = {
'model': self.config['model'],
'messages': self.conversations[chat_id],
'temperature': self.config['temperature'],
'n': self.config['n_choices'],
'max_tokens': self.config['max_tokens'],
'presence_penalty': self.config['presence_penalty'],
'frequency_penalty': self.config['frequency_penalty'],
'stream': stream
}

if self.config['enable_functions']:
functions = self.plugin_manager.get_functions_specs()
if len(functions) > 0:
common_args['functions'] = self.plugin_manager.get_functions_specs()
common_args['function_call'] = 'auto'

return await openai.ChatCompletion.acreate(**common_args)

except openai.error.RateLimitError as e:
raise e
Expand All @@ -206,6 +256,60 @@ async def __common_get_chat_response(self, chat_id: int, query: str, stream=Fals
except Exception as e:
raise Exception(f"⚠️ _{localized_text('error', bot_language)}._ ⚠️\n{str(e)}") from e

async def __handle_function_call(self, chat_id, response, stream=False, times=0, plugins_used=()):
function_name = ''
arguments = ''
if stream:
async for item in response:
if 'choices' in item and len(item.choices) > 0:
first_choice = item.choices[0]
if 'delta' in first_choice \
and 'function_call' in first_choice.delta:
if 'name' in first_choice.delta.function_call:
function_name += first_choice.delta.function_call.name
if 'arguments' in first_choice.delta.function_call:
arguments += str(first_choice.delta.function_call.arguments)
elif 'finish_reason' in first_choice and first_choice.finish_reason == 'function_call':
break
else:
return response, plugins_used
else:
return response, plugins_used
else:
if 'choices' in response and len(response.choices) > 0:
first_choice = response.choices[0]
if 'function_call' in first_choice.message:
if 'name' in first_choice.message.function_call:
function_name += first_choice.message.function_call.name
if 'arguments' in first_choice.message.function_call:
arguments += str(first_choice.message.function_call.arguments)
else:
return response, plugins_used
else:
return response, plugins_used

logging.info(f'Calling function {function_name} with arguments {arguments}')
function_response = await self.plugin_manager.call_function(function_name, arguments)

if function_name not in plugins_used:
plugins_used += (function_name,)

if is_direct_result(function_response):
self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name,
content=json.dumps({'result': 'Done, the content has been sent'
'to the user.'}))
return function_response, plugins_used

self.__add_function_call_to_history(chat_id=chat_id, function_name=function_name, content=function_response)
response = await openai.ChatCompletion.acreate(
model=self.config['model'],
messages=self.conversations[chat_id],
functions=self.plugin_manager.get_functions_specs(),
function_call='auto' if times < self.config['functions_max_consecutive_calls'] else 'none',
stream=stream
)
return await self.__handle_function_call(chat_id, response, stream, times + 1, plugins_used)

async def generate_image(self, prompt: str) -> tuple[str, str]:
"""
Generates an image from the given prompt using DALL·E model.
Expand Down Expand Up @@ -264,6 +368,12 @@ def __max_age_reached(self, chat_id) -> bool:
max_age_minutes = self.config['max_conversation_age_minutes']
return last_updated < now - datetime.timedelta(minutes=max_age_minutes)

def __add_function_call_to_history(self, chat_id, function_name, content):
"""
Adds a function call to the conversation history
"""
self.conversations[chat_id].append({"role": "function", "name": function_name, "content": content})

def __add_to_history(self, chat_id, role, content):
"""
Adds a message to the conversation history.
Expand Down
68 changes: 68 additions & 0 deletions bot/plugin_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json

from plugins.gtts_text_to_speech import GTTSTextToSpeech
from plugins.dice import DicePlugin
from plugins.youtube_audio_extractor import YouTubeAudioExtractorPlugin
from plugins.ddg_image_search import DDGImageSearchPlugin
from plugins.ddg_translate import DDGTranslatePlugin
from plugins.spotify import SpotifyPlugin
from plugins.crypto import CryptoPlugin
from plugins.weather import WeatherPlugin
from plugins.ddg_web_search import DDGWebSearchPlugin
from plugins.wolfram_alpha import WolframAlphaPlugin
from plugins.deepl import DeeplTranslatePlugin
from plugins.worldtimeapi import WorldTimeApiPlugin
from plugins.whois_ import WhoisPlugin


class PluginManager:
"""
A class to manage the plugins and call the correct functions
"""

def __init__(self, config):
enabled_plugins = config.get('plugins', [])
plugin_mapping = {
'wolfram': WolframAlphaPlugin,
'weather': WeatherPlugin,
'crypto': CryptoPlugin,
'ddg_web_search': DDGWebSearchPlugin,
'ddg_translate': DDGTranslatePlugin,
'ddg_image_search': DDGImageSearchPlugin,
'spotify': SpotifyPlugin,
'worldtimeapi': WorldTimeApiPlugin,
'youtube_audio_extractor': YouTubeAudioExtractorPlugin,
'dice': DicePlugin,
'deepl_translate': DeeplTranslatePlugin,
'gtts_text_to_speech': GTTSTextToSpeech,
'whois': WhoisPlugin,
}
self.plugins = [plugin_mapping[plugin]() for plugin in enabled_plugins if plugin in plugin_mapping]

def get_functions_specs(self):
"""
Return the list of function specs that can be called by the model
"""
return [spec for specs in map(lambda plugin: plugin.get_spec(), self.plugins) for spec in specs]

async def call_function(self, function_name, arguments):
"""
Call a function based on the name and parameters provided
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return json.dumps({'error': f'Function {function_name} not found'})
return json.dumps(await plugin.execute(function_name, **json.loads(arguments)), default=str)

def get_plugin_source_name(self, function_name) -> str:
"""
Return the source name of the plugin
"""
plugin = self.__get_plugin_by_function_name(function_name)
if not plugin:
return ''
return plugin.get_source_name()

def __get_plugin_by_function_name(self, function_name):
return next((plugin for plugin in self.plugins
if function_name in map(lambda spec: spec.get('name'), plugin.get_spec())), None)
30 changes: 30 additions & 0 deletions bot/plugins/crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Dict

import requests

from .plugin import Plugin


# Author: https://github.com/stumpyfr
class CryptoPlugin(Plugin):
"""
A plugin to fetch the current rate of various cryptocurrencies
"""
def get_source_name(self) -> str:
return "CoinCap"

def get_spec(self) -> [Dict]:
return [{
"name": "get_crypto_rate",
"description": "Get the current rate of various crypto currencies",
"parameters": {
"type": "object",
"properties": {
"asset": {"type": "string", "description": "Asset of the crypto"}
},
"required": ["asset"],
},
}]

async def execute(self, function_name, **kwargs) -> Dict:
return requests.get(f"https://api.coincap.io/v2/rates/{kwargs['asset']}").json()
Loading

0 comments on commit 30d441d

Please sign in to comment.