Skip to content

Commit

Permalink
allow specifying extra args for default model in ModelConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 20, 2023
1 parent 1160cc1 commit a6981b3
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 15 deletions.
@@ -1,4 +1,4 @@
from typing import Union, Optional
from typing import Callable, Optional, Union
from enum import Enum
import logging

Expand Down Expand Up @@ -104,8 +104,8 @@ def build_default_model(self, num_classes: int,

pretrained = self.pretrained
backbone_name = self.get_backbone_str()

model = getattr(models, backbone_name)(pretrained=pretrained)
model_factory_func: Callable = getattr(models, backbone_name)
model = model_factory_func(pretrained=pretrained, **self.extra_args)

if in_channels != 3:
if not backbone_name.startswith('resnet'):
Expand Down
Expand Up @@ -239,6 +239,11 @@ class ModelConfig(Config):
None,
description='If specified, the model will be built from the '
'definition from this external source, using Torch Hub.')
extra_args: dict = Field(
{},
description='Other implementation-specific args that might be useful '
'for constructing the default model. This is ignored if using an '
'external model.')

def get_backbone_str(self):
return self.backbone.name
Expand Down
Expand Up @@ -210,7 +210,9 @@ def build_default_model(self, num_classes: int, in_channels: int,
min_size=img_sz,
max_size=img_sz,
image_mean=image_mean,
image_std=image_std)
image_std=image_std,
**self.extra_args,
)
return model


Expand Down
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Sequence, Union
from enum import Enum
import logging

Expand All @@ -17,6 +17,9 @@
RegressionRandomWindowGeoDataset)
from rastervision.pytorch_learner.utils import adjust_conv_channels

if TYPE_CHECKING:
import torch

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -108,20 +111,23 @@ def scene_to_dataset(self,

class RegressionModel(nn.Module):
def __init__(self,
backbone_arch,
out_features,
pretrained=True,
pos_out_inds=None,
prob_out_inds=None):
backbone_arch: str,
out_features: int,
pretrained: bool = True,
pos_out_inds: Optional[Sequence[int]] = None,
prob_out_inds: Optional[Sequence[int]] = None,
**kwargs):
super().__init__()
self.backbone = getattr(models, backbone_arch)(pretrained=pretrained)
model_factory_func: Callable = getattr(models, backbone_arch)
self.backbone: nn.Module = model_factory_func(
pretrained=pretrained, **kwargs)
in_features = self.backbone.fc.in_features
self.backbone.fc = nn.Linear(in_features, out_features)
self.pos_out_inds = pos_out_inds
self.prob_out_inds = prob_out_inds

def forward(self, x):
out = self.backbone(x)
def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
out: 'torch.Tensor' = self.backbone(x)
if self.pos_out_inds:
for ind in self.pos_out_inds:
out[:, ind] = out[:, ind].exp()
Expand Down Expand Up @@ -168,7 +174,8 @@ def build_default_model(
out_features,
pretrained=pretrained,
pos_out_inds=pos_out_inds,
prob_out_inds=prob_out_inds)
prob_out_inds=prob_out_inds,
**self.extra_args)

if in_channels != 3:
if not backbone_name.startswith('resnet'):
Expand Down
Expand Up @@ -173,7 +173,8 @@ def build_default_model(self, num_classes: int,
model = model_factory_func(
num_classes=num_classes,
pretrained_backbone=pretrained,
aux_loss=False)
aux_loss=False,
**self.extra_args)

if in_channels != 3:
if not backbone_name.startswith('resnet'):
Expand Down
14 changes: 14 additions & 0 deletions tests/pytorch_learner/test_object_detection_learner_config.py
@@ -0,0 +1,14 @@
import unittest

from rastervision.pytorch_learner import Backbone, ObjectDetectionModelConfig


class TestObjectDetectionModelConfig(unittest.TestCase):
def test_extra_args(self):
cfg = ObjectDetectionModelConfig(
backbone=Backbone.resnet18,
pretrained=False,
extra_args=dict(box_nms_thresh=0.4))
model = cfg.build_default_model(
num_classes=2, in_channels=3, img_sz=256)
self.assertEqual(model.roi_heads.nms_thresh, 0.4)

0 comments on commit a6981b3

Please sign in to comment.