Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- added provisions for depthwise convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Sep 18, 2017
1 parent dfd3598 commit 4bb194e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
51 changes: 47 additions & 4 deletions inferno/extensions/layers/convolutional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
from ..initializers import OrthogonalWeightsZeroBias, KaimingNormalWeightsZeroBias
from ..initializers import Initializer
from ...utils.exceptions import assert_, ShapeError


__all__ = ['ConvActivation',
Expand All @@ -10,34 +11,51 @@
'StridedConvELU2D', 'StridedConvELU3D',
'DilatedConvELU2D', 'DilatedConvELU3D',
'Conv2D', 'Conv3D',
'BNReLUConv2D']
'BNReLUConv2D',
'BNReLUDepthwiseConv2D']


class ConvActivation(nn.Module):
"""Convolutional layer with 'SAME' padding followed by an activation."""
def __init__(self, in_channels, out_channels, kernel_size, dim, activation,
stride=1, dilation=1, bias=True, deconv=False, initialization=None):
stride=1, dilation=1, groups=None, depthwise=False, bias=True,
deconv=False, initialization=None):
super(ConvActivation, self).__init__()
# Validate dim
assert dim in [2, 3]
assert_(dim in [2, 3], "`dim` must be one of [2, 3], got {}.".format(dim), ShapeError)
self.dim = dim
# Check if depthwise
if depthwise:
assert_(in_channels == out_channels,
"For depthwise convolutions, number of input channels (given: {}) "
"must equal the number of output channels (given {})."
.format(in_channels, out_channels),
ValueError)
assert_(groups is None or groups == in_channels,
"For depthwise convolutions, groups (given: {}) must "
"equal the number of channels (given: {}).".format(groups, in_channels))
groups = in_channels
else:
groups = 1 if groups is None else groups
self.depthwise = depthwise
if not deconv:
# Get padding
padding = self.get_padding(kernel_size, dilation)
# Get convlayer
self.conv = getattr(nn, 'Conv{}d'.format(self.dim))(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias)
else:
self.conv = getattr(nn, 'ConvTranspose{}d'.format(self.dim))(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias)
if initialization is None:
pass
Expand Down Expand Up @@ -260,3 +278,28 @@ def forward(self, input):
activated = self.activation(normed)
conved = self.conv(activated)
return conved


class BNReLUDepthwiseConv2D(ConvActivation):
"""
2D BN-ReLU-Conv layer with 'SAME' padding, He weight initialization and depthwise convolution.
Note that depthwise convolutions require `in_channels == out_channels`.
"""
def __init__(self, in_channels, out_channels, kernel_size):
# We know that in_channels == out_channels, but we also want a consistent API.
# As a compromise, we allow that out_channels be None or 'auto'.
out_channels = in_channels if out_channels in [None, 'auto'] else out_channels
super(BNReLUDepthwiseConv2D, self).__init__(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dim=2,
depthwise=True,
activation=nn.ReLU(inplace=True),
initialization=KaimingNormalWeightsZeroBias(0))
self.batchnorm = nn.BatchNorm2d(in_channels)

def forward(self, input):
normed = self.batchnorm(input)
activated = self.activation(normed)
conved = self.conv(activated)
return conved
18 changes: 18 additions & 0 deletions tests/extensions/layers/convolutional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
import torch
from inferno.utils.model_utils import ModelTester


class TestConvolutional(unittest.TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "GPU not available.")
def test_bn_relu_depthwise_conv2d_pyinn(self):
from inferno.extensions.layers.convolutional import BNReLUDepthwiseConv2D
model = BNReLUDepthwiseConv2D(10, 'auto', 3)
ModelTester((1, 10, 100, 100),
(1, 10, 100, 100)).cuda()(model)
self.assertTrue(model.depthwise)
self.assertEqual(model.conv.groups, 10)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4bb194e

Please sign in to comment.