Skip to content

Commit

Permalink
Refactoring and minor fix in v3/v3+
Browse files Browse the repository at this point in the history
  • Loading branch information
kazuto1011 committed Mar 26, 2019
1 parent 6755e22 commit 1c2f1f3
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 99 deletions.
4 changes: 2 additions & 2 deletions libs/models/__init__.py
Expand Up @@ -9,11 +9,11 @@

def init_weights(module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight)
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.BatchNorm2d):
Expand Down
11 changes: 6 additions & 5 deletions libs/models/deeplabv1.py
Expand Up @@ -22,11 +22,12 @@ class DeepLabV1(nn.Sequential):

def __init__(self, n_classes, n_blocks):
super(DeepLabV1, self).__init__()
self.add_module("layer1", _Stem())
self.add_module("layer2", _ResLayer(n_blocks[0], 64, 64, 256, 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], 256, 128, 512, 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], 512, 256, 1024, 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], 1024, 512, 2048, 1, 4))
ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
self.add_module("fc", nn.Conv2d(2048, n_classes, 1))


Expand Down
13 changes: 7 additions & 6 deletions libs/models/deeplabv2.py
Expand Up @@ -43,12 +43,13 @@ class DeepLabV2(nn.Sequential):

def __init__(self, n_classes, n_blocks, atrous_rates):
super(DeepLabV2, self).__init__()
self.add_module("layer1", _Stem())
self.add_module("layer2", _ResLayer(n_blocks[0], 64, 64, 256, 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], 256, 128, 512, 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], 512, 256, 1024, 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], 1024, 512, 2048, 1, 4))
self.add_module("aspp", _ASPP(2048, n_classes, atrous_rates))
ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2))
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4))
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates))

def freeze_bn(self):
for m in self.modules():
Expand Down
74 changes: 34 additions & 40 deletions libs/models/deeplabv3.py
Expand Up @@ -16,36 +16,38 @@
from .resnet import _ConvBnReLU, _ResLayer, _Stem


class _ImagePool(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv = _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1)

def forward(self, x):
_, _, H, W = x.shape
h = self.pool(x)
h = self.conv(h)
h = F.interpolate(h, size=(H, W), mode="bilinear", align_corners=False)
return h


class _ASPP(nn.Module):
"""
Atrous spatial pyramid pooling with image-level feature
"""

def __init__(self, n_in, n_out, rates):
def __init__(self, in_ch, out_ch, rates):
super(_ASPP, self).__init__()
self.stages = nn.Module()
self.stages.add_module("c0", _ConvBnReLU(n_in, n_out, 1, 1, 0, 1))
self.stages.add_module("c0", _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1))
for i, rate in enumerate(rates):
self.stages.add_module(
"c{}".format(i + 1),
_ConvBnReLU(n_in, n_out, 3, 1, padding=rate, dilation=rate),
_ConvBnReLU(in_ch, out_ch, 3, 1, padding=rate, dilation=rate),
)
self.imagepool = nn.Sequential(
OrderedDict(
[
("pool", nn.AdaptiveAvgPool2d(1)),
("conv", _ConvBnReLU(n_in, n_out, 1, 1, 0, 1)),
]
)
)
self.stages.add_module("imagepool", _ImagePool(in_ch, out_ch))

def forward(self, x):
h = self.imagepool(x)
h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)]
for stage in self.stages.children():
h += [stage(x)]
h = torch.cat(h, dim=1)
return h
return torch.cat([stage(x) for stage in self.stages.children()], dim=1)


class DeepLabV3(nn.Sequential):
Expand All @@ -56,33 +58,25 @@ class DeepLabV3(nn.Sequential):
def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride):
super(DeepLabV3, self).__init__()

# Stride and dilation
if output_stride == 8:
stride = [1, 2, 1, 1]
dilation = [1, 1, 2, 2]
s = [1, 2, 1, 1]
d = [1, 1, 2, 4]
elif output_stride == 16:
stride = [1, 2, 2, 1]
dilation = [1, 1, 1, 2]

self.add_module("layer1", _Stem())
self.add_module(
"layer2", _ResLayer(n_blocks[0], 64, 64, 256, stride[0], dilation[0])
)
self.add_module(
"layer3", _ResLayer(n_blocks[1], 256, 128, 512, stride[1], dilation[1])
)
self.add_module(
"layer4", _ResLayer(n_blocks[2], 512, 256, 1024, stride[2], dilation[2])
)
self.add_module(
"layer5",
_ResLayer(
n_blocks[3], 1024, 512, 2048, stride[3], dilation[3], multi_grids
),
)
self.add_module("aspp", _ASPP(2048, 256, atrous_rates))
s = [1, 2, 2, 1]
d = [1, 1, 1, 2]

ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0]))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1]))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2]))
self.add_module(
"fc1", _ConvBnReLU(256 * (len(atrous_rates) + 2), 256, 1, 1, 0, 1)
"layer5", _ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], multi_grids)
)
self.add_module("aspp", _ASPP(ch[5], 256, atrous_rates))
concat_ch = 256 * (len(atrous_rates) + 2)
self.add_module("fc1", _ConvBnReLU(concat_ch, 256, 1, 1, 0, 1))
self.add_module("fc2", nn.Conv2d(256, n_classes, kernel_size=1))


Expand Down
27 changes: 14 additions & 13 deletions libs/models/deeplabv3plus.py
Expand Up @@ -25,23 +25,24 @@ class DeepLabV3Plus(nn.Module):
def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride):
super(DeepLabV3Plus, self).__init__()

# Stride and dilation
if output_stride == 8:
stride = [1, 2, 1, 1]
dilation = [1, 1, 2, 2]
s = [1, 2, 1, 1]
d = [1, 1, 2, 4]
elif output_stride == 16:
stride = [1, 2, 2, 1]
dilation = [1, 1, 1, 2]
s = [1, 2, 2, 1]
d = [1, 1, 1, 2]

# Encoder
self.layer1 = _Stem()
self.layer2 = _ResLayer(n_blocks[0], 64, 64, 256, stride[0], dilation[0])
self.layer3 = _ResLayer(n_blocks[1], 256, 128, 512, stride[1], dilation[1])
self.layer4 = _ResLayer(n_blocks[2], 512, 256, 1024, stride[2], dilation[2])
self.layer5 = _ResLayer(
n_blocks[3], 1024, 512, 2048, stride[3], dilation[3], multi_grids
)
self.aspp = _ASPP(2048, 256, atrous_rates)
self.fc1 = _ConvBnReLU(256 * (len(atrous_rates) + 2), 256, 1, 1, 0, 1)
ch = [64 * 2 ** p for p in range(6)]
self.layer1 = _Stem(ch[0])
self.layer2 = _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0])
self.layer3 = _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1])
self.layer4 = _ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2])
self.layer5 = _ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], multi_grids)
self.aspp = _ASPP(ch[5], 256, atrous_rates)
concat_ch = 256 * (len(atrous_rates) + 2)
self.add_module("fc1", _ConvBnReLU(concat_ch, 256, 1, 1, 0, 1))

# Decoder
self.reduce = _ConvBnReLU(256, 48, 1, 1, 0, 1)
Expand Down
65 changes: 32 additions & 33 deletions libs/models/resnet.py
Expand Up @@ -20,6 +20,8 @@
except:
_BATCH_NORM = nn.BatchNorm2d

_BOTTLENECK_EXPANSION = 4


class _ConvBnReLU(nn.Sequential):
"""
Expand Down Expand Up @@ -49,8 +51,9 @@ class _Bottleneck(nn.Module):
Bottleneck block of MSRA ResNet.
"""

def __init__(self, in_ch, mid_ch, out_ch, stride, dilation, downsample):
def __init__(self, in_ch, out_ch, stride, dilation, downsample):
super(_Bottleneck, self).__init__()
mid_ch = out_ch // _BOTTLENECK_EXPANSION
self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
Expand All @@ -73,26 +76,25 @@ class _ResLayer(nn.Sequential):
Residual layer with multi grids
"""

def __init__(
self, n_layers, in_ch, mid_ch, out_ch, stride, dilation, multi_grids=None
):
def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
super(_ResLayer, self).__init__()

if multi_grids is None:
multi_grids = [1 for _ in range(n_layers)]
else:
assert n_layers == len(
multi_grids
), "{} values expected, but got: mg={}".format(n_layers, multi_grids)
assert n_layers == len(multi_grids)

self.add_module(
"block1",
_Bottleneck(in_ch, mid_ch, out_ch, stride, dilation * multi_grids[0], True),
)
for i, rate in zip(range(2, n_layers + 1), multi_grids[1:]):
# Downsampling is only in the first block
for i in range(n_layers):
self.add_module(
"block" + str(i),
_Bottleneck(out_ch, mid_ch, out_ch, 1, dilation * rate, False),
"block{}".format(i + 1),
_Bottleneck(
in_ch=(in_ch if i == 0 else out_ch),
out_ch=out_ch,
stride=(stride if i == 0 else 1),
dilation=dilation * multi_grids[i],
downsample=(True if i == 0 else False),
),
)


Expand All @@ -102,32 +104,29 @@ class _Stem(nn.Sequential):
Note that the max pooling is different from both MSRA and FAIR ResNet.
"""

def __init__(self):
def __init__(self, out_ch):
super(_Stem, self).__init__()
self.add_module("conv1", _ConvBnReLU(3, 64, 7, 2, 3, 1))
self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1))
self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True))


class ResNet(nn.Module):
class _Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)


class ResNet(nn.Sequential):
def __init__(self, n_classes, n_blocks):
super(ResNet, self).__init__()
self.add_module("layer1", _Stem())
self.add_module("layer2", _ResLayer(n_blocks[0], 64, 64, 256, 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], 256, 128, 512, 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], 512, 256, 1024, 2, 1))
self.add_module("layer5", _ResLayer(n_blocks[3], 1024, 512, 2048, 2, 1))
ch = [64 * 2 ** p for p in range(6)]
self.add_module("layer1", _Stem(ch[0]))
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1))
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1))
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 2, 1))
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 2, 1))
self.add_module("pool5", nn.AdaptiveAvgPool2d(1))
self.add_module("fc", nn.Linear(2048, n_classes))

def forward(self, x):
h = self.layer1(x)
h = self.layer2(h)
h = self.layer3(h)
h = self.layer4(h)
h = self.layer5(h)
h = self.pool5(h)
h = self.fc(h.view(h.size(0), -1))
return h
self.add_module("flatten", _Flatten())
self.add_module("fc", nn.Linear(ch[5], n_classes))


if __name__ == "__main__":
Expand Down

0 comments on commit 1c2f1f3

Please sign in to comment.