Skip to content

Commit

Permalink
feat: add build_model function (#584)
Browse files Browse the repository at this point in the history
* feat: add build_model function

* chore: reordered parameters

* chore: updated changelog

* test: added unit test for build_model

* chore: bump commons and stubs

* fix: corrected imports

* chore: added detail to docstring

* refactor: make device optional

* fix: fixed tests

* chore: removed unnecessary argument

* chore: fix typo in docstring

* test: added test for embedding
  • Loading branch information
LMMilliken committed Oct 27, 2022
1 parent cb2b594 commit 01fa333
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add `build_model` function to create zero-shot models. ([#584](https://github.com/jina-ai/finetuner/pull/584))

- Use latest Hubble with `notebook_login` support. ([#576](https://github.com/jina-ai/finetuner/pull/576))

### Removed
Expand Down
66 changes: 64 additions & 2 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import os
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from _finetuner.runner.stubs import model as model_stub
from docarray import DocumentArray
Expand All @@ -27,6 +27,9 @@
from finetuner.finetuner import Finetuner
from finetuner.model import list_model_classes

if TYPE_CHECKING:
from _finetuner.models.inference import InferenceEngine

ft = Finetuner()


Expand Down Expand Up @@ -308,6 +311,63 @@ def get_token() -> str:
return ft.get_token()


def build_model(
name: str,
model_options: Optional[Dict[str, Any]] = None,
batch_size: int = 32,
select_model: Optional[str] = None,
device: Optional[str] = None,
is_onnx: bool = False,
) -> 'InferenceEngine':
"""
Builds a pre-trained model from a given descriptor.
:param name: Refers to a pre-trained model, see
https://finetuner.jina.ai/walkthrough/choose-backbone/ or use the
:meth:`finetuner.describe_models()` function for a list of all
supported models.
:param model_options: A dictionary of model specific options.
:param batch_size: Incoming documents are fed to the graph in batches, both to
speed-up inference and avoid memory errors. This argument controls the
number of documents that will be put in each batch.
:param select_model: Finetuner run artifacts might contain multiple models. In
such cases you can select which model to deploy using this argument. For CLIP
fine-tuning, you can choose either `clip-vision` or `clip-text`.
:param device: Whether to use the CPU, if set to `cuda`, a Nvidia GPU will be used.
otherwise use `cpu` to run a cpu job.
:param is_onnx: The model output format, either `onnx` or `pt`.
:return: an instance of :class:'TorchInferenceEngine' or
:class:`ONNXINferenceEngine`.
"""
import torch
from _finetuner.models.inference import (
ONNXRuntimeInferenceEngine,
TorchInferenceEngine,
)
from _finetuner.runner.model import RunnerModel

if not device:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

stub = model_stub.get_stub(
name, select_model=select_model, model_options=model_options or {}
)

model = RunnerModel(stub=stub)
if not is_onnx:
return TorchInferenceEngine(
artifact=model,
batch_size=batch_size,
device=device,
)
else:
return ONNXRuntimeInferenceEngine(
artifact=model,
batch_size=batch_size,
device=device,
)


def get_model(
artifact: str,
token: Optional[str] = None,
Expand All @@ -316,7 +376,7 @@ def get_model(
device: Optional[str] = None,
logging_level: str = 'WARNING',
is_onnx: bool = False,
):
) -> 'InferenceEngine':
"""Re-build the model based on the model inference session with ONNX.
:param artifact: Specify a finetuner run artifact. Can be a path to a local
Expand All @@ -343,6 +403,7 @@ def get_model(
..Note::
please install finetuner[full] to include all the dependencies.
"""

import torch
from _finetuner.models.inference import (
ONNXRuntimeInferenceEngine,
Expand Down Expand Up @@ -398,6 +459,7 @@ def encode(
..Note::
please install "finetuner[full]" to include all the dependencies.
"""

from _finetuner.models.inference import ONNXRuntimeInferenceEngine

for batch in data.batch(batch_size, show_progress=True):
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test___init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import pytest
from _finetuner.excepts import SelectModelRequired
from _finetuner.models.inference import ONNXRuntimeInferenceEngine, TorchInferenceEngine
from docarray import Document, DocumentArray

import finetuner


@pytest.mark.parametrize(
'descriptor, select_model, is_onnx, expect_error',
[
('bert-base-cased', None, False, None),
('bert-base-cased', None, True, None),
('openai/clip-vit-base-patch16', 'clip-text', False, None),
('openai/clip-vit-base-patch16', 'clip-vision', False, None),
('openai/clip-vit-base-patch16', None, False, SelectModelRequired),
('MADE UP MODEL', None, False, ValueError),
],
)
def test_build_model(descriptor, select_model, is_onnx, expect_error):

if expect_error:
with pytest.raises(expect_error):
model = finetuner.build_model(
name=descriptor,
select_model=select_model,
is_onnx=is_onnx,
)
else:
model = finetuner.build_model(
name=descriptor, select_model=select_model, is_onnx=is_onnx
)

if is_onnx:
assert isinstance(model, ONNXRuntimeInferenceEngine)
else:
assert isinstance(model, TorchInferenceEngine)


@pytest.mark.parametrize('is_onnx', [True, False])
def test_build_model_embedding(is_onnx):

model = finetuner.build_model(name="bert-base-cased", is_onnx=is_onnx)

da = DocumentArray(Document(text="TEST TEXT"))
finetuner.encode(model=model, data=da)
assert da.embeddings is not None
assert isinstance(da.embeddings, np.ndarray)

0 comments on commit 01fa333

Please sign in to comment.