Skip to content

Commit

Permalink
Add test cases for tokens-per-prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
basicthinker committed May 12, 2023
1 parent 90cc619 commit afefe48
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 10 deletions.
8 changes: 4 additions & 4 deletions devchat/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
```json
{
"model": "gpt-3.5-turbo",
"token_limit": 3000,
"tokens-per-prompt": 3000,
"provider": "OpenAI",
"OpenAI": {
"temperature": 0.2
Expand All @@ -116,7 +116,7 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
```json
{
"model": "gpt-4",
"token_limit": 6000,
"tokens-per-prompt": 6000,
"provider": "OpenAI",
"OpenAI": {
"temperature": 0.2,
Expand Down Expand Up @@ -154,8 +154,8 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
chat = OpenAIChat(openai_config)

assistant = Assistant(chat, store)
if 'token_limit' in config:
assistant.token_limit = config['token_limit']
if 'tokens-per-prompt' in config:
assistant.token_limit = config['tokens-per-prompt']

assistant.make_prompt(content, instruct_contents, context_contents,
parent, reference)
Expand Down
9 changes: 4 additions & 5 deletions devchat/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def available_tokens(self) -> int:

def _check_limit(self):
if self._prompt.request_tokens > self.token_limit:
raise ValueError(f"Request tokens {self._prompt.request_tokens} "
raise ValueError(f"Prompt tokens {self._prompt.request_tokens} "
f"beyond limit {self.token_limit}.")

def make_prompt(self, request: str,
Expand All @@ -48,13 +48,12 @@ def make_prompt(self, request: str,
combined_instruct = ''.join(instruct_contents)
self._prompt.append_new(MessageType.INSTRUCT, combined_instruct)
self._check_limit()

# Add context to the prompt
if context_contents:
for context_content in context_contents:
if not self._prompt.append_new(MessageType.CONTEXT, context_content,
self.available_tokens):
return
self._prompt.append_new(MessageType.CONTEXT, context_content)
self._check_limit()

# Add history to the prompt
self._prompt.references = validate_hashes(references)
for reference_hash in self._prompt.references:
Expand Down
3 changes: 3 additions & 0 deletions devchat/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
utils.py - Utility functions for DevChat.
"""
import os
import re
import getpass
Expand Down
63 changes: 62 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
test_cli.py - Tests for the command line interface.
"""
import os
import re
import json
Expand All @@ -7,6 +10,7 @@
from git import Repo
from click.testing import CliRunner
from devchat._cli import main
from devchat.utils import response_tokens

runner = CliRunner()

Expand Down Expand Up @@ -65,7 +69,7 @@ def test_main_with_temp_config_file(git_repo):
config_data = {
'model': 'gpt-3.5-turbo-0301',
'provider': 'OpenAI',
'tokens-per-prompt': 6000,
'tokens-per-prompt': 3000,
'OpenAI': {
'temperature': 0
}
Expand Down Expand Up @@ -115,3 +119,60 @@ def test_main_with_instruct_and_context(git_repo, temp_files): # pylint: disabl
"It is really scorching."])
assert result.exit_code == 0
assert _get_core_content(result.output) == "hot summer\n"


def test_main_response_tokens_exceed_config(git_repo): # pylint: disable=W0613
config_data = {
'model': 'gpt-3.5-turbo',
'provider': 'OpenAI',
'tokens-per-prompt': 2000,
'OpenAI': {
'temperature': 0
}
}

chat_dir = os.path.join(git_repo, ".chat")
if not os.path.exists(chat_dir):
os.makedirs(chat_dir)
temp_config_path = os.path.join(chat_dir, "config.json")

with open(temp_config_path, "w", encoding='utf-8') as temp_config_file:
json.dump(config_data, temp_config_file)

content = ""
while response_tokens(content, config_data["model"]) < config_data["tokens-per-prompt"]:
content += "This is a test. Ignore what I say. This is a test. Ignore what I say."
result = runner.invoke(main, ['prompt', content])
assert result.exit_code != 0
assert "beyond limit" in result.output


def test_main_response_tokens_exceed_config_with_file(git_repo, tmpdir): # pylint: disable=W0613
config_data = {
'model': 'gpt-3.5-turbo',
'provider': 'OpenAI',
'tokens-per-prompt': 2000,
'OpenAI': {
'temperature': 0
}
}

chat_dir = os.path.join(git_repo, ".chat")
if not os.path.exists(chat_dir):
os.makedirs(chat_dir)
temp_config_path = os.path.join(chat_dir, "config.json")

with open(temp_config_path, "w", encoding='utf-8') as temp_config_file:
json.dump(config_data, temp_config_file)

content_file = tmpdir.join("content.txt")
content = ""
while response_tokens(content + "This is a test. Ignore what I say.", config_data["model"]) < \
config_data["tokens-per-prompt"]:
content += "This is a test. Ignore what I say."
content_file.write(content)

input_str = "This is a test. Ignore what I say."
result = runner.invoke(main, ['prompt', '-c', str(content_file), input_str])
assert result.exit_code != 0
assert "beyond limit" in result.output

0 comments on commit afefe48

Please sign in to comment.