论文中有两种不同的`residual block`的设计     
- 浅层网络时， 采用的是两个3x3conv
- 深层网络（>50）时，使用的是`bottle neck`,三层卷积，(1x1, 64),  (3x3, 64),  (1x1, 256)

In [0]:
import torch
import torch as t
from torch import  nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim

In [0]:
class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out):

        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
                nn.BatchNorm2d(ch_out)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.extra(x)

        return F.relu(out)




class ResNet18(nn.Module):

    def __init__(self):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(16)
        )
        self.blk1 = ResBlk(16, 16)
        self.blk2 = ResBlk(16, 32)
        self.fc1 = nn.Linear(32*32*32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.blk1(x)
        x = self.blk2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)

        return x

In [15]:
batch_size = 32

cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
    transforms.Resize(32, 32),
    transforms.ToTensor()
]), download=True)
train_loader = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
    transforms.Resize(32, 32),
    transforms.ToTensor()
]), download=True)
test_loader = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)  

x, label = iter(train_loader).next()
print('x: ', x.shape, 'label:  ', label.shape)

Files already downloaded and verified
Files already downloaded and verified
x:  torch.Size([32, 3, 32, 32]) label:   torch.Size([32])


In [16]:
device = t.device('cuda')
model = ResNet18().to(device) 

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)

ResNet18(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (blk1): ResBlk(
    (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential()
  )
  (blk2): ResBlk(
    (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (extra): Sequential(
      (0): Conv2d(16, 32, ker

In [0]:
for epoch in range(100):
    
    # train
    model.train()
    for index, (x, label) in enumerate(train_loader):
        x, label = x.to(device), label.to(device)
        out = model(x)
        loss = criterion(out, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('epoch: ', epoch, 'loss: ', loss.item())
        
    # test
    model.eval()
    with t.no_grad():
        total_correct = 0
        total_num = 0
        
        for x, label in test_loader:
            x, label = x.to(device), label.to(device)
            out = model(x)
        
            pred = out.argmax(dim=1)
            correct = t.eq(pred, label).float().sum().item()
            
            total_correct += correct
            total_num += x.size(0)
            
        acc = total_correct / total_num
    
    print('epoch: ', epoch, 'acc: ', acc)
        

epoch:  0 loss:  0.05174073576927185
epoch:  0 acc:  0.6321
epoch:  1 loss:  0.08459293097257614
epoch:  1 acc:  0.6203
epoch:  2 loss:  0.03969340771436691
epoch:  2 acc:  0.6279
epoch:  3 loss:  0.05770719051361084
epoch:  3 acc:  0.6171
epoch:  4 loss:  0.18921205401420593
epoch:  4 acc:  0.622
epoch:  5 loss:  0.0482497476041317
epoch:  5 acc:  0.63
epoch:  6 loss:  0.0025149062275886536
epoch:  6 acc:  0.6147
epoch:  7 loss:  0.04263240844011307
epoch:  7 acc:  0.6231
epoch:  8 loss:  0.010829329490661621
epoch:  8 acc:  0.6232
epoch:  9 loss:  0.012511417269706726
epoch:  9 acc:  0.6247
epoch:  10 loss:  0.035785011947155
epoch:  10 acc:  0.6246
epoch:  11 loss:  0.2755345404148102
epoch:  11 acc:  0.6264
epoch:  12 loss:  0.032314591109752655
epoch:  12 acc:  0.6251
epoch:  13 loss:  0.0012970268726348877
epoch:  13 acc:  0.6324
epoch:  14 loss:  0.0144390519708395
epoch:  14 acc:  0.6331
epoch:  15 loss:  0.10761986672878265
epoch:  15 acc:  0.6309
epoch:  16 loss:  0.007928140