In [1]:
# from ghostnet.ghost_net import ghost_net
# from keras.datasets import mnist
import pandas as pd
import numpy as np
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn,optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import SimpleITK as sitk
import os
from VGG import *
# 检验GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集合(Train)
train_dataset = torchvision.datasets.CIFAR10(root='./data',
                                           train=True,
                                           transform=transform,
                                           download=True)
# 加载测试集合(Test)
test_dataset = torchvision.datasets.CIFAR10(root='./data',
                                          train=False,
                                          transform=transform,
                                          download=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
batch_size = 10
# 根据数据集定义数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [4]:
class early_exit_Branch(nn.Module):
    def __init__(self, inp, oup, fc_inp, kernel_size=3, stride=1, relu=False,class_nums=10):
        '''
        inp:输入通道数
        out:输出通道数
        fc_inp: 全连接层输入维度
        '''
        super(early_exit_Branch, self).__init__()
        self.dep_conv = nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, groups=inp, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),
    )
        self.classifier = nn.Linear(fc_inp, class_nums)
    def forward(self, x):
        x = self.dep_conv(x)
        # print(x.shape)
        out = x.reshape((x.size(0), -1))
        # print(out.shape)
        logits = self.classifier(out)
        return logits

class VGG16(nn.Module):
    def __init__(self, num_classes=10,train_flag = True, conf_score=0.8):
        '''
        num_classes:指最后分类的数量
        '''
        super(VGG16, self).__init__()
        self.train_flag = train_flag
        self.conf_score = 0.8
        self.num_classes = num_classes
        self.features1 = nn.Sequential(
            # 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 2
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
        )
        self.early_exit1 = early_exit_Branch(128,128,fc_inp=32768,class_nums=num_classes)
        self.features2 = nn.Sequential(
            # 4
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 5
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 6
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True)
        )
        self.early_exit2 = early_exit_Branch(256, 256, fc_inp=16384, class_nums=num_classes)
        self.features3 = nn.Sequential(
            # 7
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 8
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 9
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.early_exit3 = early_exit_Branch(512, 512, fc_inp=8192, class_nums=num_classes)
        self.features4 = nn.Sequential(
            # 10
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 11
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 12
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 13
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1),
        )
        self.early_exit4 = early_exit_Branch(512, 512, fc_inp=512, class_nums=num_classes)
        self.classifier = nn.Sequential(
            # 14
            nn.Linear(512, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            # 15
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            # 16
            nn.Linear(4096, num_classes),
        )
        #self.classifier = nn.Linear(512, 10)
    def decision_model(self, x):
        batch_size = x.shape[0]
        x = torch.softmax(x,dim = 1)
        logx = torch.log(x)
        entropy = 1 + (x * logx).sum(axis = 1)/np.log(self.num_classes)
        return entropy.sum()/batch_size
    
    def turn_early_exixt(self, flag=False):
        self.train_flag = flag
        
    def forward(self, x):
        if self.train_flag:
            out = self.features1(x)
    #         print("1", out.size())
            early_out1 = self.early_exit1(out)
            out = self.features2(out)
    #         print("2", out.size())
            early_out2 = self.early_exit2(out)
            out = self.features3(out)
    #         print("3", out.size())
            early_out3 = self.early_exit3(out)
            out = self.features4(out)
    #         print("4", out.size())
            early_out4 = self.early_exit4(out)
            out = out.view(out.size(0), -1)
            out = self.classifier(out)
            return out, early_out1,early_out2,early_out3,early_out4
        else:
            out = self.features1(x)
            early_out1 = self.early_exit1(out)
            score = self.decision_model(early_out1)
            if score >= self.conf_score:
                print("1出")
                return early_out1
            out = self.features2(out)
            early_out2 = self.early_exit2(out)
            score = self.decision_model(early_out2)
            if score >= self.conf_score:
                print("2出")
                return early_out2
            out = self.features3(out)
            early_out3 = self.early_exit3(out)
            score = self.decision_model(early_out3)
            if score >= self.conf_score:
                print("3出")
                return early_out3
            out = self.features4(out)
            early_out4 = self.early_exit4(out)
            score = self.decision_model(early_out4)
            if score >= self.conf_score:
                print("4出")
                return early_out4
            print("5出")
            out = out.view(out.size(0), -1)
            out = self.classifier(out)
            return out
# 定义当前设备是否支持 GPU
net = VGG16().to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [5]:
examples = iter(test_loader)
example_data, _ = examples.next()
net(example_data.to(device))[0].shape

torch.Size([10, 10])

In [6]:
net.turn_early_exixt(False)
examples = iter(test_loader)
example_data, _ = examples.next()
net(example_data.to(device)).shape

5出


torch.Size([10, 10])

In [7]:
net.turn_early_exixt(True)
examples = iter(test_loader)
example_data, _ = examples.next()
net(example_data.to(device))[0].shape

torch.Size([10, 10])

In [14]:
num_epochs = 10
n_total_steps = len(train_loader)
LossList = [] # 记录每一个epoch的loss
AccuryList = [] # 每一个epoch的accury
for epoch in range(num_epochs):
    # -------
    # 开始训练
    # -------
    net.train() # 切换为训练模型
    totalLoss = 0
    net.turn_early_exixt(True)  #train_flag 置为True 关闭early exit通道
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device) # 图片大小转换
        labels = labels.to(device)
        # 正向传播以及损失的求取
        outputs = net(images)
#         labels = labels.squeeze(1)
        optimizer.zero_grad() # 梯度清空
        for output in outputs:
            loss = criterion(output, labels)
            loss.backward(retain_graph=True)
            totalLoss = totalLoss + loss.item()
#         print(loss)
        # 反向传播
         # 反向传播
        optimizer.step() # 权重更新
        if (i+1) % 1000 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, n_total_steps, totalLoss/(i+1)))
    LossList.append(totalLoss/(i+1))
    # ---------
    # 开始测试
    # ---------
    net.eval()
#     net.turn_early_exixt(False)  #train_flag 置为False 打开early exit通道
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs[0].data, 1) # 预测的结果
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(correct)
        acc = 100.0 * correct / total # 在测试集上总的准确率
        AccuryList.append(acc)
        print('Accuracy of the network on the {} test images: {} %'.format(total, acc))
print("模型训练完成")

Epoch [1/10], Step [1000/5000], Loss: 0.0358
Epoch [1/10], Step [2000/5000], Loss: 0.0431
Epoch [1/10], Step [3000/5000], Loss: 0.0473
Epoch [1/10], Step [4000/5000], Loss: 0.0450
Epoch [1/10], Step [5000/5000], Loss: 0.0437
8677
Accuracy of the network on the 10000 test images: 86.77 %
Epoch [2/10], Step [1000/5000], Loss: 0.0131
Epoch [2/10], Step [2000/5000], Loss: 0.0132
Epoch [2/10], Step [3000/5000], Loss: 0.0148
Epoch [2/10], Step [4000/5000], Loss: 0.0217
Epoch [2/10], Step [5000/5000], Loss: 0.0294
8596
Accuracy of the network on the 10000 test images: 85.96 %
Epoch [3/10], Step [1000/5000], Loss: 0.0341
Epoch [3/10], Step [2000/5000], Loss: 0.0357
Epoch [3/10], Step [3000/5000], Loss: 0.0345
Epoch [3/10], Step [4000/5000], Loss: 0.0338
Epoch [3/10], Step [5000/5000], Loss: 0.0350
8647
Accuracy of the network on the 10000 test images: 86.47 %
Epoch [4/10], Step [1000/5000], Loss: 0.0335
Epoch [4/10], Step [2000/5000], Loss: 0.0358
Epoch [4/10], Step [3000/5000], Loss: 0.0421
E

In [13]:
torch.save(net, 'FlexVGG16.pkl')