Skip to content

Commit

Permalink
Bug fix: Handle empty instances in FCOS.
Browse files Browse the repository at this point in the history
Summary:
### 🐛 One line bug info: [FCOS implementation in current master](https://github.com/facebookresearch/detectron2/blob/31ec19b3132a3ac609600802dd37b2b40a76b5c9/detectron2/modeling/meta_arch/fcos.py) is unable to handle empty instances.

This bug went unnoticed because: (a) images with empty instances are usually filtered while loading COCO annotations, and (b) FCOS should not encounter empty instances with its default training hyper-parameters. However, if I switch to a more aggressive large-scale jitter (LSJ) cropping augmentation, the model may encounter an image crop without any boxes in it. This crashes training abruptly due to a size mismatch in the pairwise anchor matching matrix.

### Bug fix

This PR lets FCOS handle empty instances by adding a dummy `[0, 0, 0, 0]` box labeled as background (ID = `num_classes`), similar to how it is handled in `RetinaNet` class. Training FCOS with LSJ augmentation does not crash anymore.

**Additional refactor:** While I was working on a fix, I noticed some inconsistencies in variable representations and naming conventions. For example, the `pairwise_match` variable was a `(R: anchor points, M: GT boxes)` matrix — this is a transposed representation of what a `match_quality_matrix` represents in `RetinaNet` and `GeneralizedRCNN`, `(M: GT boxes, R: anchor points)`. I refactored the code to make it more uniform with D2 (11528ce) conventions and make FCOS logic flow similar to RetinaNet, given that RetinaNet was the primary baseline in FCOS paper. My changes include:
  - Refactoring `pairwise_match` to a `match_quality_matrix` with its representation consistent with the rest of meta architectures. Moreover, variables renaming like (`matched_boxes` –> `matched_gt_boxes`, `label` —> `gt_labels`, and `gt_index` —> `matched_indices`) make the code more consistent with the naming convention in rest of meta architectures.
  - Update: After ppwwyyxx , I added this explanation as a comment instead of refactoring code: ~Use a `Matcher` instead of simply doing `pairwise_match.max()`. Original code was replacing indices of unmatched anchors as `-1` and accessing GT labels/boxes by doing `gt_index.clip(0)`, which felt non-ideal.~
  - Change the internal method `match_anchors` to compute and return per-instance `match_quality_matrix` (similar to using `pairwise_iou` in R-CNN). This modifies the old behavior of returning `matched_indices`. Since this method is used internally in `FCOS`, I renamed the method to `_match_anchors`.

### Any API changes?

The call signature of `FCOS` is unchanged. `FCOS.match_anchors()` is now `FCOS._match_anchors()` with a different return value, but it was only used internally by `FCOS.label_anchors()`.

### Verification

I verified my changes by one full training run of FCOS with ResNet-50-FPN on COCO detection, using the builtin config. The validation curves overlap very closely (orange: `master` branch, blue: with my changes).

![image](https://user-images.githubusercontent.com/10494087/147985981-4d0eceb4-2103-468c-9a98-f351441303ae.png)

I should have fixed the training seed in both runs... so additionally, I manually checked the equality of `pairwise_match.tranpose(0, 1)` and my `match_quality_matrix`for first 20 iterations by fixing `train.seed = 0` in config. Everything matches exactly.

Pull Request resolved: #3851

Reviewed By: wat3rBro

Differential Revision: D33971235

Pulled By: zhanghang1989

fbshipit-source-id: 9eca18ef79c2942588cf12ead218a4f89bc8a297
  • Loading branch information
kdexd authored and facebook-github-bot committed Feb 7, 2022
1 parent 7cad0a7 commit ef2c3ab
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 59 deletions.
140 changes: 84 additions & 56 deletions detectron2/modeling/meta_arch/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import List, Optional, Tuple
import torch
from fvcore.nn import sigmoid_focal_loss_jit
from torch import Tensor, nn
from torch import nn
from torch.nn import functional as F

from detectron2.layers import ShapeSpec, batched_nms
from detectron2.modeling.matcher import Matcher
from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance
from detectron2.utils.events import get_event_storage

Expand All @@ -19,7 +20,6 @@

__all__ = ["FCOS"]


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -96,77 +96,100 @@ def forward_training(self, images, features, predictions, gt_instances):
)

@torch.no_grad()
def match_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
def _match_anchors(self, gt_boxes: Boxes, anchors: List[Boxes]):
"""
Match anchors with ground truth boxes.
Match ground-truth boxes to a set of multi-level anchors.
Args:
anchors: #level boxes, from the highest resolution to lower resolution
gt_instances: ground truth instances per image
gt_boxes: Ground-truth boxes from instances of an image.
anchors: List of anchors for each feature map (of different scales).
Returns:
List[Tensor]:
#image tensors, each is a vector of matched gt
indices (or -1 for unmatched anchors) for all anchors.
torch.Tensor
A tensor of shape `(M, R)`, given `M` ground-truth boxes and total
`R` anchor points from all feature levels, indicating the quality
of match between m-th box and r-th anchor. Higher value indicates
better match.
"""
# Naming convention: (M = ground-truth boxes, R = anchor points)
# Anchor points are represented as square boxes of size = stride.
num_anchors_per_level = [len(x) for x in anchors]
anchors = Boxes.cat(anchors) # Rx4
anchor_centers = anchors.get_centers() # Rx2
anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # R
anchors = Boxes.cat(anchors) # (R, 4)
anchor_centers = anchors.get_centers() # (R, 2)
anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # (R, )

lower_bound = anchor_sizes * 4
lower_bound[: num_anchors_per_level[0]] = 0
upper_bound = anchor_sizes * 8
upper_bound[-num_anchors_per_level[-1] :] = float("inf")

matched_indices = []
for gt_per_image in gt_instances:
gt_centers = gt_per_image.gt_boxes.get_centers() # Nx2
# FCOS with center sampling: anchor point must be close enough to gt center.
pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max(
dim=2
).values < self.center_sampling_radius * anchor_sizes[:, None]
pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_per_image.gt_boxes)

# The original FCOS anchor matching rule: anchor point must be inside gt
pairwise_match &= pairwise_dist.min(dim=2).values > 0

# Multilevel anchor matching in FCOS: each anchor is only responsible
# for certain scale range.
pairwise_dist = pairwise_dist.max(dim=2).values
pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (
pairwise_dist < upper_bound[:, None]
)
gt_centers = gt_boxes.get_centers()

# FCOS with center sampling: anchor point must be close enough to
# ground-truth box center.
center_dists = (anchor_centers[None, :, :] - gt_centers[:, None, :]).abs_()
sampling_regions = self.center_sampling_radius * anchor_sizes[None, :]

match_quality_matrix = center_dists.max(dim=2).values < sampling_regions

# Match the GT box with minimum area, if there are multiple GT matches
gt_areas = gt_per_image.gt_boxes.area() # N
pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :])
min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match
matched_idx[min_values < 1e-5] = -1 # Unmatched anchors are assigned -1
pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes)
pairwise_dist = pairwise_dist.permute(1, 0, 2) # (M, R, 4)

# The original FCOS anchor matching rule: anchor point must be inside GT.
match_quality_matrix &= pairwise_dist.min(dim=2).values > 0

# Multilevel anchor matching in FCOS: each anchor is only responsible
# for certain scale range.
pairwise_dist = pairwise_dist.max(dim=2).values
match_quality_matrix &= (pairwise_dist > lower_bound[None, :]) & (
pairwise_dist < upper_bound[None, :]
)
# Match the GT box with minimum area, if there are multiple GT matches.
gt_areas = gt_boxes.area() # (M, )

matched_indices.append(matched_idx)
return matched_indices
match_quality_matrix = match_quality_matrix.to(torch.float32)
match_quality_matrix *= 1e8 - gt_areas[:, None]
return match_quality_matrix # (M, R)

@torch.no_grad()
def label_anchors(self, anchors, gt_instances):
def label_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
"""
Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS
anchor matching rule.
Unlike RetinaNet, there are no ignored anchors.
"""
matched_indices = self.match_anchors(anchors, gt_instances)

matched_labels, matched_boxes = [], []
for gt_index, gt_per_image in zip(matched_indices, gt_instances):
label = gt_per_image.gt_classes[gt_index.clip(min=0)]
label[gt_index < 0] = self.num_classes # background
gt_labels, matched_gt_boxes = [], []

matched_gt_boxes = gt_per_image.gt_boxes[gt_index.clip(min=0)]
for inst in gt_instances:
if len(inst) > 0:
match_quality_matrix = self._match_anchors(inst.gt_boxes, anchors)

matched_labels.append(label)
matched_boxes.append(matched_gt_boxes)
return matched_labels, matched_boxes
# Find matched ground-truth box per anchor. Un-matched anchors are
# assigned -1. This is equivalent to using an anchor matcher as used
# in R-CNN/RetinaNet: `Matcher(thresholds=[1e-5], labels=[0, 1])`
match_quality, matched_idxs = match_quality_matrix.max(dim=0)
matched_idxs[match_quality < 1e-5] = -1

matched_gt_boxes_i = inst.gt_boxes.tensor[matched_idxs.clip(min=0)]
gt_labels_i = inst.gt_classes[matched_idxs.clip(min=0)]

# Anchors with matched_idxs = -1 are labeled background.
gt_labels_i[matched_idxs < 0] = self.num_classes
else:
matched_gt_boxes_i = torch.zeros_like(Boxes.cat(anchors).tensor)
gt_labels_i = torch.full(
(len(matched_gt_boxes_i), ),
fill_value=self.num_classes,
dtype=torch.long,
device=matched_gt_boxes_i.device,
)

gt_labels.append(gt_labels_i)
matched_gt_boxes.append(matched_gt_boxes_i)

return gt_labels, matched_gt_boxes

def losses(
self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
Expand All @@ -176,7 +199,7 @@ def losses(
"loss_centerness" in the returned dict.
"""
num_images = len(gt_labels)
gt_labels = torch.stack(gt_labels) # (N, R)
gt_labels = torch.stack(gt_labels) # (M, R)

pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
num_pos_anchors = pos_mask.sum().item()
Expand All @@ -199,13 +222,13 @@ def losses(
anchors,
self.box2box_transform,
pred_anchor_deltas,
[x.tensor for x in gt_boxes],
gt_boxes,
pos_mask,
box_reg_loss_type="giou",
)

ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # NxR
pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # NxR
ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # (M, R)
pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # (M, R)
ctrness_loss = F.binary_cross_entropy_with_logits(
pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum"
)
Expand All @@ -215,9 +238,11 @@ def losses(
"loss_fcos_ctr": ctrness_loss / normalizer,
}

def compute_ctrness_targets(self, anchors, gt_boxes): # NxR
def compute_ctrness_targets(
self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]
):
anchors = Boxes.cat(anchors).tensor # Rx4
reg_targets = [self.box2box_transform.get_deltas(anchors, m.tensor) for m in gt_boxes]
reg_targets = [self.box2box_transform.get_deltas(anchors, m) for m in gt_boxes]
reg_targets = torch.stack(reg_targets, dim=0) # NxRx4
if len(reg_targets) == 0:
return reg_targets.new_zeros(len(reg_targets))
Expand All @@ -229,7 +254,10 @@ def compute_ctrness_targets(self, anchors, gt_boxes): # NxR
return torch.sqrt(ctrness)

def forward_inference(
self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]]
self,
images: ImageList,
features: List[torch.Tensor],
predictions: List[List[torch.Tensor]],
):
pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
predictions, [self.num_classes, 4, 1]
Expand All @@ -254,8 +282,8 @@ def forward_inference(
def inference_single_image(
self,
anchors: List[Boxes],
box_cls: List[Tensor],
box_delta: List[Tensor],
box_cls: List[torch.Tensor],
box_delta: List[torch.Tensor],
image_size: Tuple[int, int],
):
"""
Expand Down
10 changes: 7 additions & 3 deletions detectron2/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from detectron2 import model_zoo
from detectron2.config import CfgNode, instantiate
from detectron2.data import DatasetCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
Expand All @@ -21,9 +22,12 @@ def get_model_no_weights(config_path):
Like model_zoo.get, but do not load any weights (even pretrained)
"""
cfg = model_zoo.get_config(config_path)
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
return build_model(cfg)
if isinstance(cfg, CfgNode):
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
return build_model(cfg)
else:
return instantiate(cfg.model)


def random_boxes(num_boxes, max_coord=100, device="cpu"):
Expand Down
4 changes: 4 additions & 0 deletions tests/modeling/test_model_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,10 @@ def test_autocast(self):
self.assertEqual(out.scores.dtype, torch.float16)


class FCOSE2ETest(InstanceModelE2ETest, unittest.TestCase):
CONFIG_PATH = "COCO-Detection/fcos_R_50_FPN_1x.py"


class SemSegE2ETest(unittest.TestCase):
CONFIG_PATH = "Misc/semantic_R_50_FPN_1x.yaml"

Expand Down

0 comments on commit ef2c3ab

Please sign in to comment.