In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [8]:
def conv3x3(in_channels, out_channels, stride=1,
           padding=1, bias=True, groups=1):
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=3, stride=stride,
                     padding=padding, bias=bias,
                     groups=groups
                    )

In [12]:
def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels, 
        out_channels,
        kernel_size=1,
        groups =groups,
        stride=1)

In [31]:
def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    
    channels_per_group = num_channels // groups
    
    x = x.view(batchsize, groups,
               channels_per_group, height, width)
    
    x = torch.transpose(x, 1, 2).contiguous()
    
    x = x.view(batchsize, -1, height, width)
    
    return x   
    

In [5]:
class ShuffleUnit(nn.Module):
    def __init__(self, in_channels, out_channels,
                 groups=3, grouped_conv=True, 
                 combine='add'):
        super(ShuffleUnit, self).__init__()
        self.in_channnels = in_channels
        self.out_channels = out_channels
        self.grouped_conv = grouped_conv 
        self.bottleneck_channels = self.out_channels // 4
        
        if self.combine == 'add':
            self.depthwise_stride = 1 
            self._combine_func = self._add
            
        elif self.combine == 'concat':
            self.depthwise_stride = 2 
            self._combine_func = self._concat
            
            self.out_channels -= slef.in_channels 
        
        else: 
            raise ValueError("Cannot combine tensors with \"{}\"" \
                             "Only \"add\" and \"concat\" are" \
                             "supported".format(self.combine))
        
        self.first_1x1_groups = self.groups if grouped_conv else 1
        self.first_1x1_compress = self._make_grouped_conv1x1(
            self.in_channels,
            self.bottleneck_channels,
            self.first_1x1_groups,
            batch_norm=True,
            relu=True
            )
        
    @staticmethod
    def _add(x, out):
        return x + out
    
    
    @staticmethod 
    def _concat(x, out):
        return torch.cat((x, out), 1)
    
    def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
                              batch_norm=True, relu=False):
        modules = OrderedDict()
        
        conv = conv1x1(in_channels, out_channels, groups=groups)
        modules['conv1x1'] = conv 
        
        if batch_norm:
            modules['batch_norm'] = nn.BatchNorm2d(out_channels)
        if relu:
            modules['relu'] = nn.ReLU()
        if len(modules) > 1: 
            return nn.Sequential(modules)
        else: 
            return conv
        
    def forward(self, x):
        residual = x
            
        if self.combine == 'concat':
            residual = F.avg_pool2d(residual ,kernel_szie=3,
                                    stride=2, padding=1)
            
        out = self.g_conv_1x1_compress(x)
        out = channel_shuffle(out, self.groups)
        out = self.depthwise_conv3x3(out)
        out = self.bn_after_depthwise(out)
        out = self.g_conv_1x1_expand(out)
        
        out = self._combine_func(residual, out)
        return F.relu(out)
     

In [None]:
class ShuffleNet(nn.Module):
    
    def __init__(self, groups=3, in_channels=3, num_classes=1000):
        """ShuffleNet constructor.
        Arguments:
            groups (int, optional): number of groups to be used in grouped 
                1x1 convolutions in each ShuffleUnit. Default is 3 for best
                performance according to original paper.
            in_channels (int, optional): number of channels in the input tensor.
                Default is 3 for RGB image inputs.
            num_classes (int, optional): number of classes to predict. Default
                is 1000 for ImageNet.
        """
        super(ShuffleNet, self).__init__()

        self.groups = groups
        self.stage_repeats = [3, 7, 3]
        self.in_channels =  in_channels
        self.num_classes = num_classes

        # index 0 is invalid and should never be called.
        # only used for indexing convenience.
        if groups == 1:
            self.stage_out_channels = [-1, 24, 144, 288, 567]
        elif groups == 2:
            self.stage_out_channels = [-1, 24, 200, 400, 800]
        elif groups == 3:
            self.stage_out_channels = [-1, 24, 240, 480, 960]
        elif groups == 4:
            self.stage_out_channels = [-1, 24, 272, 544, 1088]
        elif groups == 8:
            self.stage_out_channels = [-1, 24, 384, 768, 1536]
        else:
            raise ValueError(
                """{} groups is not supported for
                   1x1 Grouped Convolutions""".format(num_groups))
        
        # Stage 1 always has 24 output channels
        self.conv1 = conv3x3(self.in_channels,
                             self.stage_out_channels[1], # stage 1
                             stride=2)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Stage 2
        self.stage2 = self._make_stage(2)
        # Stage 3
        self.stage3 = self._make_stage(3)
        # Stage 4
        self.stage4 = self._make_stage(4)

        # Global pooling:
        # Undefined as PyTorch's functional API can be used for on-the-fly
        # shape inference if input size is not ImageNet's 224x224

        # Fully-connected classification layer
        num_inputs = self.stage_out_channels[-1]
        self.fc = nn.Linear(num_inputs, self.num_classes)
        self.init_params()


    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant(m.weight, 1)
                init.constant(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant(m.bias, 0)


    def _make_stage(self, stage):
        modules = OrderedDict()
        stage_name = "ShuffleUnit_Stage{}".format(stage)
        
        # First ShuffleUnit in the stage
        # 1. non-grouped 1x1 convolution (i.e. pointwise convolution)
        #   is used in Stage 2. Group convolutions used everywhere else.
        grouped_conv = stage > 2
        
        # 2. concatenation unit is always used.
        first_module = ShuffleUnit(
            self.stage_out_channels[stage-1],
            self.stage_out_channels[stage],
            groups=self.groups,
            grouped_conv=grouped_conv,
            combine='concat'
            )
        modules[stage_name+"_0"] = first_module

        # add more ShuffleUnits depending on pre-defined number of repeats
        for i in range(self.stage_repeats[stage-2]):
            name = stage_name + "_{}".format(i+1)
            module = ShuffleUnit(
                self.stage_out_channels[stage],
                self.stage_out_channels[stage],
                groups=self.groups,
                grouped_conv=True,
                combine='add'
                )
            modules[name] = module

        return nn.Sequential(modules)


    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)

        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        # global average pooling layer
        x = F.avg_pool2d(x, x.data.size()[-2:])
        
        # flatten for input to fully-connected layer
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return F.log_softmax(x, dim=1)
        

In [30]:
tensora = torch.arange(0,128)
tensora = tensora.reshape(1, 8, 4, -1)

tensora_ = channel_shuffle(tensora, groups=2)
tensora_

tensor([[[[  0,   1,   2,   3],
          [  4,   5,   6,   7],
          [  8,   9,  10,  11],
          [ 12,  13,  14,  15]],

         [[ 64,  65,  66,  67],
          [ 68,  69,  70,  71],
          [ 72,  73,  74,  75],
          [ 76,  77,  78,  79]],

         [[ 16,  17,  18,  19],
          [ 20,  21,  22,  23],
          [ 24,  25,  26,  27],
          [ 28,  29,  30,  31]],

         [[ 80,  81,  82,  83],
          [ 84,  85,  86,  87],
          [ 88,  89,  90,  91],
          [ 92,  93,  94,  95]],

         [[ 32,  33,  34,  35],
          [ 36,  37,  38,  39],
          [ 40,  41,  42,  43],
          [ 44,  45,  46,  47]],

         [[ 96,  97,  98,  99],
          [100, 101, 102, 103],
          [104, 105, 106, 107],
          [108, 109, 110, 111]],

         [[ 48,  49,  50,  51],
          [ 52,  53,  54,  55],
          [ 56,  57,  58,  59],
          [ 60,  61,  62,  63]],

         [[112, 113, 114, 115],
          [116, 117, 118, 119],
          [120, 121, 122, 

In [32]:
conva = conv3x3(10, 20)
convb = conv1x1(10, 20)
tensorb = torch.randn(3,10,4,4)
convb(tensorb).shape

torch.Size([3, 20, 4, 4])