Skip to content

Commit

Permalink
Merge branch 'main' into requests-with-retry-enhance
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed May 18, 2023
2 parents d7e5e65 + df55ec5 commit 26c4675
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 40 deletions.
102 changes: 63 additions & 39 deletions haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

try:
import weaviate
from weaviate import client, AuthClientPassword, gql
from weaviate import client, AuthClientPassword, gql, AuthClientCredentials, AuthBearerToken
except (ImportError, ModuleNotFoundError) as ie:
from haystack.utils.import_utils import _optional_component_not_installed

Expand Down Expand Up @@ -78,6 +78,11 @@ def __init__(
timeout_config: tuple = (5, 15),
username: Optional[str] = None,
password: Optional[str] = None,
client_secret: Optional[str] = None,
scope: Optional[str] = "offline_access",
access_token: Optional[str] = None,
expires_in: Optional[int] = 60,
refresh_token: Optional[str] = None,
additional_headers: Optional[Dict[str, Any]] = None,
index: str = "Document",
embedding_dim: int = 768,
Expand All @@ -95,58 +100,58 @@ def __init__(
):
"""
:param host: Weaviate server connection URL for storing and processing documents and vectors.
For more details, refer "https://weaviate.io/developers/weaviate/current/getting-started/installation.html"
:param port: port of Weaviate instance
:param timeout_config: Weaviate Timeout config as a tuple of (retries, time out seconds).
:param username: username (standard authentication via http_auth)
:param password: password (standard authentication via http_auth)
:param additional_headers: additional headers to be included in the requests sent to Weaviate e.g. bearer token
:param index: Index name for document text, embedding and metadata (in Weaviate terminology, this is a "Class" in Weaviate schema).
For more details, see [Weaviate installation](https://weaviate.io/developers/weaviate/current/getting-started/installation.html).
:param port: The port of the Weaviate instance.
:param timeout_config: The Weaviate timeout config as a tuple of (retries, time out seconds).
:param username: The Weaviate username (standard authentication using http_auth).
:param password: Weaviate password (standard authentication using http_auth).
:param client_secret: The client secret to use when using the OIDC Client Credentials authentication flow.
:param scope: The scope of the credentials when using the OIDC Resource Owner Password or Client Credentials authentication flow.
:param access_token: Access token to use when using OIDC and bearer tokens to authenticate.
:param expires_in: The time in seconds after which the access token expires.
:param refresh_token: The refresh token to use when using OIDC and bearer tokens to authenticate.
:param additional_headers: Additional headers to be included in the requests sent to Weaviate, for example the bearer token.
:param index: Index name for document text, embedding, and metadata (in Weaviate terminology, this is a "Class" in the Weaviate schema).
:param embedding_dim: The embedding vector size. Default: 768.
:param content_field: Name of field that might contain the answer and will therefore be passed to the Reader Model (e.g. "full_text").
If no Reader is used (e.g. in FAQ-Style QA) the plain content of this field will just be returned.
:param name_field: Name of field that contains the title of the doc
:param similarity: The similarity function used to compare document vectors. Available options are 'cosine' (default), 'dot_product' and 'l2'.
:param content_field: Name of the field that might contain the answer and is passed to the Reader model (for example, "full_text").
If no Reader is used (for example, in FAQ-Style QA), the plain content of this field is returned.
:param name_field: Name of the field that contains the title of the doc.
:param similarity: The similarity function used to compare document vectors. Available options are 'cosine' (default), 'dot_product', and 'l2'.
'cosine' is recommended for Sentence Transformers.
:param index_type: Index type of any vector object defined in weaviate schema. The vector index type is pluggable.
Currently, HSNW is only supported.
See: https://weaviate.io/developers/weaviate/current/more-resources/performance.html
:param custom_schema: Allows to create custom schema in Weaviate, for more details
See https://weaviate.io/developers/weaviate/current/schema/schema-configuration.html
:param index_type: Index type of any vector object defined in the Weaviate schema. The vector index type is pluggable.
Currently, only HSNW is supported.
See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/more-resources/performance.html).
:param custom_schema: Allows to create a custom schema in Weaviate. For more details,
see [Weaviate documentation](https://weaviate.io/developers/weaviate/current/schema/schema-configuration.html).
:param module_name: Vectorization module to convert data into vectors. Default is "text2vec-trasnformers"
For more details, See https://weaviate.io/developers/weaviate/current/modules/
:param return_embedding: To return document embedding.
:param embedding_field: Name of field containing an embedding vector.
For more details, see [Weaviate documentation](https://weaviate.io/developers/weaviate/current/modules/).
:param return_embedding: Returns document embedding.
:param embedding_field: Name of the field containing an embedding vector.
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_documents:Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
Parameter options: 'skip','overwrite','fail'
skip: Ignore the duplicates documents
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already exists.
:param recreate_index: If set to True, an existing Weaviate index will be deleted and a new one will be
created using the config you are using for initialization. Be aware that all data in the old index will be
fail: Raises an error if the document ID of the document being added already exists.
:param recreate_index: If set to True, deletes an existing Weaviate index and creates a new one using the config you are using for initialization. Note that all data in the old index is
lost if you choose to recreate the index.
:param replication_factor: It sets the Weaviate Class's replication factor in Weaviate at the time of Class creation.
See: https://weaviate.io/developers/weaviate/current/configuration/replication.html
:param replication_factor: Sets the Weaviate Class's replication factor in Weaviate at the time of Class creation.
See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/configuration/replication.html).
"""
super().__init__()

# Connect to Weaviate server using python binding
weaviate_url = f"{host}:{port}"
if username and password:
secret = AuthClientPassword(username, password)
self.weaviate_client = client.Client(
url=weaviate_url,
auth_client_secret=secret,
timeout_config=timeout_config,
additional_headers=additional_headers,
)
else:
self.weaviate_client = client.Client(
url=weaviate_url, timeout_config=timeout_config, additional_headers=additional_headers
)

secret = self._get_auth_secret(
username, password, client_secret, access_token, expires_in, refresh_token, scope
)
self.weaviate_client = client.Client(
url=weaviate_url,
auth_client_secret=secret,
timeout_config=timeout_config,
additional_headers=additional_headers,
)
# Test Weaviate connection
try:
status = self.weaviate_client.is_ready()
Expand Down Expand Up @@ -185,6 +190,25 @@ def __init__(
self._create_schema_and_index(self.index, recreate_index=recreate_index)
self.uuid_format_warning_raised = False

@staticmethod
def _get_auth_secret(
username: Optional[str] = None,
password: Optional[str] = None,
client_secret: Optional[str] = None,
access_token: Optional[str] = None,
expires_in: Optional[int] = 60,
refresh_token: Optional[str] = None,
scope: Optional[str] = "offline_access",
) -> Optional[Union["AuthClientPassword", "AuthClientCredentials", "AuthBearerToken"]]:
if username and password:
return AuthClientPassword(username, password, scope=scope)
elif client_secret:
return AuthClientCredentials(client_secret, scope=scope)
elif access_token:
return AuthBearerToken(access_token, expires_in=expires_in, refresh_token=refresh_token)

return None

def _sanitize_index_name(self, index: Optional[str]) -> Optional[str]:
if index is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ milvus = [
"farm-haystack[sql,only-milvus]",
]
weaviate = [
"weaviate-client>=3.10.0,<4",
"weaviate-client<3.19.0",
]
only-pinecone = [
"pinecone-client>=2.0.11,<3",
Expand Down
20 changes: 20 additions & 0 deletions test/document_stores/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,26 @@ def test_get_embedding_count(self, ds, documents):
ds.write_documents(documents)
assert ds.get_embedding_count() == 9

@pytest.mark.unit
def test__get_auth_secret(self):
# Test with username and password
secret = WeaviateDocumentStore._get_auth_secret("user", "pass", scope="some_scope")
assert isinstance(secret, weaviate.AuthClientPassword)

# Test with client_secret
secret = WeaviateDocumentStore._get_auth_secret(client_secret="client_secret_value", scope="some_scope")
assert isinstance(secret, weaviate.AuthClientCredentials)

# Test with access_token
secret = WeaviateDocumentStore._get_auth_secret(
access_token="access_token_value", expires_in=3600, refresh_token="refresh_token_value"
)
assert isinstance(secret, weaviate.AuthBearerToken)

# Test with no authentication method
secret = WeaviateDocumentStore._get_auth_secret()
assert secret is None

@pytest.mark.unit
def test__get_current_properties(self, mocked_ds):
mocked_ds.weaviate_client.schema.get.return_value = json.loads(
Expand Down

0 comments on commit 26c4675

Please sign in to comment.