Skip to content

Commit

Permalink
Merge pull request #22 from ijwfly/feature/api-usage-safety-measures
Browse files Browse the repository at this point in the history
Token usage safety features
  • Loading branch information
ijwfly committed Jan 8, 2024
2 parents 28cd632 + f69f748 commit bdb8ad8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
8 changes: 6 additions & 2 deletions app/bot/message_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ async def process(self, is_cancelled):
chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled
)

async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled):
async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled, recursive_count=0):
if recursive_count >= settings.SUCCESSIVE_FUNCTION_CALLS_LIMIT:
# sometimes model starts to make function call retries indefinitely, this is safety measure
raise ValueError('Model makes too many successive function calls')

response_dialog_message, message_id = await self.handle_response_generator(response_generator)
if response_dialog_message.function_call:
function_name = response_dialog_message.function_call.name
Expand All @@ -104,7 +108,7 @@ async def handle_gpt_response(self, chat_gpt_manager, context_manager, response_
context_dialog_messages = await context_manager.add_message(function_response, function_response_message_id)
response_generator = await chat_gpt_manager.send_user_message(self.user, context_dialog_messages, is_cancelled)

await self.handle_gpt_response(chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled)
await self.handle_gpt_response(chat_gpt_manager, context_manager, response_generator, function_storage, is_cancelled, recursive_count + 1)
else:
dialog_messages = self.split_dialog_message(response_dialog_message)
for dialog_message in dialog_messages:
Expand Down
10 changes: 9 additions & 1 deletion app/context/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ class ContextConfiguration:
long_term_memory_tokens: int
# short term memory is used for storing last messages
short_term_memory_tokens: int

# length of summary to be generated when context is too long
summary_length: int
# hard limit for context size, when this limit is reached, processing is being stopped,
# summarization also cannot be done
hard_max_context_size: int

@staticmethod
def get_config(model: str):
Expand All @@ -29,34 +32,39 @@ def get_config(model: str):
long_term_memory_tokens=512,
short_term_memory_tokens=2560,
summary_length=512,
hard_max_context_size=5*1024,
)
elif model == 'gpt-3.5-turbo-16k':
return ContextConfiguration(
model_name=model,
long_term_memory_tokens=1024,
short_term_memory_tokens=4096,
summary_length=1024,
hard_max_context_size=17*1024,
)
elif model == 'gpt-4':
return ContextConfiguration(
model_name=model,
long_term_memory_tokens=512,
short_term_memory_tokens=2048,
summary_length=1024,
hard_max_context_size=9*1024,
)
elif model == 'gpt-4-1106-preview':
return ContextConfiguration(
model_name=model,
long_term_memory_tokens=512,
short_term_memory_tokens=5120,
summary_length=2048,
hard_max_context_size=13*1024,
)
elif model == 'gpt-4-vision-preview':
return ContextConfiguration(
model_name=model,
long_term_memory_tokens=512,
short_term_memory_tokens=5120,
summary_length=2048,
hard_max_context_size=13*1024,
)
else:
raise ValueError(f'Unknown model name: {model}')
Expand Down
22 changes: 16 additions & 6 deletions app/context/dialog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ async def process_dialog(self, message: types.Message) -> List[DialogMessage]:
# if it's a reply, we need to update activation time of dialog messages to be included in context next time
await self.db.update_activation_dtime([m.id for m in dialog_messages])

if self.user.auto_summarize and count_dialog_messages_tokens(m.message for m in dialog_messages) >= self.context_configuration.short_term_memory_tokens:
to_summarize, to_process = self.split_context_by_token_length(dialog_messages)
summarized_message = await self.summarize_messages(to_summarize)
self.dialog_messages = [summarized_message] + to_process
else:
self.dialog_messages = dialog_messages
self.dialog_messages = await self.summarize_messages_if_needed(dialog_messages)
return self.get_dialog_messages()

def split_context_by_token_length(self, messages: List[Message]):
Expand All @@ -61,6 +56,20 @@ def split_context_by_token_length(self, messages: List[Message]):
else:
return messages, []

async def summarize_messages_if_needed(self, messages: List[Message]):
message_tokens_count = count_dialog_messages_tokens(m.message for m in messages)
if message_tokens_count > self.context_configuration.hard_max_context_size:
# this is safety measure, we should never get here
# if hard limit is exceeded, the context is too big to summarize or to process
raise ValueError(f'Hard context size is exceeded: {message_tokens_count}')

if self.user.auto_summarize and message_tokens_count >= self.context_configuration.short_term_memory_tokens:
to_summarize, to_process = self.split_context_by_token_length(messages)
summarized_message = await self.summarize_messages(to_summarize)
return [summarized_message] + to_process
else:
return messages

async def summarize_messages(self, messages: List[Message]):
summarized, completion_usage = await summarize_messages(
[m.message for m in messages], self.user.current_model, self.context_configuration.summary_length
Expand All @@ -82,6 +91,7 @@ async def add_message_to_dialog(self, dialog_message: DialogMessage, tg_message_
self.user.id, self.chat_id, tg_message_id, dialog_message, self.dialog_messages
)
self.dialog_messages.append(dialog_message)
self.dialog_messages = await self.summarize_messages_if_needed(self.dialog_messages)
return self.get_dialog_messages()

def get_dialog_messages(self) -> List[DialogMessage]:
Expand Down
1 change: 1 addition & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
OPENAI_CHAT_COMPLETION_TEMPERATURE = 0.3
MESSAGE_EXPIRATION_WINDOW = 60 * 60 # 1 hour
POSTGRES_TIMEZONE = pytz.timezone('UTC')
SUCCESSIVE_FUNCTION_CALLS_LIMIT = 12 # limit of successive function calls that model can make

# Database settings
# Change these if you know what you're doing
Expand Down

0 comments on commit bdb8ad8

Please sign in to comment.