From 04de292abf47dd17252c3d2833ec785dcae5d164 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Wed, 6 Oct 2021 17:18:35 +0200 Subject: [PATCH] feat(tailor): add high-level framework-agnostic convert (#97) --- finetuner/tailor/__init__.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/finetuner/tailor/__init__.py b/finetuner/tailor/__init__.py index e69de29bb..0a5df2be1 100644 --- a/finetuner/tailor/__init__.py +++ b/finetuner/tailor/__init__.py @@ -0,0 +1,46 @@ +from typing import overload, Optional, Tuple + +from ..helper import get_framework, AnyDNN + + +# Keras Tailor +@overload +def convert( + model: AnyDNN, + freeze: bool = False, + embedding_layer_name: Optional[str] = None, + output_dim: Optional[int] = None, +) -> AnyDNN: + ... + + +# Pytorch and Paddle Tailor +@overload +def convert( + model: AnyDNN, + input_size: Tuple[int, ...], + freeze: bool = False, + embedding_layer_name: Optional[str] = None, + output_dim: Optional[int] = None, + input_dtype: str = 'float32', +) -> AnyDNN: + ... + + +def convert(model: AnyDNN, **kwargs) -> AnyDNN: + f_type = get_framework(model) + + if f_type == 'keras': + from .keras import KerasTailor + + ft = KerasTailor + elif f_type == 'torch': + from .pytorch import PytorchTailor + + ft = PytorchTailor + elif f_type == 'paddle': + from .paddle import PaddleTailor + + ft = PaddleTailor + + return ft(model, **kwargs)().model