Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add trust_remote_code init param to SentenceTransformer embedders #7356

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}

@staticmethod
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, **kwargs):
embedding_backend_id = f"{model}{device}{auth_token}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
embedding_backend = _SentenceTransformersEmbeddingBackend(
model=model, device=device, auth_token=auth_token, trust_remote_code=kwargs.get("trust_remote_code", False)
)
vblagoje marked this conversation as resolved.
Show resolved Hide resolved
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -30,10 +32,19 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""

def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
def __init__(
self,
model: str,
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
model_name_or_path=model,
device=device,
use_auth_token=auth_token.resolve_value() if auth_token else None,
trust_remote_code=trust_remote_code,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
normalize_embeddings: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersDocumentEmbedder component.
Expand All @@ -65,6 +66,9 @@ def __init__(
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param trust_remote_code:
If trust_remote_code is false only HuggingFace verified model architectures are allowed. If true
then custom models and scripts are allowed.
vblagoje marked this conversation as resolved.
Show resolved Hide resolved
"""

self.model = model
Expand All @@ -77,6 +81,7 @@ def __init__(
self.normalize_embeddings = normalize_embeddings
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.trust_remote_code = trust_remote_code

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -103,6 +108,7 @@ def to_dict(self) -> Dict[str, Any]:
normalize_embeddings=self.normalize_embeddings,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
trust_remote_code=self.trust_remote_code,
)

@classmethod
Expand All @@ -127,7 +133,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
trust_remote_code: bool = False,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand All @@ -59,6 +60,9 @@ def __init__(
If True shows a progress bar when running.
:param normalize_embeddings:
If True returned vectors will have length 1.
:param trust_remote_code:
If trust_remote_code is false only HuggingFace verified model architectures are allowed. If true
then custom models and scripts are allowed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with 2bad887

"""

self.model = model
Expand All @@ -69,6 +73,7 @@ def __init__(
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.trust_remote_code = trust_remote_code

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -93,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
trust_remote_code=self.trust_remote_code,
)

@classmethod
Expand All @@ -117,7 +123,10 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device.to_torch_str(), auth_token=self.token
model=self.model,
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
)

@component.output_types(embedding=List[float])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add trust_remote_code parameter to SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder for allowing custom models and scripts.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest.mock import patch, MagicMock
import pytest
from unittest.mock import MagicMock, patch

import numpy as np
from haystack.utils import Secret, ComponentDevice
import pytest

from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.utils import ComponentDevice, Secret


class TestSentenceTransformersDocumentEmbedder:
Expand All @@ -20,6 +21,7 @@ def test_init_default(self):
assert embedder.normalize_embeddings is False
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.trust_remote_code is False

def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
Expand All @@ -33,6 +35,7 @@ def test_init_with_parameters(self):
normalize_embeddings=True,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -44,6 +47,7 @@ def test_init_with_parameters(self):
assert embedder.normalize_embeddings is True
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.trust_remote_code

def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -61,6 +65,7 @@ def test_to_dict(self):
"normalize_embeddings": False,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
"trust_remote_code": False,
},
}

Expand All @@ -76,6 +81,7 @@ def test_to_dict_with_custom_init_parameters(self):
normalize_embeddings=True,
meta_fields_to_embed=["meta_field"],
embedding_separator=" - ",
trust_remote_code=True,
)
data = component.to_dict()

Expand All @@ -91,6 +97,7 @@ def test_to_dict_with_custom_init_parameters(self):
"progress_bar": False,
"normalize_embeddings": True,
"embedding_separator": " - ",
"trust_remote_code": True,
"meta_fields_to_embed": ["meta_field"],
},
}
Expand All @@ -107,6 +114,7 @@ def test_from_dict(self):
"normalize_embeddings": True,
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
Expand All @@ -123,6 +131,7 @@ def test_from_dict(self):
assert component.progress_bar is False
assert component.normalize_embeddings is True
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]

@patch(
Expand All @@ -134,7 +143,9 @@ def test_warmup(self, mocked_factory):
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import patch

import pytest

from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
Expand All @@ -23,10 +25,10 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from unittest.mock import patch, MagicMock
import pytest
from haystack.utils import Secret, ComponentDevice
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret


class TestSentenceTransformersTextEmbedder:
Expand All @@ -18,6 +18,7 @@ def test_init_default(self):
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.trust_remote_code is False

def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
Expand All @@ -29,6 +30,7 @@ def test_init_with_parameters(self):
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -38,6 +40,7 @@ def test_init_with_parameters(self):
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.trust_remote_code

def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -53,6 +56,7 @@ def test_to_dict(self):
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}

Expand All @@ -66,6 +70,7 @@ def test_to_dict_with_custom_init_parameters(self):
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
)
data = component.to_dict()
assert data == {
Expand All @@ -79,6 +84,7 @@ def test_to_dict_with_custom_init_parameters(self):
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"trust_remote_code": True,
},
}

Expand All @@ -99,6 +105,7 @@ def test_from_dict(self):
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
Expand All @@ -110,6 +117,7 @@ def test_from_dict(self):
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand All @@ -118,7 +126,9 @@ def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model", token=None, device=ComponentDevice.from_str("cpu"))
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
)

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand Down