Skip to content

Commit

Permalink
feat: allow model selection in client (#775)
Browse files Browse the repository at this point in the history
* feat: allow model selection in client

* docs: update client model selection

* docs: revert

* fix: improve endpoint

* fix: rstrip endpoint
  • Loading branch information
ZiniuYu committed Jul 20, 2022
1 parent bc6b72e commit 32b11cd
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions client/clip_client/client.py
Expand Up @@ -69,6 +69,7 @@ def encode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'np.ndarray':
"""Encode images and texts into embeddings where the input is an iterable of raw strings.
Each image and text must be represented as a string. The following strings are acceptable:
Expand All @@ -79,6 +80,7 @@ def encode(
:param content: an iterator of image URIs or sentences, each element is an image or a text sentence as a string.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand All @@ -90,11 +92,13 @@ def encode(
*,
batch_size: Optional[int] = None,
show_progress: bool = False,
parameters: Optional[dict] = None,
) -> 'DocumentArray':
"""Encode images and texts into embeddings where the input is an iterable of :class:`docarray.Document`.
:param content: an iterable of :class:`docarray.Document`, each Document must be filled with `.uri`, `.text` or `.blob`.
:param batch_size: the number of elements in each request when sending ``content``
:param show_progress: if set, show a progress bar
:param parameters: the parameters for the encoding, you can specify the model to use when you have multiple models
:return: the embedding in a numpy ndarray with shape ``[N, D]``. ``N`` is in the same length of ``content``
"""
...
Expand Down Expand Up @@ -185,8 +189,10 @@ def _iter_doc(self, content) -> Generator['Document', None, None]:
)

def _get_post_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on='/',
on=f'/encode/{model_name}'.rstrip('/'),
inputs=self._iter_doc(content),
request_size=kwargs.get('batch_size', 8),
total_docs=len(content) if hasattr(content, '__len__') else None,
Expand Down Expand Up @@ -364,8 +370,10 @@ def _iter_rank_docs(
)

def _get_rank_payload(self, content, kwargs):
parameters = kwargs.get('parameters', {})
model_name = parameters.get('model', '')
payload = dict(
on='/rank',
on=f'/rank/{model_name}'.rstrip('/'),
inputs=self._iter_rank_docs(
content, _source=kwargs.get('source', 'matches')
),
Expand Down

0 comments on commit 32b11cd

Please sign in to comment.