Skip to content

Commit

Permalink
Add CSN model to torch video model zoo (#1517)
Browse files Browse the repository at this point in the history
* add ircsn

* update model zoo

* fix lint
  • Loading branch information
bryanyzhu committed Nov 12, 2020
1 parent d4f2893 commit 7e7135f
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ r2plus1d_v1_resnet18_kinetics400 [6]_,Scratch,1,16 (32/2),71.72,340a5952,`config
r2plus1d_v1_resnet34_kinetics400 [6]_,Scratch,1,16 (32/2),72.63,5102fd17,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/r2plus1d_v1_resnet34_kinetics400.yaml>`_
r2plus1d_v1_resnet50_kinetics400 [6]_,Scratch,1,16 (32/2),74.92,9a3b665c,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/r2plus1d_v1_resnet50_kinetics400.yaml>`_
r2plus1d_v2_resnet152_kinetics400 [6]_,IG65M,1,16 (32/2),81.34,42707ffc,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/r2plus1d_v2_resnet152_kinetics400.yaml>`_
ircsn_v2_resnet152_f32s2_kinetics400,IG65M,1,32 (64/2),83.18,82855d2c,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/ircsn_v2_resnet152_f32s2_kinetics400.yaml>`_
i3d_resnet50_v1_kinetics400 [4]_,ImageNet,1,32 (64/2),74.87,18545497,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/i3d_resnet50_v1_kinetics400.yaml>`_
i3d_resnet101_v1_kinetics400 [4]_,ImageNet,1,32 (64/2),75.1,a9bb4f89,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/i3d_resnet101_v1_kinetics400.yaml>`_
i3d_nl5_resnet50_v1_kinetics400 [7]_,ImageNet,1,32 (64/2),75.17,9df1e103,`config <https://raw.githubusercontent.com/dmlc/gluon-cv/master/scripts/action-recognition/configuration/i3d_nl5_resnet50_v1_kinetics400.yaml>`_
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
('d0e8603120ab02118a8973d52a26b8296d1b4078', 'psp_resnet101_citys'),
('ef2bb40ad8f8f59f451969b2fabe4e548394e80a', 'deeplab_v3b_plus_wideresnet_citys'),
('909742b45d5a3844d6000248aa92fef0ae23a0f0', 'icnet_resnet50_citys'),
('63db8a7938586525256a0bdc6632ed986e4026cf', 'icnet_resnet50_mhpv1'),
('873d381a4bc246c5b9d3660ccf66c2f63d0b4e7c', 'icnet_resnet50_mhpv1'),
('cf6a7bb3d55360933de647a8505f7936003902a4', 'deeplab_resnet50_citys'),
('eb8477a91efc244c85b364c0736664078aaf0e65', 'deeplab_resnet101_citys'),
('95aad0b699ae17c67caa44b3ead4b23474e98954', 'fastscnn_citys'),
Expand Down
1 change: 1 addition & 0 deletions gluoncv/torch/model_zoo/action_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .r2plus1dv1 import *
from .r2plus1dv2 import *
from .tpn import *
from .ircsnv2 import *
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""
Video Classification with Channel-Separated Convolutional Networks
ICCV 2019, https://arxiv.org/abs/1904.02811
Large-scale weakly-supervised pre-training for video action recognition
CVPR 2019, https://arxiv.org/abs/1905.00561
"""

# pylint: disable=missing-function-docstring, missing-class-docstring
import torch
import torch.nn as nn


__all__ = ['ir_csn_resnet152_kinetics400']
__all__ = ['ResNet_IRCSNv2', 'ircsn_v2_resnet152_f32s2_kinetics400']


eps = 1e-3
Expand All @@ -18,7 +20,7 @@ class Affine(nn.Module):
def __init__(self, feature_in):
super(Affine, self).__init__()
self.weight = nn.Parameter(torch.randn(feature_in, 1, 1, 1))
self.bias = nn.Parameter(torch.randn(feature_in,1, 1, 1))
self.bias = nn.Parameter(torch.randn(feature_in, 1, 1, 1))
self.weight.requires_grad = False
self.bias.requires_grad = False

Expand All @@ -27,13 +29,11 @@ def forward(self, x):
return x


class ResNeXtBottleneck(nn.Module):
# expansion = 2

class Bottleneck_IRCSNv2(nn.Module):
def __init__(self, in_planes, planes, stride=1, temporal_stride=1,
down_sample=None, expansion=2, temporal_kernel=3, use_affine=True):

super(ResNeXtBottleneck, self).__init__()
super(Bottleneck_IRCSNv2, self).__init__()
self.expansion = expansion
self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=(1, 1, 1), bias=False, stride=(1, 1, 1))

Expand Down Expand Up @@ -87,7 +87,7 @@ def forward(self, x):
return out


class ResNeXt(nn.Module):
class ResNet_IRCSNv2(nn.Module):
def __init__(self,
block,
block_nums,
Expand All @@ -98,7 +98,7 @@ def __init__(self,
self.in_planes = 64
self.num_classes = num_classes

super(ResNeXt, self).__init__()
super(ResNet_IRCSNv2, self).__init__()

self.conv1 = nn.Conv3d(
3,
Expand Down Expand Up @@ -155,7 +155,7 @@ def _make_layer(self,
layers.append(
block(in_planes, planes, stride, temporal_stride, down_sample, expansion,
temporal_kernel=3, use_affine=self.use_affine))
for i in range(1, blocks):
for _ in range(1, blocks):
layers.append(block(planes * expansion, planes, expansion=expansion,
temporal_kernel=3, use_affine=self.use_affine))

Expand All @@ -182,15 +182,14 @@ def forward(self, x):
return logits


def ir_csn_resnet152_kinetics400(cfg):
model = ResNeXt(ResNeXtBottleneck,
num_classes=cfg.CONFIG.DATA.NUM_CLASSES,
block_nums=[3, 8, 36, 3],
use_affine=cfg.CONFIG.MODEL.USE_AFFINE)
def ircsn_v2_resnet152_f32s2_kinetics400(cfg):
model = ResNet_IRCSNv2(Bottleneck_IRCSNv2,
num_classes=cfg.CONFIG.DATA.NUM_CLASSES,
block_nums=[3, 8, 36, 3],
use_affine=cfg.CONFIG.MODEL.USE_AFFINE)

if cfg.CONFIG.MODEL.PRETRAINED:
from ..model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('ir_csn_resnet152_kinetics400',
model.load_state_dict(torch.load(get_model_file('ircsn_v2_resnet152_f32s2_kinetics400',
tag=cfg.CONFIG.MODEL.PRETRAINED)))

return model
return model
1 change: 0 additions & 1 deletion gluoncv/torch/model_zoo/action_recognition/r2plus1dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
CVPR 2018, https://arxiv.org/abs/1711.11248
Large-scale weakly-supervised pre-training for video action recognition
CVPR 2019, https://arxiv.org/abs/1905.00561
"""
import torch
import torch.nn as nn
Expand Down
1 change: 1 addition & 0 deletions gluoncv/torch/model_zoo/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
('5102fd1736a2a205f5fd7bded5d2e0d2e5ca6307', 'r2plus1d_v1_resnet34_kinetics400'),
('9a3b665c182f81f22d4a105c75fac030ad20a628', 'r2plus1d_v1_resnet50_kinetics400'),
('42707ffcab518cdda93523004360de167863c5d8', 'r2plus1d_v2_resnet152_kinetics400'),
('82855d2c85a888e96477130253b90a4892bdc649', 'ircsn_v2_resnet152_f32s2_kinetics400'),
('368108eb6bca9143318937319c3efec09e0419af', 'tpn_resnet50_f8s8_kinetics400'),
('6bf899df92224d3c0c117c940ef0c95d51fedb26', 'tpn_resnet50_f16s4_kinetics400'),
('27710ce8091a317f50fc55e171b29c096b3d253e', 'tpn_resnet50_f32s2_kinetics400'),
Expand Down
1 change: 1 addition & 0 deletions gluoncv/torch/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'r2plus1d_v1_resnet101_kinetics400': r2plus1d_v1_resnet101_kinetics400,
'r2plus1d_v1_resnet152_kinetics400': r2plus1d_v1_resnet152_kinetics400,
'r2plus1d_v2_resnet152_kinetics400': r2plus1d_v2_resnet152_kinetics400,
'ircsn_v2_resnet152_f32s2_kinetics400': ircsn_v2_resnet152_f32s2_kinetics400,
'tpn_resnet50_f8s8_kinetics400': tpn_resnet50_f8s8_kinetics400,
'tpn_resnet50_f16s4_kinetics400': tpn_resnet50_f16s4_kinetics400,
'tpn_resnet50_f32s2_kinetics400': tpn_resnet50_f32s2_kinetics400,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# ircsn_v2_resnet152_f32s2_kinetics400

DDP_CONFIG:
WORLD_SIZE: 1
WORLD_RANK: 0
GPU_WORLD_SIZE: 8
GPU_WORLD_RANK: 0
DIST_URL: 'tcp://172.31.72.252:23456'
WOLRD_URLS: ['172.31.72.252']
AUTO_RANK_MATCH: True
DIST_BACKEND: 'nccl'
GPU: 0
DISTRIBUTED: True

CONFIG:
TRAIN:
EPOCH_NUM: 58 # finetune from a pretrained model, hence small lr
BATCH_SIZE: 8
LR: 0.000125
LR_POLICY: 'Step'
MOMENTUM: 0.9
W_DECAY: 1e-4
LR_MILESTONE: [32, 48]
STEP: 0.1

VAL:
FREQ: 2
BATCH_SIZE: 8

DATA:
TRAIN_ANNO_PATH: '/home/ubuntu/data/kinetics400/k400_train.txt'
VAL_ANNO_PATH: '/home/ubuntu/data/kinetics400/k400_val.txt'
TRAIN_DATA_PATH: '/home/ubuntu/data/kinetics400/train_256/'
VAL_DATA_PATH: '/home/ubuntu/data/kinetics400/val_256/'
NUM_CLASSES: 400
CLIP_LEN: 32
FRAME_RATE: 2
NUM_SEGMENT: 1
NUM_CROP: 1
TEST_NUM_SEGMENT: 10
TEST_NUM_CROP: 3
MULTIGRID: False
KEEP_ASPECT_RATIO: True

MODEL:
NAME: 'ircsn_v2_resnet152_f32s2_kinetics400'
PRETRAINED: False

LOG:
BASE_PATH: './logs/ircsn_v2_resnet152_f32s2_kinetics400'
LOG_DIR: 'tb_log'
SAVE_DIR: 'checkpoints'
EVAL_DIR: './logs/ircsn_v2_resnet152_f32s2_kinetics400/eval'
SAVE_FREQ: 2

0 comments on commit 7e7135f

Please sign in to comment.