Skip to content

Commit

Permalink
minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 24, 2023
1 parent 244f484 commit 97ad7e9
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
Expand Up @@ -64,7 +64,8 @@ def get_collate_fn(self):
def train_step(self, batch, batch_ind):
x, y = batch
loss_dict = self.model(x, y)
return {'train_loss': loss_dict['total_loss']}
loss_dict['train_loss'] = sum(loss_dict.values())
return loss_dict

def validate_step(self, batch, batch_ind):
x, y = batch
Expand All @@ -83,10 +84,10 @@ def validate_end(self, outputs):
num_class_ids = len(self.cfg.data.class_names)
coco_eval = compute_coco_eval(outs, ys, num_class_ids)

metrics = {'map': 0.0, 'map50': 0.0}
metrics = {'mAP': 0.0, 'mAP50': 0.0}
if coco_eval is not None:
coco_metrics = coco_eval.stats
metrics = {'map': coco_metrics[0], 'map50': coco_metrics[1]}
metrics = {'mAP': coco_metrics[0], 'mAP50': coco_metrics[1]}
return metrics

def predict(self,
Expand Down
Expand Up @@ -6,6 +6,7 @@
import albumentations as A

from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.faster_rcnn import FasterRCNN

from rastervision.core.data import Scene
from rastervision.pipeline.config import (Config, register_config, Field,
Expand All @@ -17,7 +18,6 @@
ObjectDetectionImageDataset, ObjectDetectionSlidingWindowGeoDataset,
ObjectDetectionRandomWindowGeoDataset)
from rastervision.pytorch_learner.utils import adjust_conv_channels
from torchvision.models.detection.faster_rcnn import FasterRCNN

if TYPE_CHECKING:
from rastervision.pytorch_learner.learner_config import SolverConfig
Expand Down
Expand Up @@ -317,7 +317,7 @@ def __init__(self,
def forward(self,
input: torch.Tensor,
targets: Optional[Iterable[BoxList]] = None
) -> Union[dict, List[BoxList]]:
) -> Union[Dict[str, Any], List[BoxList]]:
"""Forward pass.
Args:
Expand All @@ -340,10 +340,7 @@ def forward(self,
# models: a dict with keys, 'boxes' and 'labels'.
# Note: labels (class IDs) must start at 1.
_targets = [self.boxlist_to_model_input_dict(bl) for bl in targets]

loss_dict = self.model(input, _targets)
loss_dict['total_loss'] = sum(list(loss_dict.values()))

return loss_dict

outs = self.model(input)
Expand All @@ -353,8 +350,9 @@ def forward(self,
return boxlists

def boxlist_to_model_input_dict(self, boxlist: BoxList) -> dict:
"""Convert BoxList to a dict compatible with torchvision detection
models. Also, make class labels 1-indexed.
"""Convert BoxList to dict compatible w/ torchvision detection models.
Also, make class labels 1-indexed.
Args:
boxlist (BoxList): A BoxList with a "class_ids" field.
Expand All @@ -369,8 +367,9 @@ def boxlist_to_model_input_dict(self, boxlist: BoxList) -> dict:
}

def model_output_dict_to_boxlist(self, out: dict) -> BoxList:
"""Convert torchvision detection dict to BoxList. Also, exclude any
null classes and make class labels 0-indexed.
"""Convert model output dict to BoxList.
Also, exclude any null classes and make class labels 0-indexed.
Args:
out (dict): A dict output by a torchvision detection model in eval
Expand Down
Expand Up @@ -66,10 +66,9 @@ def validate_step(self, batch, batch_nb):
torch.abs(out - y) / self.target_medians).sum(dim=0)

metrics = {'val_loss': val_loss}
for ind, label in enumerate(self.cfg.data.class_names):
metrics['{}_abs_error'.format(label)] = abs_error[ind]
metrics['{}_scaled_abs_error'.format(label)] = scaled_abs_error[
ind]
for i, label in enumerate(self.cfg.data.class_names):
metrics[f'{label}_abs_error'] = abs_error[i]
metrics[f'{label}_scaled_abs_error'] = scaled_abs_error[i]

return metrics

Expand Down
1 change: 0 additions & 1 deletion tests/pytorch_learner/test_object_detection_utils.py
Expand Up @@ -47,7 +47,6 @@ def test_train_output(self):
self.assertRaises(Exception, lambda: model(x))
out = model(x, y)
self.assertIsInstance(out, dict)
self.assertIn('total_loss', out)

def test_eval_output_with_bogus_class(self):
true_num_classes = 3
Expand Down

0 comments on commit 97ad7e9

Please sign in to comment.