diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py index eaf685aa1..7789f123d 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py @@ -64,7 +64,8 @@ def get_collate_fn(self): def train_step(self, batch, batch_ind): x, y = batch loss_dict = self.model(x, y) - return {'train_loss': loss_dict['total_loss']} + loss_dict['train_loss'] = sum(loss_dict.values()) + return loss_dict def validate_step(self, batch, batch_ind): x, y = batch @@ -83,10 +84,10 @@ def validate_end(self, outputs): num_class_ids = len(self.cfg.data.class_names) coco_eval = compute_coco_eval(outs, ys, num_class_ids) - metrics = {'map': 0.0, 'map50': 0.0} + metrics = {'mAP': 0.0, 'mAP50': 0.0} if coco_eval is not None: coco_metrics = coco_eval.stats - metrics = {'map': coco_metrics[0], 'map50': coco_metrics[1]} + metrics = {'mAP': coco_metrics[0], 'mAP50': coco_metrics[1]} return metrics def predict(self, diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py index aeda0c07e..00dc73e63 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner_config.py @@ -6,6 +6,7 @@ import albumentations as A from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models.detection.faster_rcnn import FasterRCNN from rastervision.core.data import Scene from rastervision.pipeline.config import (Config, register_config, Field, @@ -17,7 +18,6 @@ ObjectDetectionImageDataset, ObjectDetectionSlidingWindowGeoDataset, ObjectDetectionRandomWindowGeoDataset) from rastervision.pytorch_learner.utils import adjust_conv_channels -from torchvision.models.detection.faster_rcnn import FasterRCNN if TYPE_CHECKING: from rastervision.pytorch_learner.learner_config import SolverConfig diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_utils.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_utils.py index 5df52f8f3..3c1e17279 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_utils.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_utils.py @@ -317,7 +317,7 @@ def __init__(self, def forward(self, input: torch.Tensor, targets: Optional[Iterable[BoxList]] = None - ) -> Union[dict, List[BoxList]]: + ) -> Union[Dict[str, Any], List[BoxList]]: """Forward pass. Args: @@ -340,10 +340,7 @@ def forward(self, # models: a dict with keys, 'boxes' and 'labels'. # Note: labels (class IDs) must start at 1. _targets = [self.boxlist_to_model_input_dict(bl) for bl in targets] - loss_dict = self.model(input, _targets) - loss_dict['total_loss'] = sum(list(loss_dict.values())) - return loss_dict outs = self.model(input) @@ -353,8 +350,9 @@ def forward(self, return boxlists def boxlist_to_model_input_dict(self, boxlist: BoxList) -> dict: - """Convert BoxList to a dict compatible with torchvision detection - models. Also, make class labels 1-indexed. + """Convert BoxList to dict compatible w/ torchvision detection models. + + Also, make class labels 1-indexed. Args: boxlist (BoxList): A BoxList with a "class_ids" field. @@ -369,8 +367,9 @@ def boxlist_to_model_input_dict(self, boxlist: BoxList) -> dict: } def model_output_dict_to_boxlist(self, out: dict) -> BoxList: - """Convert torchvision detection dict to BoxList. Also, exclude any - null classes and make class labels 0-indexed. + """Convert model output dict to BoxList. + + Also, exclude any null classes and make class labels 0-indexed. Args: out (dict): A dict output by a torchvision detection model in eval diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py index 7a88dd7b8..09a6f7f7d 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py @@ -66,10 +66,9 @@ def validate_step(self, batch, batch_nb): torch.abs(out - y) / self.target_medians).sum(dim=0) metrics = {'val_loss': val_loss} - for ind, label in enumerate(self.cfg.data.class_names): - metrics['{}_abs_error'.format(label)] = abs_error[ind] - metrics['{}_scaled_abs_error'.format(label)] = scaled_abs_error[ - ind] + for i, label in enumerate(self.cfg.data.class_names): + metrics[f'{label}_abs_error'] = abs_error[i] + metrics[f'{label}_scaled_abs_error'] = scaled_abs_error[i] return metrics diff --git a/tests/pytorch_learner/test_object_detection_utils.py b/tests/pytorch_learner/test_object_detection_utils.py index 5dc4216e6..69a103c9d 100644 --- a/tests/pytorch_learner/test_object_detection_utils.py +++ b/tests/pytorch_learner/test_object_detection_utils.py @@ -47,7 +47,6 @@ def test_train_output(self): self.assertRaises(Exception, lambda: model(x)) out = model(x, y) self.assertIsInstance(out, dict) - self.assertIn('total_loss', out) def test_eval_output_with_bogus_class(self): true_num_classes = 3