-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add build_model function (#584)
* feat: add build_model function * chore: reordered parameters * chore: updated changelog * test: added unit test for build_model * chore: bump commons and stubs * fix: corrected imports * chore: added detail to docstring * refactor: make device optional * fix: fixed tests * chore: removed unnecessary argument * chore: fix typo in docstring * test: added test for embedding
- Loading branch information
1 parent
cb2b594
commit 01fa333
Showing
3 changed files
with
115 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import pytest | ||
from _finetuner.excepts import SelectModelRequired | ||
from _finetuner.models.inference import ONNXRuntimeInferenceEngine, TorchInferenceEngine | ||
from docarray import Document, DocumentArray | ||
|
||
import finetuner | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'descriptor, select_model, is_onnx, expect_error', | ||
[ | ||
('bert-base-cased', None, False, None), | ||
('bert-base-cased', None, True, None), | ||
('openai/clip-vit-base-patch16', 'clip-text', False, None), | ||
('openai/clip-vit-base-patch16', 'clip-vision', False, None), | ||
('openai/clip-vit-base-patch16', None, False, SelectModelRequired), | ||
('MADE UP MODEL', None, False, ValueError), | ||
], | ||
) | ||
def test_build_model(descriptor, select_model, is_onnx, expect_error): | ||
|
||
if expect_error: | ||
with pytest.raises(expect_error): | ||
model = finetuner.build_model( | ||
name=descriptor, | ||
select_model=select_model, | ||
is_onnx=is_onnx, | ||
) | ||
else: | ||
model = finetuner.build_model( | ||
name=descriptor, select_model=select_model, is_onnx=is_onnx | ||
) | ||
|
||
if is_onnx: | ||
assert isinstance(model, ONNXRuntimeInferenceEngine) | ||
else: | ||
assert isinstance(model, TorchInferenceEngine) | ||
|
||
|
||
@pytest.mark.parametrize('is_onnx', [True, False]) | ||
def test_build_model_embedding(is_onnx): | ||
|
||
model = finetuner.build_model(name="bert-base-cased", is_onnx=is_onnx) | ||
|
||
da = DocumentArray(Document(text="TEST TEXT")) | ||
finetuner.encode(model=model, data=da) | ||
assert da.embeddings is not None | ||
assert isinstance(da.embeddings, np.ndarray) |