Skip to content

Commit

Permalink
Make prompts dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
basicthinker committed May 21, 2023
1 parent 3e80980 commit 49e578c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
10 changes: 3 additions & 7 deletions devchat/openai/openai_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import json
import math
from typing import List
Expand All @@ -7,18 +8,13 @@
from .openai_message import OpenAIMessage


@dataclass
class OpenAIPrompt(Prompt):
"""
A class to represent a prompt and its corresponding responses from OpenAI APIs.
"""

def __init__(self, model: str, user_name: str, user_email: str):
super().__init__(model, user_name, user_email)
self._id: str = None

@property
def model(self) -> str:
return self._model
_id: str = None

@property
def id(self) -> str:
Expand Down
51 changes: 26 additions & 25 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import hashlib
import math
from typing import Dict, List
from devchat.message import MessageType, Message
from devchat.utils import unix_to_local_datetime


@dataclass
class Prompt(ABC):
"""
A class to represent a prompt and its corresponding responses from the chat API.
Attributes:
_model (str): The name of the language model.
_user_name (str): The name of the user.
_user_email (str): The email address of the user.
model (str): The name of the language model.
user_name (str): The name of the user.
user_email (str): The email address of the user.
_new_messages (dict): The messages for the current round of conversation.
_history_messages (dict): The messages for the history of conversation.
parent (str): The parent prompt hash.
Expand All @@ -24,26 +26,25 @@ class Prompt(ABC):
_hash (str): The hash of the prompt.
"""

def __init__(self, model: str, user_name: str, user_email: str):
self._model: str = model
self._user_name: str = user_name
self._user_email: str = user_email
self._new_messages = {
MessageType.INSTRUCT: [],
'request': None,
MessageType.CONTEXT: [],
'response': {}
}
self._history_messages: Dict[str, Message] = {
MessageType.CONTEXT: [],
MessageType.CHAT: []
}
self.parent: str = None
self.references: List[str] = []
self._timestamp: int = None
self._request_tokens: int = 0
self._response_tokens: int = 0
self._hash: str = None
model: str
user_name: str
user_email: str
_new_messages: Dict = field(default_factory=lambda: {
MessageType.INSTRUCT: [],
'request': None,
MessageType.CONTEXT: [],
'response': {}
})
_history_messages: Dict[str, Message] = field(default_factory=lambda: {
MessageType.CONTEXT: [],
MessageType.CHAT: []
})
parent: str = None
references: List[str] = field(default_factory=list)
_timestamp: int = None
_request_tokens: int = 0
_response_tokens: int = 0
_hash: str = None

@property
def timestamp(self) -> int:
Expand Down Expand Up @@ -156,7 +157,7 @@ def set_hash(self):

def formatted_header(self) -> str:
"""Formatted string header of the prompt."""
formatted_str = f"User: {self._user_name} <{self._user_email}>\n"
formatted_str = f"User: {self.user_name} <{self.user_email}>\n"

local_time = unix_to_local_datetime(self._timestamp)
formatted_str += f"Date: {local_time.strftime('%a %b %d %H:%M:%S %Y %z')}\n\n"
Expand All @@ -183,7 +184,7 @@ def shortlog(self) -> List[dict]:
logs = []
for message in self.response.values():
shortlog_data = {
"user": f"{self._user_name} <{self._user_email}>",
"user": f"{self.user_name} <{self.user_email}>",
"date": self._timestamp,
"context": [msg.to_dict() for msg in self.new_context],
"request": self.request.content,
Expand Down

0 comments on commit 49e578c

Please sign in to comment.