Skip to content

Commit

Permalink
refactor(labeler): use set_embeddings in labeler (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 23, 2021
1 parent 0d8e0b5 commit d8d875f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 28 deletions.
32 changes: 5 additions & 27 deletions finetuner/labeler/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jina import Executor, DocumentArray, requests, DocumentArrayMemmap
from jina.helper import cached_property

from ..helper import get_framework
from ..embedding import set_embeddings
from ..tuner import fit, save


Expand Down Expand Up @@ -34,37 +34,15 @@ def embed(self, docs: DocumentArray, parameters: Dict, **kwargs):
if not docs:
return
self._all_data.reload()
da = self._all_data.sample(
_catalog = self._all_data.sample(
min(len(self._all_data), int(parameters.get('sample_size', 1000)))
)

f_type = get_framework(self._embed_model)

if f_type == 'keras':
da_input = da.blobs
docs_input = docs.blobs
da.embeddings = self._embed_model(da_input).numpy()
docs.embeddings = self._embed_model(docs_input).numpy()
elif f_type == 'torch':
import torch

self._embed_model.eval()
da_input = torch.from_numpy(da.blobs)
docs_input = torch.from_numpy(docs.blobs)
with torch.inference_mode():
da.embeddings = self._embed_model(da_input).detach().numpy()
docs.embeddings = self._embed_model(docs_input).detach().numpy()
elif f_type == 'paddle':
import paddle

self._embed_model.eval()
da_input = paddle.to_tensor(da.blobs)
docs_input = paddle.to_tensor(docs.blobs)
da.embeddings = self._embed_model(da_input).detach().numpy()
docs.embeddings = self._embed_model(docs_input).detach().numpy()
set_embeddings(docs, self._embed_model)
set_embeddings(_catalog, self._embed_model)

docs.match(
da,
_catalog,
metric=self._metric,
limit=int(parameters.get('topk', 10)),
exclude_self=True,
Expand Down
3 changes: 2 additions & 1 deletion finetuner/toydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,13 @@ def _download_fashion_doc(
):

_d = Document(
content=(raw_img / 255.0).astype(np.float32),
content=raw_img,
tags={
'class': int(lbl),
},
)
_d.convert_image_blob_to_uri()
_d.blob = (_d.blob / 255.0).astype(np.float32)
yield _d


Expand Down

0 comments on commit d8d875f

Please sign in to comment.