# GoogLeNet基于MNIST数据集的实现

鉴于ImageNet数据集较为庞大，而本人的设备计算能力有限，主要是学习GoogLeNet的思想，这里就参考其他大佬的代码仿真一下GoogLeNet框架，直接通过tochvision获取MNIST数据集，相关的代码如下

In [1]:
# 导入模块
import torch.nn as nn
import torch
from torchvision import transforms, datasets
import torchvision
import json
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import numpy as np
import torch.optim as optim
import time
from torch.optim import lr_scheduler
from sklearn.metrics import confusion_matrix
import itertools
import random

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
epochs = 50

train_acc = []
test_acc = []
train_loss = []
test_loss = []

In [4]:
classes = ('0','1','2','3','4','5','6','7','8','9')

In [5]:
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.50,), (0.50,))]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize((0.50,), (0.50,))])}

In [6]:
train_dataset = torchvision.datasets.MNIST(root='../data',train=True,
                                        download=True, transform=data_transform["train"])
validate_dataset = torchvision.datasets.MNIST(root='../data',train=False,
                                       download=True,transform=data_transform["val"])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz


102.8%


Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


112.7%

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw




  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [7]:
batch_size = 128
train_num = len(train_dataset)
val_num = len(validate_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=8)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=8)

  cpuset_checked))


In [8]:
# 绘制准确率
def plot_acc_curves(array1, array2):
    plt.figure(figsize=(10, 10))
    x = np.linspace(1, epochs, epochs, endpoint=True)
    plt.plot(x, array1, color='r', label='Train_accuracy')
    plt.plot(x, array2, color='b', label='Test_accuracy')
    plt.legend()
    plt.title('accuracy of train and test sets in different epoch')

    plt.xlabel('epoch')
    plt.ylabel('accuracy: ')
    plt.savefig("acc_curves")
    plt.show()
    plt.clf()

In [9]:
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.figure(figsize=(10, 10))
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig("confusion_matrix")
    plt.clf()

In [10]:
# 绘制结果
@torch.no_grad()
def get_all_preds(model, loader):

    all_preds = torch.tensor([]).to(device)
    model.to(device)
    for batch in loader:
        images, labels = batch
        preds = model(images.to(device))
        all_preds = torch.cat((all_preds, preds),dim=0)
    return all_preds

In [11]:
# 绘制查看多分类的图片
def plot_misclf_imgs(candidates,gts_np,preds_np,classes):
    size_figure_grid = 5  # a grid of 5 by 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(20, 20))

    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5 * 5): 
        i = k // 5
        j = k % 5
        idx = candidates[k]
        img = validate_dataset[idx][0].numpy()
        img = img[0]
        ax[i, j].imshow((img), cmap='gray') 
        ax[i, j].set_title("Label:"+str(classes[gts_np[idx]]), loc='left')
        ax[i, j].set_title("Predict:"+str(classes[preds_np[idx]]), loc='right')

    plt.savefig("misclf_imgs")
    plt.clf()

In [12]:
# 绘制损失函数
def plot_loss_curves(array1, array2):
    plt.figure(figsize=(10, 10))
    x = np.linspace(1, epochs, epochs, endpoint=True)
    plt.plot(x, array1, color='r', label='Train_loss')
    plt.plot(x, array2, color='b', label='Test_loss')
    plt.legend()
    plt.title('loss of train and test sets in different epoch')

    plt.xlabel('epoch')
    plt.ylabel('loss: ')
    plt.savefig("loss_curves")
    plt.show()
    plt.clf()

In [15]:
# 定义Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    # 前向传播
    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

 
    def forward(self, x):
        
        x = self.averagePool(x)       
        x = self.conv(x)       
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)        
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)        
        x = self.fc2(x)
        
        return x

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)


    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x


In [16]:
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=10, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(1, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    # 前向传播
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.maxpool1(x)       
        x = self.conv2(x)       
        x = self.conv3(x)       
        x = self.maxpool2(x)

        
        x = self.inception3a(x)       
        x = self.inception3b(x)
        x = self.maxpool3(x)      
        x = self.inception4a(x)
        
        if self.training and self.aux_logits:    
            aux1 = self.aux1(x)

        x = self.inception4b(x)       
        x = self.inception4c(x)       
        x = self.inception4d(x)
        
        if self.training and self.aux_logits:    
            aux2 = self.aux2(x)

        x = self.inception4e(x)       
        x = self.maxpool4(x)       
        x = self.inception5a(x)       
        x = self.inception5b(x)
      

        x = self.avgpool(x)        
        x = torch.flatten(x, 1)      
        x = self.dropout(x)
        x = self.fc(x)
        
        if self.training and self.aux_logits:   
            return x, aux2, aux1
        return x

    # 权重初始化
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [17]:
net = GoogLeNet(num_classes=10, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [18]:
best_acc = 0.0
save_path = 'googleNet.pth'
since = time.time()
for epoch in range(epochs):
    net.train()
    running_loss = 0.0
    running_corrects = 0
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad()
        logits, aux_logits2, aux_logits1 = net(images.to(device))
        _, predict_y = torch.max(logits, dim=1)
        loss0 = loss_function(logits, labels.to(device))
        loss1 = loss_function(aux_logits1, labels.to(device))
        loss2 = loss_function(aux_logits2, labels.to(device))
        loss = loss0 + loss1 * 0.3 + loss2 * 0.3
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_corrects += (predict_y == labels.to(device)).sum().item()
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    accurate_train = running_corrects / train_num
    train_loss.append(running_loss / len(train_loader))
    train_acc.append(accurate_train)

  cpuset_checked))
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train loss:  1 %[->.................................................]5.999

In [None]:
net.eval()
    acc = 0.0  
    acc_train = 0.0
    Loss_val = 0.0
    with torch.no_grad():
        for test_step,data_test in enumerate(validate_loader, start=0):
            test_images, test_labels = data_test
            outputs = net(test_images.to(device))  
            loss_val = loss_function(outputs,test_labels.to(device))
            Loss_val +=loss_val.item()
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == test_labels.to(device)).sum().item()
        accurate_test = acc / val_num
        test_acc.append(accurate_test)
        for step,data in enumerate(train_loader, start=0):
           images, labels = data
           outputs = net(images.to(device))  
           predict_y = torch.max(outputs, dim=1)[1]
           acc += (predict_y == labels.to(device)).sum().item()
        accurate_train = acc_train / train_num
        train_acc.append(accurate_train)
        if accurate_test > best_acc:
            best_acc=accurate_test
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, accurate_test))

        test_loss.append(Loss_val/len(validate_loader))
time_elapsed = time.time() - since

In [None]:
test_preds = get_all_preds(net, validate_loader).cpu()
gts = validate_dataset.targets
preds = test_preds.argmax(dim=1)
gts_np = np.array(gts)
preds_np = np.array(preds)
mis_idxes = list(np.where(gts_np!= preds_np)[0])
candidates = random.sample(mis_idxes,25)
cm = confusion_matrix(validate_dataset.targets, test_preds.argmax(dim=1))

plot_confusion_matrix(cm, classes)
print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

In [None]:
# 绘制准确率
plot_acc_curves(train_acc,test_acc)

In [None]:
# 绘制损失函数
plot_loss_curves(train_loss,test_loss)

In [None]:
# 绘制部分样本
plot_misclf_imgs(candidates,gts_np,preds_np,classes)