Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced Type Hints for Improved Readability and Maintainability: Addressing Type Mismatch in query_quadrant Return Statement #2452

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
self._hnsw_config = self._retrieve_config.get("hnsw_config", None)
self._payload_indexing = self._retrieve_config.get("payload_indexing", False)

def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
def retrieve_docs(self, problem: str, n_results: Optional[int] = 20, search_string: Optional[str] = ""):
"""
Args:
problem (str): the problem to be solved.
Expand Down Expand Up @@ -150,21 +150,21 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =

def create_qdrant_from_dir(
dir_path: str,
max_tokens: int = 4000,
client: QdrantClient = None,
collection_name: str = "all-my-documents",
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
embedding_model: str = "BAAI/bge-small-en-v1.5",
custom_text_split_function: Callable = None,
custom_text_types: List[str] = TEXT_FORMATS,
recursive: bool = True,
extra_docs: bool = False,
parallel: int = 0,
on_disk: bool = False,
max_tokens: Optional[int] = 4000,
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
client: Optional[QdrantClient] = None,
collection_name: Optional[str] = "all-my-documents",
chunk_mode: Optional[str] = "multi_lines",
must_break_at_empty_line: Optional[bool] = True,
embedding_model: Optional[str] = "BAAI/bge-small-en-v1.5",
custom_text_split_function: Optional[Callable] = None,
custom_text_types: Optional[List[str]] = TEXT_FORMATS,
recursive: Optional[bool] = True,
extra_docs: Optional[bool] = False,
parallel: Optional[int] = 0,
on_disk: Optional[bool] = False,
quantization_config: Optional[models.QuantizationConfig] = None,
hnsw_config: Optional[models.HnswConfigDiff] = None,
payload_indexing: bool = False,
payload_indexing: Optional[bool] = False,
qdrant_client_options: Optional[Dict] = {},
):
"""Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a
Expand Down Expand Up @@ -255,11 +255,11 @@ def create_qdrant_from_dir(

def query_qdrant(
query_texts: List[str],
n_results: int = 10,
client: QdrantClient = None,
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "BAAI/bge-small-en-v1.5",
n_results: Optional[int] = 10,
client: Optional[QdrantClient] = None,
collection_name: Optional[str] = "all-my-documents",
search_string: Optional[str] = "",
embedding_model: Optional[str] = "BAAI/bge-small-en-v1.5",
qdrant_client_options: Optional[Dict] = {},
) -> List[List[QueryResponse]]:
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
"""Perform a similarity search with filters on a Qdrant collection
Expand Down Expand Up @@ -304,10 +304,18 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore
),
)

data = {
"ids": [[result.id for result in sublist] for sublist in results],
"documents": [[result.document for result in sublist] for sublist in results],
"distances": [[result.score for result in sublist] for sublist in results],
"metadatas": [[result.metadata for result in sublist] for sublist in results],
}
data = [
[
QueryResponse(
id=result.id,
embedding=result.embedding,
metadata=result.metadata,
document=result.document,
score=result.score,
)
for result in sublist
]
for sublist in results
]

Hk669 marked this conversation as resolved.
Show resolved Hide resolved
return data
27 changes: 14 additions & 13 deletions autogen/math_utils.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the code with a source url, shall we keep as original?
Another issue is that this file is not covered by test right now because the test that covered this file in v0.1 hasn't been migrated to v0.2.

Copy link
Contributor Author

@Hk669 Hk669 May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think, these changes are required if, these utils are being used in any of our modules or the developer. there are some return types to be updated, which is mentioned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sonichi shall i cover the tests for the utils in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that we'll deprecate this util file. It's better to avoid changes to this file in this PR, especially the code which contains a source url.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure @sonichi , lets not make changes anymore to this file. Thanks

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional, Tuple, Union

from autogen import DEFAULT_MODEL, oai

Expand All @@ -9,7 +9,7 @@
}


def solve_problem(problem: str, **config) -> str:
def solve_problem(problem: str, **config) -> Tuple[Optional[str], int]:
"""(openai<1) Solve the math problem.

Args:
Expand All @@ -25,7 +25,7 @@ def solve_problem(problem: str, **config) -> str:
return results.get("voted_answer"), response["cost"]


def remove_boxed(string: str) -> Optional[str]:
def remove_boxed(string: str) -> Union[str, None]:
"""Source: https://github.com/hendrycks/math
Extract the text within a \\boxed{...} environment.
Example:
Expand All @@ -36,15 +36,15 @@ def remove_boxed(string: str) -> Optional[str]:
"""
left = "\\boxed{"
try:
if not all((string[: len(left)] == left, string[-1] == "}")):
if not ((string.startswith(left) and string[-1] == "}")):
raise AssertionError

return string[len(left) : -1]
except Exception:
return None


def last_boxed_only_string(string: str) -> Optional[str]:
def last_boxed_only_string(string: str) -> Union[str, None]:
"""Source: https://github.com/hendrycks/math
Extract the last \\boxed{...} or \\fbox{...} element from a string.
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def _fix_fracs(string: str) -> str:
new_str += substr
else:
try:
if not len(substr) >= 2:
if len(substr) < 2:
raise AssertionError
except Exception:
return string
Expand Down Expand Up @@ -132,7 +132,7 @@ def _fix_a_slash_b(string: str) -> str:
try:
a = int(a_str)
b = int(b_str)
if not string == "{}/{}".format(a, b):
if string != "{}/{}".format(a, b):
raise AssertionError
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
Expand All @@ -147,7 +147,7 @@ def _remove_right_units(string: str) -> str:
"""
if "\\text{ " in string:
splits = string.split("\\text{ ")
if not len(splits) == 2:
if len(splits) != 2:
raise AssertionError
return splits[0]
else:
Expand All @@ -161,16 +161,17 @@ def _fix_sqrt(string: str) -> str:
>>> _fix_sqrt("\\sqrt3")
\\sqrt{3}
"""
if "\\sqrt" not in string:
SQRT_LITERAL = "\\sqrt" # Define a constant for the repeated literal
if SQRT_LITERAL not in string:
return string
splits = string.split("\\sqrt")
splits = string.split(SQRT_LITERAL)
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
new_substr = SQRT_LITERAL + "{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_substr = SQRT_LITERAL + split
new_string += new_substr
return new_string

Expand Down Expand Up @@ -310,7 +311,7 @@ def voting_counts(responses):
return answers


def eval_math_responses(responses, solution=None, **args):
def eval_math_responses(responses, solution=None, **args) -> Dict:
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
"""Select a response for a math problem using voting, and check if the response is correct if the solution is provided.

Args:
Expand Down
56 changes: 30 additions & 26 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import hashlib
import os
import re
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse

import chromadb
Expand Down Expand Up @@ -195,7 +195,9 @@ def split_files_to_chunks(
return chunks, sources


def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True):
def get_files_from_dir(
dir_path: Union[str, List[str]], types: Optional[list] = TEXT_FORMATS, recursive: Optional[bool] = True
):
"""Return a list of all the files in a given directory, a url, a file path or a list of them."""
if len(types) == 0:
raise ValueError("types cannot be empty.")
Expand Down Expand Up @@ -245,7 +247,7 @@ def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMA
return files


def parse_html_to_markdown(html: str, url: str = None) -> str:
def parse_html_to_markdown(html: str, url: Optional[str] = None) -> str:
"""Parse HTML to markdown."""
soup = BeautifulSoup(html, "html.parser")
title = soup.title.string
Expand Down Expand Up @@ -278,14 +280,16 @@ def parse_html_to_markdown(html: str, url: str = None) -> str:

def _generate_file_name_from_url(url: str, max_length=255) -> str:
url_bytes = url.encode("utf-8")
hash = hashlib.blake2b(url_bytes).hexdigest()
url_hash = hashlib.blake2b(url_bytes).hexdigest()
parsed_url = urlparse(url)
file_name = os.path.basename(url)
file_name = f"{parsed_url.netloc}_{file_name}_{hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
file_name = (
f"{parsed_url.netloc}_{file_name}_{url_hash[:min(8, max_length-len(parsed_url.netloc)-len(file_name)-1)]}"
)
return file_name


def get_file_from_url(url: str, save_path: str = None) -> Tuple[str, str]:
def get_file_from_url(url: str, save_path: str = None) -> Union[Tuple[str, str], None]:
"""Download a file from a URL."""
if save_path is None:
save_path = "tmp/chromadb"
Expand Down Expand Up @@ -333,19 +337,19 @@ def is_url(string: str):

def create_vector_db_from_dir(
dir_path: Union[str, List[str]],
max_tokens: int = 4000,
client: API = None,
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
custom_text_split_function: Callable = None,
custom_text_types: List[str] = TEXT_FORMATS,
recursive: bool = True,
extra_docs: bool = False,
max_tokens: Optional[int] = 4000,
client: Optional[API] = None,
db_path: Optional[str] = "tmp/chromadb.db",
collection_name: Optional[str] = "all-my-documents",
get_or_create: Optional[bool] = False,
chunk_mode: Optional[str] = "multi_lines",
must_break_at_empty_line: Optional[bool] = True,
embedding_model: Optional[str] = "all-MiniLM-L6-v2",
embedding_function: Optional[Callable] = None,
custom_text_split_function: Optional[Callable] = None,
custom_text_types: Optional[List[str]] = TEXT_FORMATS,
recursive: Optional[bool] = True,
extra_docs: Optional[bool] = False,
) -> API:
"""Create a vector db from all the files in a given directory, the directory can also be a single file or a url to
a single file. We support chromadb compatible APIs to create the vector db, this function is not required if
Expand Down Expand Up @@ -426,13 +430,13 @@ def create_vector_db_from_dir(

def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
db_path: str = "tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
embedding_function: Callable = None,
n_results: Optional[int] = 10,
client: Optional[API] = None,
db_path: Optional[str] = "tmp/chromadb.db",
collection_name: Optional[str] = "all-my-documents",
search_string: Optional[str] = "",
embedding_model: Optional[str] = "all-MiniLM-L6-v2",
embedding_function: Optional[Callable] = None,
) -> QueryResult:
"""Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db
and query function.
Expand Down
3 changes: 1 addition & 2 deletions autogen/runtime_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def start(logger_type: str = "sqlite", config: Optional[Dict[str, Any]] = None)
is_logging = True
except Exception as e:
logger.error(f"[runtime logging] Failed to start logging: {e}")
finally:
return session_id
return session_id


def log_chat_completion(
Expand Down
Loading