In [6]:
# 使用残差结构
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10

In [2]:
def conv3x3(in_channel,out_channel,stride=1):
    return nn.Conv2d(in_channel,out_channel,3,stride=stride,padding=1,bias=False)


In [10]:
class residual_block(nn.Module):
    def __init__(self,in_channel,out_channel,same_shape=True):
        super(residual_block,self).__init__()
        self.same_shape=same_shape
        stride=1 if self.same_shape else 2
        self.conv1=conv3x3(in_channel,out_channel,stride=stride)
        self.bn1=nn.BatchNorm2d(out_channel)
        
        self.conv2=conv3x3(out_channel,out_channel)
        self.bn2=nn.BatchNorm2d(out_channel)     
        if not self.same_shape:
            self.conv3=nn.Conv2d(in_channel,out_channel,1,stride=stride)
    def forward(self,x):
        out=self.conv1(x)
        out=F.relu(self.bn1(out),True)
        out=self.conv2(out)
        out=F.relu(self.bn2(out),True)
        if not self.same_shape:
            x=self.conv3(x)
        return F.relu(x+out,True)
        

In [11]:
# 测试residual block 的输入和输出
# 输入输出形状相同
test_net=residual_block(32,32)
test_x=Variable(torch.zeros(1,32,96,96))
print(test_x.shape)
test_y=test_net(test_x)
print(test_y.shape)


torch.Size([1, 32, 96, 96])
torch.Size([1, 32, 96, 96])


In [12]:
# 输入输出形状不同
test_net=residual_block(32,32,False)
test_x=Variable(torch.zeros(1,32,96,96))
print(test_x.shape)
test_y=test_net(test_x)
print(test_y.shape)

torch.Size([1, 32, 96, 96])
torch.Size([1, 32, 48, 48])


In [14]:
# 完整Resnet 他是residual ，residual block 模块的堆叠
class resnet(nn.Module):
    def __init__(self,in_channel,num_classes,verbose=False):
        super(resnet,self).__init__()
        self.verbose=verbose
        self.block1=nn.Conv2d(in_channel,64,7,2)
        self.block2=nn.Sequential(
            nn.MaxPool2d(3,2),
            residual_block(64,64),
            residual_block(64,64)
        
        
        )
        self.block3=nn.Sequential(
            residual_block(64,128,False),
            residual_block(128,128)
        
        )
        self.block4=nn.Sequential(
            residual_block(128,256,False),
            residual_block(256,256)
        
        )
        self.block5=nn.Sequential(
            residual_block(256,512,False),
            residual_block(512,512),
            nn.AvgPool2d(3)
        
        )
        self.classifier=nn.Linear(512,num_classes)
    def forward(self,x):
        x=self.block1(x)
        if self.verbose:
            print(x.shape)
        x=self.block2(x)
        if self.verbose:
            print(x.shape)
        x=self.block3(x)
        if self.verbose:
            print(x.shape)
        x=self.block4(x)
        if self.verbose:
            print(x.shape)
        x=self.block5(x)
        if self.verbose:
            print(x.shape)
        x=x.view(x.shape[0],-1)
        x=self.classifier(x)
        return x

In [15]:
test_net=resnet(3,10,True)
test_x=Variable(torch.zeros(1,3,96,96))
test_y=test_net(test_x)
print(test_y.shape)

torch.Size([1, 64, 45, 45])
torch.Size([1, 64, 22, 22])
torch.Size([1, 128, 11, 11])
torch.Size([1, 256, 6, 6])
torch.Size([1, 512, 1, 1])
torch.Size([1, 10])
