Skip to content

Commit

Permalink
prod_resnext3d model with new Global Reasoning Unit (#181)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #181

- Implement Global Reasoning Networks (`glore`) (https://arxiv.org/abs/1811.12814). The reference implementation is on Github (https://github.com/facebookresearch/GloRe)

- For now, we don't plan to open source `glore`. Thus, we implement a `ProdResNeXt3D` model in `prod_resnext3d.py`. Refactor `ResNeXt3D` to reduce the duplicate code between `ResNeXt3D` and `ProdResNeXt3D`.
   - In the long-run, we are more flexible to add more implementation to `ProdResNeXt3D` to meet prod needs while keeping the open source `ResNeXt3D` implementation minimal.

- Rename the json configs to tell whether glore is `on` or `off`. Therefore
  - `train_kinetics400_glore_off_i3d50` means `glore` is off, but we use the baseline I3D-50 model in `glore` paper.
  - `train_kinetics400_glore_on_i3d50` means `glore` is on and `glore` units are inserted into baseline I3D-50 model.
  - Similar changes are made to SlowFast `sf` json config. Since we only implement baseline I3D model used in SF, which is a slightly different I3D model from I3D model in `glore`, it is always `off`.

Reviewed By: vreis

Differential Revision: D18344566

fbshipit-source-id: ff46f3c49f68cefea3519743205f7a19c7b99596
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Nov 7, 2019
1 parent 95d3928 commit f2a6eee
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 134 deletions.
219 changes: 128 additions & 91 deletions classy_vision/models/resnext3d.py
Expand Up @@ -20,77 +20,25 @@
}


@register_model("resnext3d")
class ResNeXt3D(ClassyModel):
class ResNeXt3DBase(ClassyModel):
def __init__(
self,
input_key,
input_planes,
clip_crop_size,
skip_transformation_type,
residual_transformation_type,
frames_per_clip,
num_blocks,
stem_name,
stem_planes,
stem_temporal_kernel,
stem_spatial_kernel,
stem_maxpool,
stage_planes,
stage_temporal_kernel_basis,
temporal_conv_1x1,
stage_temporal_stride,
stage_spatial_stride,
num_groups,
width_per_group,
zero_init_residual_transform,
):
"""
Implementation of
1) conventional post-activated 3D ResNe(X)t
(https://arxiv.org/abs/1812.03982).
2) pre-activated 3D ResNe(X)t
(https://arxiv.org/abs/1811.12814).
The model consists of one stem, a number of stages, and one or multiple
heads that are attached to different blocks in the stage.
Args:
input_key (str): a key that can index into model input of dict type.
input_planes (int): the channel dimension of the input. Normally 3 is used
for rgb input.
clip_crop_size (int): spatial cropping size of video clip at train time.
skip_transformation_type (str): the type of skip transformation.
residual_transformation_type (str): the type of residual transformation.
frames_per_clip (int): No. of frames in a video clip.
num_blocks (list): list of the number of blocks in stages.
stem_name (str): name of model stem.
stem_planes (int): the output dimension of the convolution in the model
stem.
stem_temporal_kernel (int): the temporal kernel size of the convolution
in the model stem.
stem_spatial_kernel (int): the spatial kernel size of the convolution
in the model stem.
stem_maxpool (bool): If true, perform max pooling.
stage_planes (int): the output channel dimension of the 1st residual stage
stage_temporal_kernel_basis (list): Basis of temporal kernel sizes for
each of the stage.
temporal_conv_1x1 (bool): Only useful for BottleneckTransformation.
In a pathaway, if True, do temporal convolution in the first 1x1
Conv3d. Otherwise, do it in the second 3x3 Conv3d.
stage_temporal_stride (int): the temporal stride of the residual
transformation.
stage_spatial_stride (int): the spatial stride of the the residual
transformation.
num_groups (int): number of groups for the convolution.
num_groups = 1 is for standard ResNet like networks, and
num_groups > 1 is for ResNeXt like networks.
width_per_group (int): No. of channels per group in 2nd (group) conv in the
residual transformation in the first stage
zero_init_residual_transform (bool): if true, the weight of last
op, which could be either BatchNorm3D in post-activated transformation
or Conv3D in pre-activated transformation, in the residual
transformation is initialized to zero
ResNeXt3DBase implements everything in ResNeXt3D model except the
construction of 4 stages. See more details in ResNeXt3D.
"""
super(ResNeXt3D, self).__init__(num_classes=None)
super(ResNeXt3DBase, self).__init__(num_classes=None)

self._input_key = input_key
self.input_planes = input_planes
Expand All @@ -107,39 +55,8 @@ def __init__(
stem_maxpool,
)

num_stages = len(num_blocks)
out_planes = [stage_planes * 2 ** i for i in range(num_stages)]
in_planes = [stem_planes] + out_planes[:-1]
inner_planes = [
num_groups * width_per_group * 2 ** i for i in range(num_stages)
]

stages = []
for s in range(num_stages):
stage = ResStage(
s + 1, # stem is viewed as stage 0, and following stages start from 1
[in_planes[s]],
[out_planes[s]],
[inner_planes[s]],
[stage_temporal_kernel_basis[s]],
[temporal_conv_1x1[s]],
[stage_temporal_stride[s]],
[stage_spatial_stride[s]],
[num_blocks[s]],
[num_groups],
skip_transformation_type,
residual_transformation_type,
block_callback=self.build_attachable_block,
disable_pre_activation=(s == 0),
final_stage=(s == (num_stages - 1)),
)
stages.append(stage)

self.stages = nn.Sequential(*stages)
self._init_parameter(zero_init_residual_transform)

@classmethod
def from_config(cls, config):
def _parse_config(cls, config):
ret_config = {}
required_args = [
"input_planes",
Expand All @@ -152,6 +69,7 @@ def from_config(cls, config):
for arg in required_args:
assert arg in config, "resnext3d model requires argument %s" % arg
ret_config[arg] = config[arg]

# Default setting for model stem
# stem_planes: No. of output channles of conv op in stem
# stem_temporal_kernel: temporal size of conv op in stem
Expand Down Expand Up @@ -220,8 +138,7 @@ def from_config(cls, config):
assert is_pos_int_list(ret_config["stage_spatial_stride"])
assert is_pos_int(ret_config["num_groups"])
assert is_pos_int(ret_config["width_per_group"])

return cls(**ret_config)
return ret_config

def _init_parameter(self, zero_init_residual_transform):
for m in self.modules():
Expand All @@ -236,7 +153,7 @@ def _init_parameter(self, zero_init_residual_transform):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu"
)
if m.bias:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm3d) and m.affine:
if (
Expand Down Expand Up @@ -374,3 +291,123 @@ def validate(self, dataset_output_shape):
# Thus, comparing it with dataset_output_shape will have varying results
# We skip validation and simply return True
return True


@register_model("resnext3d")
class ResNeXt3D(ResNeXt3DBase):
def __init__(
self,
input_key,
input_planes,
clip_crop_size,
skip_transformation_type,
residual_transformation_type,
frames_per_clip,
num_blocks,
stem_name,
stem_planes,
stem_temporal_kernel,
stem_spatial_kernel,
stem_maxpool,
stage_planes,
stage_temporal_kernel_basis,
temporal_conv_1x1,
stage_temporal_stride,
stage_spatial_stride,
num_groups,
width_per_group,
zero_init_residual_transform,
):
"""
Implementation of
1) conventional post-activated 3D ResNe(X)t
(https://arxiv.org/abs/1812.03982).
2) pre-activated 3D ResNe(X)t
(https://arxiv.org/abs/1811.12814).
The model consists of one stem, a number of stages, and one or multiple
heads that are attached to different blocks in the stage.
Args:
input_key (str): a key that can index into model input of dict type.
input_planes (int): the channel dimension of the input. Normally 3 is used
for rgb input.
clip_crop_size (int): spatial cropping size of video clip at train time.
skip_transformation_type (str): the type of skip transformation.
residual_transformation_type (str): the type of residual transformation.
frames_per_clip (int): No. of frames in a video clip.
num_blocks (list): list of the number of blocks in stages.
stem_name (str): name of model stem.
stem_planes (int): the output dimension of the convolution in the model
stem.
stem_temporal_kernel (int): the temporal kernel size of the convolution
in the model stem.
stem_spatial_kernel (int): the spatial kernel size of the convolution
in the model stem.
stem_maxpool (bool): If true, perform max pooling.
stage_planes (int): the output channel dimension of the 1st residual stage
stage_temporal_kernel_basis (list): Basis of temporal kernel sizes for
each of the stage.
temporal_conv_1x1 (bool): Only useful for BottleneckTransformation.
In a pathaway, if True, do temporal convolution in the first 1x1
Conv3d. Otherwise, do it in the second 3x3 Conv3d.
stage_temporal_stride (int): the temporal stride of the residual
transformation.
stage_spatial_stride (int): the spatial stride of the the residual
transformation.
num_groups (int): number of groups for the convolution.
num_groups = 1 is for standard ResNet like networks, and
num_groups > 1 is for ResNeXt like networks.
width_per_group (int): No. of channels per group in 2nd (group) conv in the
residual transformation in the first stage
zero_init_residual_transform (bool): if true, the weight of last
op, which could be either BatchNorm3D in post-activated transformation
or Conv3D in pre-activated transformation, in the residual
transformation is initialized to zero
"""
super(ResNeXt3D, self).__init__(
input_key,
input_planes,
clip_crop_size,
frames_per_clip,
num_blocks,
stem_name,
stem_planes,
stem_temporal_kernel,
stem_spatial_kernel,
stem_maxpool,
)

num_stages = len(num_blocks)
out_planes = [stage_planes * 2 ** i for i in range(num_stages)]
in_planes = [stem_planes] + out_planes[:-1]
inner_planes = [
num_groups * width_per_group * 2 ** i for i in range(num_stages)
]

stages = []
for s in range(num_stages):
stage = ResStage(
s + 1, # stem is viewed as stage 0, and following stages start from 1
[in_planes[s]],
[out_planes[s]],
[inner_planes[s]],
[stage_temporal_kernel_basis[s]],
[temporal_conv_1x1[s]],
[stage_temporal_stride[s]],
[stage_spatial_stride[s]],
[num_blocks[s]],
[num_groups],
skip_transformation_type,
residual_transformation_type,
block_callback=self.build_attachable_block,
disable_pre_activation=(s == 0),
final_stage=(s == (num_stages - 1)),
)
stages.append(stage)

self.stages = nn.Sequential(*stages)
self._init_parameter(zero_init_residual_transform)

@classmethod
def from_config(cls, config):
ret_config = ResNeXt3D._parse_config(config)
return cls(**ret_config)

0 comments on commit f2a6eee

Please sign in to comment.