In [None]:
# -*- coding: utf-8 -*-
import sys
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
 
def conv3x3(in_channel, out_channel, stride=1):
    return nn.Conv2d(in_channel, out_channel, 3, stride=stride, padding=1, bias=False)
     
class residual_block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        stride = 1 if in_channel == out_channel else 2
        self.conv1 = conv3x3(in_channel, out_channel, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = conv3x3(out_channel, out_channel)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = None
        if in_channel != out_channel:
            self.conv3 = nn.Conv2d(in_channel, out_channel, 1, stride=stride)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(out), True)
        out = self.conv2(out)
        out = F.relu(self.bn2(out), True)
        if self.conv3:
            x = self.conv3(x)
        return F.relu(x+out, True)

class resnet(nn.Module):
    def __init__(self, in_channel, num_classes, verbose = False):
        super(resnet, self).__init__()
        self.verbose = verbose
        self.block1 = nn.Conv2d(in_channel, 64, 7, 2)

        self.block2 = nn.Sequential(
            nn.MaxPool2d(3, 2),
            residual_block(64, 64),
            residual_block(64, 64)
        )
        self.block3 = nn.Sequential(
            residual_block(64, 128),
            residual_block(128, 128)
        )
        self.block4 = nn.Sequential(
            residual_block(128, 256),
            residual_block(256, 256)
        )
        self.block5 = nn.Sequential(
            residual_block(256, 512),
            residual_block(512, 512),
            nn.AvgPool2d(3)
        )
        self.classifier = nn.Linear(512, num_classes)
  
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x   
 
def data_tf(x):
    x = x.resize((96, 96), 2) # 将图片放大到 96 x 96
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5 # 标准化，这个技巧之后会讲到
    x = x.transpose((2, 0, 1)) # 将 channel 放到第一维，只是 pytorch 要求的输入方式
    x = torch.from_numpy(x)
    return x
     
train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)    
print(train_set.data.shape)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
 
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total
 
def train(net, train_data, num_epochs, optimizer, criterion):
    net = net.train()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        for im, label in train_data:
            #forward
            output = net(im)
            loss = criterion(output, label)
            #forward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
               
            train_loss += loss.item()
            train_acc += get_acc(output, label)
        print("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
                                          
train(net, train_data, 20, optimizer, criterion)  
          
#ResNet 使用跨层通道使得训练非常深的卷积神经网络成为可能。同样它使用很简单的卷积层配置，使得其拓展更加简单。
test_loss = 0
test_acc = 0
net = net.eval()
for im, label in test_data:
    with torch.no_grad():
        output = net(im)
        loss = criterion(output, label)
        test_loss += loss.item()
        test_acc += get_acc(output, label)
print("Test Loss: %f, Test Acc: %f, " % (test_loss / len(test_data), test_acc / len(test_data)))