Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions timm/models/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']


def parse_model_name(model_name):
def parse_model_name(model_name: str):
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
Expand All @@ -26,7 +26,7 @@ def parse_model_name(model_name):
return 'timm', model_name


def safe_model_name(model_name, remove_source=True):
def safe_model_name(model_name: str, remove_source: bool = True):
# return a filename / path safe model name
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
Expand All @@ -46,27 +46,48 @@ def create_model(
no_jit: Optional[bool] = None,
**kwargs,
):
"""Create a model
"""Create a model.

Lookup model's entrypoint function and pass relevant args to create a new model.

**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
and then the model class __init__(). kwargs values set to None are pruned before passing.
<Tip>
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
and then the model class __init__(). kwargs values set to None are pruned before passing.
</Tip>

Args:
model_name (str): name of model to instantiate
pretrained (bool): load pretrained ImageNet-1k weights if true
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
model_name: Name of model to instantiate.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).

Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are consumed by builder or model __init__()
drop_rate (float): Classifier dropout rate for training.
drop_path_rate (float): Stochastic depth drop rate for training.
global_pool (str): Classifier global pooling type.

Example:

```py
>>> from timm import create_model

>>> # Create a MobileNetV3-Large model with no pretrained weights.
>>> model = create_model('mobilenetv3_large_100')

>>> # Create a MobileNetV3-Large model with pretrained weights.
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
>>> model.num_classes
1000

>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
>>> model.num_classes
10
```
"""
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
Expand Down