-
Notifications
You must be signed in to change notification settings - Fork 400
/
resnet_hparams.py
60 lines (48 loc) · 2.95 KB
/
resnet_hparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.ComposerResNet`."""
from dataclasses import dataclass
import yahp as hp
from composer.models.model_hparams import ModelHparams
from composer.models.resnet.model import ComposerResNet
__all__ = ['ResNetHparams']
@dataclass
class ResNetHparams(ModelHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.ComposerResNet`.
Args:
model_name (str): Name of the ResNet model instance. Either [``"resnet18"``, ``"resnet34"``, ``"resnet50"``, ``"resnet101"``,
``"resnet152"``].
num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``1000``.
pretrained (bool, optional): If True, use ImageNet pretrained weights. Default: ``False``.
groups (int, optional): Number of filter groups for the 3x3 convolution layer in bottleneck blocks. Default: ``1``.
width_per_group (int, optional): Initial width for each convolution group. Width doubles after each stage.
Default: ``64``.
initializers (List[Initializer], optional): Initializers for the model. ``None`` for no initialization.
Default: ``None``.
"""
model_name: str = hp.optional(
f"ResNet architecture to instantiate, must be one of {ComposerResNet.valid_model_names}. (default: '')",
default='')
pretrained: bool = hp.optional('If true, use ImageNet pretrained weights. (default: ``False``)', default=False)
groups: int = hp.optional(
'Number of filter groups for the 3x3 convolution layer in bottleneck block. (default: ``1``)', default=1)
width_per_group: int = hp.optional(
'Initial width for each convolution group. Width doubles after each stage. (default: ``64``)', default=64)
loss_name: str = hp.optional(
"Name of loss function. E.g. 'soft_cross_entropy' or 'binary_cross_entropy_with_logits'. (default: ``soft_cross_entropy``)",
default='soft_cross_entropy')
def validate(self):
if self.model_name not in ComposerResNet.valid_model_names:
raise ValueError(f'model_name must be one of {ComposerResNet.valid_model_names}, but got {self.model_name}')
if self.num_classes is None:
raise ValueError('num_classes must be specified')
def initialize_object(self):
if self.num_classes is None:
raise ValueError('num_classes must be specified')
return ComposerResNet(model_name=self.model_name,
num_classes=self.num_classes,
pretrained=self.pretrained,
groups=self.groups,
width_per_group=self.width_per_group,
initializers=self.initializers,
loss_name=self.loss_name)