diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py index cb3891a62..d64e08bc7 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner_config.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Callable, Optional, Union from enum import Enum import logging @@ -104,8 +104,8 @@ def build_default_model(self, num_classes: int, pretrained = self.pretrained backbone_name = self.get_backbone_str() - - model = getattr(models, backbone_name)(pretrained=pretrained) + model_factory_func: Callable = getattr(models, backbone_name) + model = model_factory_func(pretrained=pretrained, **self.extra_args) if in_channels != 3: if not backbone_name.startswith('resnet'): diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 4bef6b843..fd2599d04 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -239,6 +239,11 @@ class ModelConfig(Config): None, description='If specified, the model will be built from the ' 'definition from this external source, using Torch Hub.') + extra_args: dict = Field( + {}, + description='Other implementation-specific args that might be useful ' + 'for constructing the default model. This is ignored if using an ' + 'external model.') def get_backbone_str(self): return self.backbone.name diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py index 169ebabb4..aeda0c07e 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py @@ -210,7 +210,9 @@ def build_default_model(self, num_classes: int, in_channels: int, min_size=img_sz, max_size=img_sz, image_mean=image_mean, - image_std=image_std) + image_std=image_std, + **self.extra_args, + ) return model diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py index 5b7ac05d0..7e713b12d 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner_config.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Union from enum import Enum import logging @@ -17,6 +17,9 @@ RegressionRandomWindowGeoDataset) from rastervision.pytorch_learner.utils import adjust_conv_channels +if TYPE_CHECKING: + import torch + log = logging.getLogger(__name__) @@ -108,20 +111,23 @@ def scene_to_dataset(self, class RegressionModel(nn.Module): def __init__(self, - backbone_arch, - out_features, - pretrained=True, - pos_out_inds=None, - prob_out_inds=None): + backbone_arch: str, + out_features: int, + pretrained: bool = True, + pos_out_inds: Optional[Sequence[int]] = None, + prob_out_inds: Optional[Sequence[int]] = None, + **kwargs): super().__init__() - self.backbone = getattr(models, backbone_arch)(pretrained=pretrained) + model_factory_func: Callable = getattr(models, backbone_arch) + self.backbone: nn.Module = model_factory_func( + pretrained=pretrained, **kwargs) in_features = self.backbone.fc.in_features self.backbone.fc = nn.Linear(in_features, out_features) self.pos_out_inds = pos_out_inds self.prob_out_inds = prob_out_inds - def forward(self, x): - out = self.backbone(x) + def forward(self, x: 'torch.Tensor') -> 'torch.Tensor': + out: 'torch.Tensor' = self.backbone(x) if self.pos_out_inds: for ind in self.pos_out_inds: out[:, ind] = out[:, ind].exp() @@ -168,7 +174,8 @@ def build_default_model( out_features, pretrained=pretrained, pos_out_inds=pos_out_inds, - prob_out_inds=prob_out_inds) + prob_out_inds=prob_out_inds, + **self.extra_args) if in_channels != 3: if not backbone_name.startswith('resnet'): diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py index 858ffc973..69c1c6b06 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner_config.py @@ -173,7 +173,8 @@ def build_default_model(self, num_classes: int, model = model_factory_func( num_classes=num_classes, pretrained_backbone=pretrained, - aux_loss=False) + aux_loss=False, + **self.extra_args) if in_channels != 3: if not backbone_name.startswith('resnet'): diff --git a/tests/pytorch_learner/test_object_detection_learner_config.py b/tests/pytorch_learner/test_object_detection_learner_config.py new file mode 100644 index 000000000..e8513405b --- /dev/null +++ b/tests/pytorch_learner/test_object_detection_learner_config.py @@ -0,0 +1,14 @@ +import unittest + +from rastervision.pytorch_learner import Backbone, ObjectDetectionModelConfig + + +class TestObjectDetectionModelConfig(unittest.TestCase): + def test_extra_args(self): + cfg = ObjectDetectionModelConfig( + backbone=Backbone.resnet18, + pretrained=False, + extra_args=dict(box_nms_thresh=0.4)) + model = cfg.build_default_model( + num_classes=2, in_channels=3, img_sz=256) + self.assertEqual(model.roi_heads.nms_thresh, 0.4)