Skip to content

Commit

Permalink
refactoring and replace sample_duration with conv1_t_stride
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshohara committed Nov 22, 2018
1 parent afc24c0 commit dc83da1
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 74 deletions.
47 changes: 26 additions & 21 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,55 @@ def generate_model(opt):
'resnet', 'preresnet', 'wideresnet', 'resnext', 'densenet'
]

if opt.sample_duration >= 32:
conv1_t_stride = 2
else:
conv1_t_stride = 1

if opt.model == 'resnet':
assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]

if opt.model_depth == 10:
model = resnet.resnet10(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 18:
model = resnet.resnet18(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 34:
model = resnet.resnet34(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 50:
model = resnet.resnet50(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 101:
model = resnet.resnet101(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 152:
model = resnet.resnet152(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model_depth == 200:
model = resnet.resnet200(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration,
conv1_t_stride=conv1_t_stride,
conv1_t_size=opt.conv1_t_size)
elif opt.model == 'wideresnet':
assert opt.model_depth in [50]
Expand All @@ -82,7 +87,7 @@ def generate_model(opt):
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
k=opt.wide_resnet_k,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model == 'resnext':
assert opt.model_depth in [50, 101, 152]

Expand All @@ -91,67 +96,67 @@ def generate_model(opt):
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
cardinality=opt.resnext_cardinality,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 101:
model = resnext.resnext101(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
cardinality=opt.resnext_cardinality,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 152:
model = resnext.resnext152(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
cardinality=opt.resnext_cardinality,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model == 'preresnet':
assert opt.model_depth in [18, 34, 50, 101, 152, 200]

if opt.model_depth == 18:
model = pre_act_resnet.resnet18(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 34:
model = pre_act_resnet.resnet34(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 50:
model = pre_act_resnet.resnet50(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 101:
model = pre_act_resnet.resnet101(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 152:
model = pre_act_resnet.resnet152(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 200:
model = pre_act_resnet.resnet200(
n_classes=opt.n_classes,
shortcut_type=opt.resnet_shortcut,
sample_duration=opt.sample_duration)
conv1_t_stride=conv1_t_stride)
elif opt.model == 'densenet':
assert opt.model_depth in [121, 169, 201, 264]

if opt.model_depth == 121:
model = densenet.densenet121(
n_classes=opt.n_classes, sample_duration=opt.sample_duration)
n_classes=opt.n_classes, conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 169:
model = densenet.densenet169(
n_classes=opt.n_classes, sample_duration=opt.sample_duration)
n_classes=opt.n_classes, conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 201:
model = densenet.densenet201(
n_classes=opt.n_classes, sample_duration=opt.sample_duration)
n_classes=opt.n_classes, conv1_t_stride=conv1_t_stride)
elif opt.model_depth == 264:
model = densenet.densenet264(
n_classes=opt.n_classes, sample_duration=opt.sample_duration)
n_classes=opt.n_classes, conv1_t_stride=conv1_t_stride)

if not opt.no_cuda:
model = nn.DataParallel(model, device_ids=None).cuda()
Expand Down
13 changes: 3 additions & 10 deletions models/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DenseNet(nn.Module):
"""

def __init__(self,
sample_duration,
conv1_t_stride=1,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
Expand All @@ -92,13 +92,6 @@ def __init__(self,

super().__init__()

self.sample_duration = sample_duration

if sample_duration >= 32:
first_t_stride = 2
else:
first_t_stride = 1

# First convolution
self.features = nn.Sequential(
OrderedDict([
Expand All @@ -107,8 +100,8 @@ def __init__(self,
3,
num_init_features,
kernel_size=7,
stride=(first_t_stride, 2, 2),
padding=(3, 3, 3),
stride=(conv1_t_stride, 2, 2),
padding=3,
bias=False)),
('norm0', nn.BatchNorm3d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
Expand Down
33 changes: 16 additions & 17 deletions models/pre_act_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F

from . import resnet
from .resnet import conv3x3x3, conv1x1x1, get_inplanes, ResNet


class PreActivationBasicBlock(nn.Module):
Expand Down Expand Up @@ -45,12 +45,11 @@ def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()

self.bn1 = nn.BatchNorm3d(inplanes)
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.conv1 = conv1x1x1(inplanes, planes)
self.bn2 = nn.BatchNorm3d(planes)
self.conv2 = resnet.conv3x3x3(planes, planes, stride)
self.conv2 = conv3x3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.conv3 = conv1x1x1(planes, planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
Expand Down Expand Up @@ -81,46 +80,46 @@ def forward(self, x):
def resnet18(**kwargs):
"""Constructs a ResNet-18 model.
"""
model = resnet.ResNet(PreActivationBasicBlock, [2, 2, 2, 2],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBasicBlock, [2, 2, 2, 2], get_inplanes(),
**kwargs)
return model


def resnet34(**kwargs):
"""Constructs a ResNet-34 model.
"""
model = resnet.ResNet(PreActivationBasicBlock, [3, 4, 6, 3],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBasicBlock, [3, 4, 6, 3], get_inplanes(),
**kwargs)
return model


def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = resnet.ResNet(PreActivationBottleneck, [3, 4, 6, 3],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBottleneck, [3, 4, 6, 3], get_inplanes(),
**kwargs)
return model


def resnet101(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = resnet.ResNet(PreActivationBottleneck, [3, 4, 23, 3],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBottleneck, [3, 4, 23, 3], get_inplanes(),
**kwargs)
return model


def resnet152(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = resnet.ResNet(PreActivationBottleneck, [3, 8, 36, 3],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBottleneck, [3, 8, 36, 3], get_inplanes(),
**kwargs)
return model


def resnet200(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = resnet.ResNet(PreActivationBottleneck, [3, 24, 36, 3],
resnet.get_inplanes(), **kwargs)
model = ResNet(PreActivationBottleneck, [3, 24, 36, 3], get_inplanes(),
**kwargs)
return model
29 changes: 12 additions & 17 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def get_inplanes():


def conv3x3x3(in_planes, out_planes, stride=1):
# 3x3x3 convolution with padding
return nn.Conv3d(
in_planes,
out_planes,
Expand All @@ -21,6 +20,11 @@ def conv3x3x3(in_planes, out_planes, stride=1):
bias=False)


def conv1x1x1(in_planes, out_planes, stride=1):
return nn.Conv3d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

Expand Down Expand Up @@ -60,12 +64,11 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()

self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.conv1 = conv1x1x1(inplanes, planes)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = conv3x3x3(planes, planes, stride)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.conv3 = conv1x1x1(planes, planes * self.expansion)
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
Expand Down Expand Up @@ -100,28 +103,24 @@ def __init__(self,
block,
layers,
block_inplanes,
sample_duration,
conv1_t_size=7,
conv1_t_stride=1,
shortcut_type='B',
n_classes=400):
super().__init__()

self.inplanes = 64

if sample_duration >= 32:
first_t_stride = 2
else:
first_t_stride = 1
self.conv1 = nn.Conv3d(
3,
self.inplanes,
kernel_size=(conv1_t_size, 7, 7),
stride=(first_t_stride, 2, 2),
stride=(conv1_t_stride, 2, 2),
padding=(conv1_t_size // 2, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
shortcut_type)
self.layer2 = self._make_layer(
Expand Down Expand Up @@ -164,12 +163,8 @@ def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
stride=stride)
else:
downsample = nn.Sequential(
nn.Conv3d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False), nn.BatchNorm3d(planes * block.expansion))
conv1x1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm3d(planes * block.expansion))

layers = []
layers.append(
Expand Down

0 comments on commit dc83da1

Please sign in to comment.