Skip to content

Commit

Permalink
feat(helper): set_embedding function for all frameworks (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Oct 23, 2021
1 parent 870c5a2 commit 43480cc
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 34 deletions.
68 changes: 68 additions & 0 deletions finetuner/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Union

from jina import DocumentArray, DocumentArrayMemmap

from .helper import AnyDNN, get_framework


def set_embeddings(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
) -> None:
"""Fill the embedding of Documents inplace by using `embed_model`
:param docs: the Documents to be embedded
:param embed_model: the embedding model written in Keras/Pytorch/Paddle
:param device: the computational device for `embed_model`, can be `cpu`, `cuda`, etc.
"""
fm = get_framework(embed_model)
globals()[f'_set_embeddings_{fm}'](docs, embed_model, device)


def _set_embeddings_keras(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
):
from .tuner.keras import get_device

device = get_device(device)
with device:
embeddings = embed_model(docs.blobs).numpy()

docs.embeddings = embeddings


def _set_embeddings_torch(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
):
from .tuner.pytorch import get_device

device = get_device(device)

import torch

tensor = torch.tensor(docs.blobs, device=device)
with torch.inference_mode():
embeddings = embed_model(tensor).cpu().numpy()

docs.embeddings = embeddings


def _set_embeddings_paddle(
docs: Union[DocumentArray, DocumentArrayMemmap],
embed_model: AnyDNN,
device: str = 'cpu',
):
from .tuner.paddle import get_device

get_device(device)

import paddle

embeddings = embed_model(paddle.Tensor(docs.blobs)).numpy()
docs.embeddings = embeddings
4 changes: 1 addition & 3 deletions finetuner/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def get_framework(dnn_model: AnyDNN) -> str:
elif 'paddle' in dnn_model.__module__:
return 'paddle'
else:
raise ValueError(
f'can not determine the backend from embed_model from {dnn_model.__module__}'
)
raise ValueError(f'can not determine the backend of {dnn_model!r}')


def is_seq_int(tp) -> bool:
Expand Down
8 changes: 1 addition & 7 deletions finetuner/toydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
from jina import Document, DocumentArray
from jina.logging.profile import ProgressBar
from jina.types.document import png_to_buffer

from finetuner import __default_tag_key__

Expand Down Expand Up @@ -306,12 +305,7 @@ def _download_fashion_doc(
'class': int(lbl),
},
)

if kwargs['channels'] == 0:
png_bytes = png_to_buffer(
raw_img, width=28, height=28, resize_method='BILINEAR'
)
_d.uri = 'data:image/png;base64,' + base64.b64encode(png_bytes).decode()
_d.convert_image_blob_to_uri()
yield _d


Expand Down
26 changes: 16 additions & 10 deletions finetuner/tuner/keras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import tensorflow as tf
from jina.logging.profile import ProgressBar
Expand Down Expand Up @@ -169,20 +169,12 @@ def fit(
inputs=eval_data, batch_size=batch_size, shuffle=False
)

if device == 'cuda':
device = '/GPU:0'
elif device == 'cpu':
device = '/CPU:0'
else:
raise ValueError(f'Device {device} not recognized')
self.device = tf.device(device)

_optimizer = self._get_optimizer(optimizer, optimizer_kwargs, learning_rate)

m_train_loss = ScalarSummary('train')
m_eval_loss = ScalarSummary('eval')

with self.device:
with get_device(device):
for epoch in range(epochs):
lt = self._train(
_train_data,
Expand All @@ -209,3 +201,17 @@ def save(self, *args, **kwargs):
"""

self.embed_model.save(*args, **kwargs)


def get_device(device: str):
"""Get tensorflow compute device.
:param device: device name
"""

# translate our own alias into framework-compatible ones
if device == 'cuda':
device = '/GPU:0'
elif device == 'cpu':
device = '/CPU:0'
return tf.device(device)
22 changes: 16 additions & 6 deletions finetuner/tuner/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,7 @@ def fit(
``"cpu"`` and ``"cuda"`` (for GPU)
"""

if device == 'cuda':
paddle.set_device('gpu:0')
elif device == 'cpu':
paddle.set_device('cpu')
else:
raise ValueError(f'Device {device} not recognized')
get_device(device) #: this actually sets the device in Paddle

_optimizer = self._get_optimizer(optimizer, optimizer_kwargs, learning_rate)

Expand Down Expand Up @@ -203,3 +198,18 @@ def save(self, *args, **kwargs):
:param kwargs: Keyword arguments to pass to ``paddle.save`` function
"""
paddle.save(self.embed_model.state_dict(), *args, **kwargs)


def get_device(device: str):
"""Get Paddle compute device.
:param device: device name
"""

# translate our own alias into framework-compatible ones
if device == 'cuda':
paddle.set_device('gpu:0')
elif device == 'cpu':
paddle.set_device('cpu')
else:
paddle.set_device(device)
17 changes: 11 additions & 6 deletions finetuner/tuner/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,7 @@ def fit(
:param device: The device to which to move the model. Supported options are
``"cpu"`` and ``"cuda"`` (for GPU)
"""
if device == 'cpu':
self.device = torch.device('cpu')
elif device == 'cuda':
self.device = torch.device('cuda')
else:
raise ValueError(f'Device {device} not recognized')
self.device = get_device(device)

# Place model on device
self._embed_model = self._embed_model.to(self.device)
Expand Down Expand Up @@ -218,3 +213,13 @@ def save(self, *args, **kwargs):
:param kwargs: Keyword arguments to pass to ``torch.save`` function
"""
torch.save(self.embed_model.state_dict(), *args, **kwargs)


def get_device(device: str):
"""Get Pytorch compute device.
:param device: device name
"""

# translate our own alias into framework-compatible ones
return torch.device(device)
2 changes: 0 additions & 2 deletions tests/integration/fit/test_fit_mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import paddle
import tensorflow as tf
import torch
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import paddle
import pytest
import tensorflow as tf
import torch
from jina import DocumentArray, DocumentArrayMemmap

from finetuner.embedding import set_embeddings
from finetuner.toydata import generate_fashion_match

embed_models = {
'keras': lambda: tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(32),
]
),
'pytorch': lambda: torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(
in_features=28 * 28,
out_features=128,
),
torch.nn.ReLU(),
torch.nn.Linear(in_features=128, out_features=32),
),
'paddle': lambda: paddle.nn.Sequential(
paddle.nn.Flatten(),
paddle.nn.Linear(
in_features=28 * 28,
out_features=128,
),
paddle.nn.ReLU(),
paddle.nn.Linear(in_features=128, out_features=32),
),
}


@pytest.mark.parametrize('framework', ['keras', 'pytorch', 'paddle'])
def test_embedding_docs(framework, tmpdir):
# works for DA
embed_model = embed_models[framework]()
docs = DocumentArray(generate_fashion_match(num_total=100))
set_embeddings(docs, embed_model)
assert docs.embeddings.shape == (100, 32)

# works for DAM
dam = DocumentArrayMemmap(tmpdir)
dam.extend(generate_fashion_match(num_total=42))
set_embeddings(dam, embed_model)
assert dam.embeddings.shape == (42, 32)
Empty file removed tests/unit/test_fit.py
Empty file.

0 comments on commit 43480cc

Please sign in to comment.