Skip to content

Commit

Permalink
Calculate hash the whole prompt using sha256
Browse files Browse the repository at this point in the history
  • Loading branch information
basicthinker committed May 22, 2023
1 parent f14188b commit fe29e5c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
10 changes: 5 additions & 5 deletions devchat/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass, field, asdict
import hashlib
import math
from typing import Dict, List
Expand Down Expand Up @@ -149,10 +149,10 @@ def set_hash(self):
"""Set the hash of the prompt."""
if not self.request or not self.response:
raise ValueError("Prompt is incomplete for hash.")
hash_str = self.request.content
for response in self.response.values():
hash_str += response.content
self._hash = hashlib.sha1(hash_str.encode()).hexdigest()
data = asdict(self)
assert data.pop('_hash') is None
string = str(tuple(sorted(data.items())))
self._hash = hashlib.sha256(string.encode('utf-8')).hexdigest()
return self._hash

def formatted_header(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion devchat/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_prompt(self, prompt_hash: str) -> Prompt:

# Retrieve the prompt object from TinyDB
prompt_data = self._db.search(where('_hash') == prompt_hash)
print(prompt_data)
assert len(prompt_data) == 1
return self._chat.load_prompt(prompt_data[0])

Expand Down
2 changes: 1 addition & 1 deletion devchat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def is_valid_hash(hash_str):
"""Check if a string is a valid hash value."""
# Hash values are usually alphanumeric with a fixed length
# depending on the algorithm used to generate them
pattern = re.compile(r'^[a-fA-F0-9]{40}$') # Example pattern for SHA-1 hash
pattern = re.compile(r'^[a-f0-9]{64}$') # Example pattern for SHA-256 hash
return bool(pattern.match(hash_str))


Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def test_main_no_args(git_repo): # pylint: disable=W0613


def _check_output_format(output) -> bool:
pattern = r"(User: .+ <.+@.+>\nDate: .+\n\n(?:.*\n)*\n(?:prompt [a-f0-9]{40}\n\n?)+)"
pattern = r"(User: .+ <.+@.+>\nDate: .+\n\n(?:.*\n)*\n(?:prompt [a-f0-9]{64}\n\n?)+)"
return bool(re.fullmatch(pattern, output))


def _get_core_content(output) -> str:
header_pattern = r"User: .+ <.+@.+>\nDate: .+\n\n"
footer_pattern = r"\n(?:prompt [a-f0-9]{40}\n\n?)+"
footer_pattern = r"\n(?:prompt [a-f0-9]{64}\n\n?)+"

core_content = re.sub(header_pattern, "", output)
core_content = re.sub(footer_pattern, "", core_content)
Expand Down

0 comments on commit fe29e5c

Please sign in to comment.