Skip to content

Commit

Permalink
fix resnext3d_block to enable model_complexity_hook to compute resnex…
Browse files Browse the repository at this point in the history
…t3d model complexity

Differential Revision: D18368224

fbshipit-source-id: 5007178c554c42db433efc177ce91ebaaf8ad9f3
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Nov 7, 2019
1 parent fe2e4ce commit 95d3928
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
2 changes: 1 addition & 1 deletion classy_vision/hooks/model_complexity_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def on_start(
except NotImplementedError:
logging.warning(
"""Model contains unsupported modules:
Could not compute FLOPs for model forward pass"""
Could not compute FLOPs for model forward pass. Exception:""", exc_info=True
)
logging.info(
"Number of parameters in model: %d" % count_params(task.base_model)
Expand Down
77 changes: 39 additions & 38 deletions classy_vision/models/resnext3d_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,22 +315,19 @@ def __init__(
):
super(PostactivatedShortcutTransformation, self).__init__()
# Use skip connection with projection if dim or spatial/temporal res change.
if (dim_in != dim_out) or (spatial_stride != 1) or (temporal_stride != 1):
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=0,
bias=False,
)
self.branch1_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
assert (dim_in != dim_out) or (spatial_stride != 1) or (temporal_stride != 1)
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=0,
bias=False,
)
self.branch1_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)

def forward(self, x):
if hasattr(self, "branch1") and hasattr(self, "branch1_bn"):
return self.branch1_bn(self.branch1(x))
else:
return x
return self.branch1_bn(self.branch1(x))


class PreactivatedShortcutTransformation(nn.Module):
Expand All @@ -353,24 +350,23 @@ def __init__(
):
super(PreactivatedShortcutTransformation, self).__init__()
# Use skip connection with projection if dim or spatial/temporal res change.
if (dim_in != dim_out) or (spatial_stride != 1) or (temporal_stride != 1):
if not disable_pre_activation:
self.branch1_bn = nn.BatchNorm3d(dim_in, eps=bn_eps, momentum=bn_mmt)
self.branch1_relu = nn.ReLU(inplace=inplace_relu)
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=0,
bias=False,
)
assert (dim_in != dim_out) or (spatial_stride != 1) or (temporal_stride != 1)
if not disable_pre_activation:
self.branch1_bn = nn.BatchNorm3d(dim_in, eps=bn_eps, momentum=bn_mmt)
self.branch1_relu = nn.ReLU(inplace=inplace_relu)
self.branch1 = nn.Conv3d(
dim_in,
dim_out,
kernel_size=1,
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=0,
bias=False,
)

def forward(self, x):
if hasattr(self, "branch1_bn") and hasattr(self, "branch1_relu"):
x = self.branch1_relu(self.branch1_bn(x))
if hasattr(self, "branch1"):
x = self.branch1(x)
x = self.branch1(x)
return x


Expand Down Expand Up @@ -432,15 +428,17 @@ def __init__(
assert skip_transformation_type in skip_transformations, (
"unknown skip transformation: %s" % skip_transformation_type
)
self.skip = skip_transformations[skip_transformation_type](
dim_in,
dim_out,
temporal_stride,
spatial_stride,
bn_eps=bn_eps,
bn_mmt=bn_mmt,
disable_pre_activation=disable_pre_activation,
)

if (dim_in != dim_out) or (spatial_stride != 1) or (temporal_stride != 1):
self.skip = skip_transformations[skip_transformation_type](
dim_in,
dim_out,
temporal_stride,
spatial_stride,
bn_eps=bn_eps,
bn_mmt=bn_mmt,
disable_pre_activation=disable_pre_activation,
)

assert residual_transformation_type in residual_transformations, (
"unknown residual transformation: %s" % residual_transformation_type
Expand All @@ -459,6 +457,9 @@ def __init__(
self.relu = nn.ReLU(inplace_relu)

def forward(self, x):
x = self.skip(x) + self.residual(x)
if hasattr(self, "skip"):
x = self.skip(x) + self.residual(x)
else:
x = x + self.residual(x)
x = self.relu(x)
return x

0 comments on commit 95d3928

Please sign in to comment.