Skip to content

Commit

Permalink
feat: Updated 2 files
Browse files Browse the repository at this point in the history
  • Loading branch information
d-walsh committed Jun 26, 2024
1 parent 732ddf0 commit 10f2c79
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
5 changes: 5 additions & 0 deletions sweepai/config/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@

SENTRY_URL = os.environ.get("SENTRY_URL", None)

DB_ENGINE = os.environ.get("DB_ENGINE", "hnsw")
HNSW_MAX_ELEMENTS = int(os.environ.get("HNSW_MAX_ELEMENTS", 100000))
HNSW_EF_CONSTRUCTION = int(os.environ.get("HNSW_EF_CONSTRUCTION", 200))
HNSW_M = int(os.environ.get("HNSW_M", 16))

CACHE_DIRECTORY = os.environ.get("CACHE_DIRECTORY", "/mnt/caches")

assert OPENAI_API_KEY, "OPENAI_API_KEY is required."
Expand Down
30 changes: 28 additions & 2 deletions sweepai/utils/multi_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

import os
from hnswlib import Index
from loguru import logger
from sweepai.core.chat import ChatGPT
from sweepai.core.entities import Message
Expand Down Expand Up @@ -75,7 +77,7 @@
- Where are the Elasticsearch queries that power the autocomplete suggestions for the site's search bar, and what specific fields are being searched and returned?
- Where is the logic for automatically provisioning and scaling EC2 instances based on CPU and memory usage metrics from CloudWatch in the DevOps scripts?"""

def generate_multi_queries(input_query: str):
def generate_multi_queries(input_query: str, vectors, metadata):
chatgpt = ChatGPT(
messages=[
Message(
Expand All @@ -91,12 +93,36 @@ def generate_multi_queries(input_query: str):
temperature=0.7, # I bumped this and it improved the benchmarks
use_openai=True,
)

if not vectors:
logger.warning("No vectors provided, skipping HNSW index initialization.")
index = None
else:
dim = len(vectors[0])
max_elements = len(vectors)

try:
ef_construction = int(os.environ.get("HNSW_EF_CONSTRUCTION", 200))
m = int(os.environ.get("HNSW_M", 16))
except ValueError:
logger.warning("Invalid value for HNSW_EF_CONSTRUCTION or HNSW_M, using defaults.")
ef_construction = 200
m = 16

# Initialize HNSW index
index = Index(space='cosine', dim=dim)
index.init_index(max_elements=max_elements, ef_construction=ef_construction, M=m)
index.add_items(vectors, metadata)

pattern = re.compile(r"<query>(?P<query>.*?)</query>", re.DOTALL)
queries = []
for q in pattern.finditer(response):
query = q.group("query").strip()
if query:
queries.append(query)
# Use HNSW index to find similar vectors
labels, distances = index.knn_query(query, k=5)
queries.append((query, labels, distances))

logger.debug(f"Generated {len(queries)} queries from the input query.")
return queries

Expand Down

0 comments on commit 10f2c79

Please sign in to comment.