Skip to content

Commit

Permalink
Make flatten robust
Browse files Browse the repository at this point in the history
  • Loading branch information
Erotemic committed Oct 21, 2019
1 parent 767b0b8 commit 22da3f4
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions mmdet/models/bbox_heads/convfc_bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import reduce
from operator import mul

import torch.nn as nn

from ..registry import HEADS
Expand Down Expand Up @@ -138,7 +141,9 @@ def forward(self, x):
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)

x = _view_flat_trailing_dims(x)

for fc in self.shared_fcs:
x = self.relu(fc(x))
# separate branches
Expand All @@ -150,7 +155,7 @@ def forward(self, x):
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.view(x_cls.size(0), -1)
x_cls = _view_flat_trailing_dims(x_cls)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))

Expand All @@ -159,7 +164,7 @@ def forward(self, x):
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.view(x_reg.size(0), -1)
x_reg = _view_flat_trailing_dims(x_reg)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))

Expand All @@ -168,6 +173,36 @@ def forward(self, x):
return cls_score, bbox_pred


def _view_flat_trailing_dims(x):
"""
Flattens trailing dimensions
Equivalent to `x.view(x.shape[0], -1)`, but has special handling of the
case where `x.shape[0] == 0`
Args:
x (Tensor): input tensor
Returns:
Tensor: reshaped tensor
Example:
>>> import torch
>>> x = _view_flat_trailing_dims(torch.empty(3, 5, 7))
>>> assert tuple(x.shape) == (3, 35)
>>> x = _view_flat_trailing_dims(torch.empty(0, 5, 7))
>>> assert tuple(x.shape) == (0, 35)
>>> x = _view_flat_trailing_dims(torch.empty(0,))
>>> assert tuple(x.shape) == (0, 1)
"""
if x.numel() == 0:
num_trailing = reduce(mul, x.shape[1:], 1)
x = x.view(x.size(0), num_trailing)
else:
x = x.view(x.size(0), -1)
return x


@HEADS.register_module
class SharedFCBBoxHead(ConvFCBBoxHead):

Expand Down

0 comments on commit 22da3f4

Please sign in to comment.