Skip to content

Commit

Permalink
Refactor Store and Chat for loading prompts
Browse files Browse the repository at this point in the history
- Add load_prompt() method to Chat abstract class.
- Implement load_prompt() in OpenAIChat class.
- Modify Store to accept Chat instance in constructor.
- Update Store methods to use Chat instance for prompt handling.
- Update tests to reflect changes in Store and Chat classes.
  • Loading branch information
basicthinker committed May 21, 2023
1 parent 49e578c commit d7bda72
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 43 deletions.
22 changes: 15 additions & 7 deletions devchat/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def handle_errors():
sys.exit(os.EX_SOFTWARE)


def init_dir() -> Tuple[dict, Store]:
def init_dir() -> Tuple[dict, str]:
git_root = find_git_root()
chat_dir = os.path.join(git_root, ".chat")
if not os.path.exists(chat_dir):
Expand All @@ -51,9 +51,8 @@ def init_dir() -> Tuple[dict, Store]:
except FileNotFoundError:
config_data = default_config_data

store = Store(chat_dir)
git_ignore(git_root, chat_dir)
return config_data, store
return config_data, chat_dir


@main.command()
Expand Down Expand Up @@ -133,7 +132,7 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
```
"""
config, store = init_dir()
config, chat_dir = init_dir()

with handle_errors():
if content is None:
Expand All @@ -152,6 +151,7 @@ def prompt(content: Optional[str], parent: Optional[str], reference: Optional[Li
openai_config = OpenAIChatConfig(model=model, **config['OpenAI'])

chat = OpenAIChat(openai_config)
store = Store(chat_dir, chat)

assistant = Assistant(chat, store)
if 'tokens-per-prompt' in config:
Expand All @@ -174,9 +174,17 @@ def log(skip, max_count):
"""
Show the prompt history.
"""
_, store = init_dir()

recent_prompts = store.select_recent(skip, skip + max_count)
config, chat_dir = init_dir()
provider = config.get('provider')
recent_prompts = []
if provider == 'OpenAI':
openai_config = OpenAIChatConfig(**config['OpenAI'])
chat = OpenAIChat(openai_config)
store = Store(chat_dir, chat)
recent_prompts = store.select_recent(skip, skip + max_count)
else:
click.echo(f"Error: Invalid LLM in configuration '{provider}'", err=True)
sys.exit(os.EX_DATAERR)

logs = []
for record in recent_prompts:
Expand Down
9 changes: 9 additions & 0 deletions devchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def init_prompt(self, request: str) -> Prompt:
The returned prompt can be combined with more instructions and context.
"""

@abstractmethod
def load_prompt(self, data: dict) -> Prompt:
"""
Load a prompt from a dictionary.
Args:
data (dict): The dictionary containing the prompt data.
"""

@abstractmethod
def complete_response(self, prompt: Prompt) -> str:
"""
Expand Down
26 changes: 3 additions & 23 deletions devchat/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,15 @@ def __init__(self, config: OpenAIChatConfig):
self.config = config

def init_prompt(self, request: str) -> OpenAIPrompt:
"""
Initialize a prompt for the chat system.
Args:
request (str): The basic request of the prompt.
The returned prompt can be combined with more instructions and context.
"""
user, email = get_git_user_info()
prompt = OpenAIPrompt(self.config.model, user, email)
prompt.set_request(request)
return prompt

def complete_response(self, prompt: OpenAIPrompt) -> str:
"""
Retrieve a complete response JSON string from the chat system.
def load_prompt(self, data: dict) -> OpenAIPrompt:
return OpenAIPrompt(**data)

Args:
prompt (Prompt): A prompt of messages representing the conversation.
Returns:
str: A JSON string representing the complete response.
"""
def complete_response(self, prompt: OpenAIPrompt) -> str:
# Filter the config parameters with non-None values
config_params = {
key: value
Expand All @@ -78,14 +66,6 @@ def complete_response(self, prompt: OpenAIPrompt) -> str:
return response

def stream_response(self, prompt: OpenAIPrompt) -> Iterator[str]:
"""
Retrieve a streaming response as an iterator of JSON strings from the chat system.
Args:
prompt (Prompt): A prompt of messages representing the conversation.
Returns:
Iterator[str]: An iterator over JSON strings representing the streaming response events.
"""
# Filter the config parameters with non-None values
config_params = {
key: value
Expand Down
10 changes: 7 additions & 3 deletions devchat/store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from dataclasses import asdict
import os
import shelve
from typing import List
from xml.etree.ElementTree import ParseError
import networkx as nx
from devchat.chat import Chat
from devchat.prompt import Prompt


class Store:
def __init__(self, store_dir: str):
def __init__(self, store_dir: str, chat: Chat):
"""
Initializes a Store instance.
Expand All @@ -17,8 +19,10 @@ def __init__(self, store_dir: str):
store_dir = os.path.expanduser(store_dir)
if not os.path.isdir(store_dir):
os.makedirs(store_dir)

self._graph_path = os.path.join(store_dir, 'prompts.graphml')
self._db_path = os.path.join(store_dir, 'prompts')
self._chat = chat

if os.path.isfile(self._graph_path):
try:
Expand Down Expand Up @@ -61,7 +65,7 @@ def store_prompt(self, prompt: Prompt):
prompt.set_hash()

# Store the prompt object in the shelve database
self._db[prompt.hash] = prompt
self._db[prompt.hash] = asdict(prompt)
self._db.sync()

# Add the prompt to the graph
Expand Down Expand Up @@ -91,7 +95,7 @@ def get_prompt(self, prompt_hash: str) -> Prompt:
raise ValueError(f'Prompt {prompt_hash} not found in the store.')

# Retrieve the prompt object from the shelve database
return self._db[prompt_hash]
return self._chat.load_prompt(self._db[prompt_hash])

def select_recent(self, start: int, end: int) -> List[Prompt]:
"""
Expand Down
19 changes: 9 additions & 10 deletions tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from devchat.openai import OpenAIPrompt
from devchat.openai import OpenAIChatConfig, OpenAIChat
from devchat.store import Store
from devchat.utils import get_git_user_info


def test_get_prompt(tmp_path):
store = Store(tmp_path / "store.graphml")
name, email = get_git_user_info()
prompt = OpenAIPrompt(model="gpt-3.5-turbo", user_name=name, user_email=email)
prompt.set_request("Where was the 2020 World Series played?")
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)
prompt = chat.init_prompt("Where was the 2020 World Series played?")
response_str = '''
{
"id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve",
Expand All @@ -34,14 +33,14 @@ def test_get_prompt(tmp_path):


def test_select_recent(tmp_path):
store = Store(tmp_path / "store.graphml")
name, email = get_git_user_info()
config = OpenAIChatConfig(model="gpt-3.5-turbo")
chat = OpenAIChat(config)
store = Store(tmp_path / "store.graphml", chat)

# Create and store 5 prompts
hashes = []
for index in range(5):
prompt = OpenAIPrompt(model="gpt-3.5-turbo", user_name=name, user_email=email)
prompt.set_request(f"Question {index}")
prompt = chat.init_prompt(f"Question {index}")
response_str = f'''
{{
"id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve",
Expand Down

0 comments on commit d7bda72

Please sign in to comment.