In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import d2lzh_pytorch as d2l

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])


class Inception(nn.Module):
    # c1-c4为每条线路里的层的输出通道数
    def __init__(self, in_c, c1, c2, c3, c4):
        super(Inception, self).__init__()
        # 线路1，单1*1卷积层
        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)
        # 线路2,1*1卷积层后接3*3卷积层
        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 线路3，1*1卷积层后接5*5卷积层
        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 线路4,3*3最大池化后接1*1卷积
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)

    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(self.p2_1(x)))
        p3 = F.relu(self.p3_2(self.p3_1(x)))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        # 在通道维上连接输出
        return torch.cat((p1, p2, p3, p4), dim=1)

In [3]:
b1 = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [4]:
b2 = nn.Sequential(
    nn.Conv2d(64, 64, kernel_size=1),
    nn.ReLU(),
    nn.Conv2d(64, 192, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [5]:
b3 = nn.Sequential(
    Inception(192, 64, (96, 128), (16, 32), 32),
    Inception(256, 128, (128, 192), (32, 96), 64),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [6]:
b4 = nn.Sequential(
    Inception(480, 192, (96, 208), (16, 48), 64),
    Inception(512, 160, (112, 224), (24, 64), 64),
    Inception(512, 128, (128, 256), (24, 64), 64),
    Inception(512, 112, (144, 288), (32, 64), 64),
    Inception(528, 256, (160, 320), (32, 128), 128),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [7]:
b5 = nn.Sequential(
    Inception(832, 256, (160, 320), (32, 128), 128),
    Inception(832, 384, (192, 384), (48, 128), 128),
    GlobalAvgPool2d()
)
net = nn.Sequential(
    b1, b2, b3, b4, b5,
    d2l.FlattenLayer(),
    nn.Linear(1024, 10)
)

In [8]:
X = torch.rand(1, 1, 96, 96)
for blk in net.children():
    X = blk(X)
    print('output shape: ', X.shape)

output shape:  torch.Size([1, 64, 24, 24])
output shape:  torch.Size([1, 192, 12, 12])
output shape:  torch.Size([1, 480, 6, 6])
output shape:  torch.Size([1, 832, 3, 3])
output shape:  torch.Size([1, 1024, 1, 1])
output shape:  torch.Size([1, 1024])
output shape:  torch.Size([1, 10])


In [10]:
import os

path = os.path.dirname(os.getcwd())
path += '\data\FashionMNIST'
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96, root=path)
lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\train-images-idx3-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\train-labels-idx1-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to D:\python_code\pycharm_pro\DiveIntoDL\data\FashionMNIST\FashionMNIST\raw

training on  cuda
epoch 1, loss 2.1052, train acc 0.170, test acc 0.481, time 34.5 sec
epoch 2, loss 0.7674, train acc 0.700, test acc 0.799, time 33.5 sec
epoch 3, loss 0.4596, train acc 0.829, test acc 0.844, time 37.2 sec
epoch 4, loss 0.3832, train acc 0.856, test acc 0.852, time 41.9 sec
epoch 5, loss 0.3355, train acc 0.874, test acc 0.878, time 39.5 sec
