Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R(2+1)D unit #322

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions classy_vision/models/r2plus1_util.py
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch.nn as nn


def r2plus1_unit(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
dim_mid=None,
):
"""
Implementation of `R(2+1)D unit <https://arxiv.org/abs/1711.11248>`_.
Decompose one 3D conv into one 2D spatial conv and one 1D temporal conv.
Choose the middle dimensionality so that the total No. of parameters
in 2D spatial conv and 1D temporal conv is unchanged.

Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temporal_stride (int): the temporal stride of the bottleneck.
spatial_stride (int): the spatial_stride of the bottleneck.
groups (int): number of groups for the convolution.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
bn_eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dim_mid (Optional[int]): If not None, use the provided channel dimension
for the output of the 2D spatial conv. If None, compute the output
channel dimension of the 2D spatial conv so that the total No. of
model parameters remains unchanged.
"""
if dim_mid is None:
dim_mid = int(dim_out * dim_in * 3 * 3 * 3 / (dim_in * 3 * 3 + dim_out * 3))
logging.info(
"dim_in: %d, dim_out: %d. Set dim_mid to %d" % (dim_in, dim_out, dim_mid)
)
# 1x3x3 group conv, BN, ReLU
conv_middle = nn.Conv3d(
dim_in,
dim_mid,
[1, 3, 3], # kernel
stride=[1, spatial_stride, spatial_stride],
padding=[0, 1, 1],
groups=groups,
bias=False,
)
conv_middle_bn = nn.BatchNorm3d(dim_mid, eps=bn_eps, momentum=bn_mmt)
conv_middle_relu = nn.ReLU(inplace=inplace_relu)
# 3x1x1 group conv
conv = nn.Conv3d(
dim_mid,
dim_out,
[3, 1, 1], # kernel
stride=[temporal_stride, 1, 1],
padding=[1, 0, 0],
groups=groups,
bias=False,
)
return nn.Sequential(conv_middle, conv_middle_bn, conv_middle_relu, conv)
3 changes: 2 additions & 1 deletion classy_vision/models/resnext3d.py
Expand Up @@ -13,10 +13,11 @@
from . import register_model
from .classy_model import ClassyModel, ClassyModelEvaluationMode
from .resnext3d_stage import ResStage
from .resnext3d_stem import ResNeXt3DStem
from .resnext3d_stem import R2Plus1DStem, ResNeXt3DStem


model_stems = {
"r2plus1d_stem": R2Plus1DStem,
"resnext3d_stem": ResNeXt3DStem,
# For more types of model stem, add them below
}
Expand Down
114 changes: 112 additions & 2 deletions classy_vision/models/resnext3d_block.py
Expand Up @@ -6,6 +6,8 @@

import torch.nn as nn

from .r2plus1_util import r2plus1_unit


class BasicTransformation(nn.Module):
"""
Expand Down Expand Up @@ -38,7 +40,28 @@ def __init__(
PyTorch = 1 - BN momentum in Caffe2.
"""
super(BasicTransformation, self).__init__()
self._construct_model(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
)

def _construct_model(
self,
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
):
# 3x3x3 group conv, BN, ReLU.
branch2a = nn.Conv3d(
dim_in,
Expand All @@ -64,12 +87,98 @@ def __init__(
branch2b_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
branch2b_bn.final_transform_op = True

self.basic_transform = nn.Sequential(
self.transform = nn.Sequential(
branch2a, branch2a_bn, branch2a_relu, branch2b, branch2b_bn
)

def forward(self, x):
return self.basic_transform(x)
return self.transform(x)


class BasicR2Plus1DTransformation(BasicTransformation):
"""
Basic transformation: 3x3x3 group conv, 3x3x3 group conv
"""

def __init__(
self,
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu=True,
bn_eps=1e-5,
bn_mmt=0.1,
**kwargs
):
"""
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temporal_stride (int): the temporal stride of the bottleneck.
spatial_stride (int): the spatial_stride of the bottleneck.
groups (int): number of groups for the convolution.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
bn_eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
"""
super(BasicR2Plus1DTransformation, self).__init__(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu=inplace_relu,
bn_eps=bn_eps,
bn_mmt=bn_mmt,
)

def _construct_model(
self,
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
):
# Implementation of R(2+1)D operation <https://arxiv.org/abs/1711.11248>.
# decompose the original 3D conv into one 2D spatial conv and one
# 1D temporal conv
branch2a = r2plus1_unit(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
)
branch2a_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
branch2a_relu = nn.ReLU(inplace=inplace_relu)

branch2b = r2plus1_unit(
dim_out,
dim_out,
1, # temporal_stride
1, # spatial_stride
groups,
inplace_relu,
bn_eps,
bn_mmt,
)
branch2b_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
branch2b_bn.final_transform_op = True

self.transform = nn.Sequential(
branch2a, branch2a_bn, branch2a_relu, branch2b, branch2b_bn
)


class PostactivatedBottleneckTransformation(nn.Module):
Expand Down Expand Up @@ -291,6 +400,7 @@ def forward(self, x):


residual_transformations = {
"basic_r2plus1d_transformation": BasicR2Plus1DTransformation,
"basic_transformation": BasicTransformation,
"postactivated_bottleneck_transformation": PostactivatedBottleneckTransformation,
"preactivated_bottleneck_transformation": PreactivatedBottleneckTransformation,
Expand Down