Skip to content

Commit

Permalink
feat(helper): add batch_size to embed fn (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Nov 1, 2021
1 parent 43d62f0 commit ae8e399
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions finetuner/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,61 +9,63 @@ def embed(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
batch_size: int = 256,
) -> None:
"""Fill the embedding of Documents inplace by using `embed_model`
:param docs: the Documents to be embedded
:param embed_model: the embedding model written in Keras/Pytorch/Paddle
:param device: the computational device for `embed_model`, can be either `cpu` or `cuda`.
:param batch_size: number of Documents in a batch for embedding
"""
fm = get_framework(embed_model)
globals()[f'_set_embeddings_{fm}'](docs, embed_model, device)
globals()[f'_set_embeddings_{fm}'](docs, embed_model, device, batch_size)


def _set_embeddings_keras(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
batch_size: int = 256,
):
from .tuner.keras import get_device

device = get_device(device)
with device:
embeddings = embed_model(docs.blobs).numpy()

docs.embeddings = embeddings
for b in docs.batch(batch_size):
b.embeddings = embed_model(b.blobs).numpy()


def _set_embeddings_torch(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
batch_size: int = 256,
):
from .tuner.pytorch import get_device

device = get_device(device)

import torch

tensor = torch.tensor(docs.blobs, device=device)
embed_model = embed_model.to(device)
with torch.inference_mode():
embeddings = embed_model(tensor).cpu().detach().numpy()

docs.embeddings = embeddings
for b in docs.batch(batch_size):
tensor = torch.tensor(b.blobs, device=device)
b.embeddings = embed_model(tensor).cpu().detach().numpy()


def _set_embeddings_paddle(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
batch_size: int = 256,
):
from .tuner.paddle import get_device

get_device(device)

import paddle

embeddings = embed_model(paddle.Tensor(docs.blobs)).numpy()
docs.embeddings = embeddings
for b in docs.batch(batch_size):
b.embeddings = embed_model(paddle.Tensor(b.blobs)).numpy()

0 comments on commit ae8e399

Please sign in to comment.