### PyTorch实现的ResNeXt

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### 定义基本单元

In [16]:
class Block(nn.Module):
    expansion=2
    def __init__(self,in_planes,cardinality=32,bottleneck_width=4,stride=1):
        super(Block,self).__init__()
        group_width=cardinality*bottleneck_width
        self.conv1=nn.Conv2d(in_planes,group_width,kernel_size=1,bias=False)
        self.bn1=nn.BatchNorm2d(group_width)
        self.conv2=nn.Conv2d(group_width,group_width,kernel_size=3,stride=stride,padding=1,groups=cardinality,bias=False)
        self.bn2=nn.BatchNorm2d(group_width)
        self.conv3=nn.Conv2d(group_width,self.expansion*group_width,kernel_size=1,bias=False)
        self.bn3=nn.BatchNorm2d(self.expansion*group_width)
        
        self.shortcut=nn.Sequential()
        if stride!=1 or in_planes!=self.expansion*group_width:
            self.shortcut=nn.Sequential(
                nn.Conv2d(in_planes,self.expansion*group_width,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(self.expansion*group_width)
            )
    def forward(self,x):
        out=F.relu(self.bn1(self.conv1(x)))
        out=F.relu(self.bn2(self.conv2(out)))
        out=self.bn3(self.conv3(out))
        out+=self.shortcut(x)
        out=F.relu(out)
        return out

### 定义ResNeXt主体结构

In [114]:
class ResNeXt(nn.Module):
    def __init__(self,num_blocks,cardinality,bottleneck_width,num_classes=100):
        super(ResNeXt,self).__init__()
        self.cardinality=cardinality
        self.bottleneck_width=bottleneck_width
        self.in_planes=64
        #这一部分有争议的，按照原论文卷积核的尺寸应该为7
#         self.conv1=nn.Conv2d(3,64,kernel_size=1,bias=False)
#         self.bn1=nn.BatchNorm2d(64)
        #按照原论文实现
        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(64)
#         self.relu=nn.ReLU(inplace=True)
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.layer1=self._make_layer(num_blocks[0],1)
        self.layer2=self._make_layer(num_blocks[1],2)
        self.layer3=self._make_layer(num_blocks[2],2)
        self.avgpool=nn.AvgPool2d(7,stride=1)
        self.fc=nn.Linear(cardinality*bottleneck_width*8,num_classes)
        
    def _make_layer(self,num_blocks,stride):
        strides=[stride]+[1]*(num_blocks-1)
        layers=[]
        for stride in strides:
            layers.append(Block(self.in_planes,self.cardinality,self.bottleneck_width,stride))
            self.in_planes=Block.expansion*self.cardinality*self.bottleneck_width
        #increase bottleneck_width by 2 after each stage
        self.bottleneck_width*=2
        return nn.Sequential(*layers)
    def forward(self,x):
        out=F.relu(self.bn1(self.conv1(x)))
        out=self.maxpool(out)
#         out=F.relu(self.bn1(self.conv1(x)))
        out=self.layer1(out)
        out=self.layer2(out)
        out=self.layer3(out)
        out=self.avgpool(out)
        out=out.view(out.size(0),-1)
        out=self.fc(out)
        return out

### 定义各种结构的网络模型

In [115]:
def ResNeXt29_2x64d():
    return ResNeXt(num_blocks=[3,3,3],cardinality=2,bottleneck_width=64)
def ResNext29_4x64d():
    return ResNeXt(num_blocks=[3,3,3],cardinality=4,bottleneck_width=64)
def ResNext29_8x64d():
    return ResNeXt(num_blocks=[3,3,3],cardinality=8,bottleneck_width=64)
def ResNeXt29_32x4d():
    return ResNeXt(num_blocks=[3,3,3],cardinality=32,bottleneck_width=4)

In [118]:
def test_resnext():
    net=ResNeXt29_2x64d()
    net.avgpool=nn.AdaptiveAvgPool2d((1,1))
    print(net)
    x=torch.randn(1,3,128,128)
    y=net(x)
    print(y.size())

In [119]:
test_resnext()

ResNeXt(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Block(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e

In [None]:
import torch